Skip to content

Commit

Permalink
Added functionality to model.py and opnmf.py to run out-of-sample OPNMF.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alfie Wearn committed Nov 27, 2024
1 parent 88d7273 commit 884b89a
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 40 deletions.
2 changes: 1 addition & 1 deletion opnmf/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.3.dev'
__version__ = '0.0.3.dev1+g88d7273.d20240120'
20 changes: 18 additions & 2 deletions opnmf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def fit_transform(self, X, init_W=None):

def transform(self, X):
"""Transform the data X according to the fitted OPNMF model.
Added functionality by Alfie Wearn 2024-01-20
Parameters
----------
Expand All @@ -131,10 +132,25 @@ def transform(self, X):
Returns
-------
W : ndarray of shape (n_samples, n_components)
H : ndarray of shape (n_components, n_features)
Transformed data.
As the OPNMF is: X~W*H, this calculates a new H given pre-calculated W and a new X.
"""
raise NotImplementedError("Don't know how to do this!")
# Ensure the model is fitted
check_is_fitted(self, 'components_')

# Apply transformation to new subjects
# Use the fixed W (self.coef_) learned during training to transform the new data
_, H, _ = opnmf(X, n_components=self.n_components_, W_fixed=self.coef_,
max_iter=self.max_iter, tol=self.tol)

# Calculate the reconstruction
X_reconstructed = self.coef_ @ H

# Calculate MSE
mse = np.linalg.norm(X - (self.coef_ @ H), ord='fro')

return H, mse

def mse(self):
check_is_fitted(self)
Expand Down
81 changes: 44 additions & 37 deletions opnmf/opnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from . logging import logger, warn


def opnmf(X, n_components, max_iter=50000, tol=1e-5, init='nndsvd',
init_W=None):
def opnmf(X, n_components, max_iter=50000, tol=1e-5, init='nndsvd', init_W=None, W_fixed=None):
"""
Orthogonal projective non-negative matrix factorization.
Expand Down Expand Up @@ -40,6 +39,9 @@ def opnmf(X, n_components, max_iter=50000, tol=1e-5, init='nndsvd',
init_W: array (n_samples, n_components)
Fixed initial coefficient matrix.
W_fixed: ndarray of shape (n_samples, n_components), default=None
Fixed basis matrix. If provided, the function will only solve for H.
Returns
-------
W : ndarray of shape (n_samples, n_components)
Expand All @@ -49,42 +51,47 @@ def opnmf(X, n_components, max_iter=50000, tol=1e-5, init='nndsvd',
mse : float
Reconstruction error
"""
if init != 'custom':
if init_W is not None:
warn('Initialisation was not set to "custom" but an initial W '
'matrix was specified. This matrix will be ignored.')
logger.info(f'Initializing using {init}')
W, _ = _initialize_nmf(X, n_components, init=init)
init_W = None
if W_fixed is not None:
# Use the fixed W and skip the update loop
W = W_fixed
else:
W = init_W
delta_W = np.inf
XX = X @ X.T

with warnings.catch_warnings():
warnings.simplefilter("ignore")
for iter in range(max_iter):
old_W = W

enum = XX @ W
denom = W @ (W.T @ XX @ W)
W = W * enum / denom

W[W < 1e-16] = 1e-16
W = W / np.linalg.norm(W, ord=2)

delta_W = (np.linalg.norm(old_W - W, ord='fro') /
np.linalg.norm(old_W, ord='fro'))
if (iter % 100) == 0:
obj = np.linalg.norm(X - (W @ (W.T @ X)), ord='fro')
logger.info(f'iter={iter} diff={delta_W}, obj={obj}')
if delta_W < tol:
logger.info(f'Converged in {iter} iterations')
break

if delta_W > tol:
warn('OPNMF did not converge with '
f'tolerance = {tol} under {max_iter} iterations')
# Initialization as per the original code
if init != 'custom':
if init_W is not None:
warn('Initialisation was not set to "custom" but an initial W '
'matrix was specified. This matrix will be ignored.')
logger.info(f'Initializing using {init}')
W, _ = _initialize_nmf(X, n_components, init=init)
init_W = None
else:
W = init_W
# Main iterative loop for updating W
delta_W = np.inf
XX = X @ X.T
with warnings.catch_warnings():
warnings.simplefilter("ignore")
for iter in range(max_iter):
old_W = W

enum = XX @ W
denom = W @ (W.T @ XX @ W)
W = W * enum / denom

W[W < 1e-16] = 1e-16
W = W / np.linalg.norm(W, ord=2)

delta_W = (np.linalg.norm(old_W - W, ord='fro') /
np.linalg.norm(old_W, ord='fro'))
if (iter % 100) == 0:
obj = np.linalg.norm(X - (W @ (W.T @ X)), ord='fro')
logger.info(f'iter={iter} diff={delta_W}, obj={obj}')
if delta_W < tol:
logger.info(f'Converged in {iter} iterations')
break

if delta_W > tol:
warn('OPNMF did not converge with '
f'tolerance = {tol} under {max_iter} iterations')

H = W.T @ X

Expand Down

0 comments on commit 884b89a

Please sign in to comment.