Skip to content

Commit

Permalink
Update pf.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Oct 9, 2024
1 parent 53fff6c commit e0f15fe
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions mNSF/NSF/pf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
rng = np.random.default_rng()

class ProcessFactorization(tf.Module):
def __init__(self, J, L, Z, lik="poi", psd_kernel=tfk.MaternThreeHalves,
def __init__(self, J, L, Z, lik="poi", chol = True,psd_kernel=tfk.MaternThreeHalves,
nugget=1e-5, length_scale=0.1, disp="default",
nonneg=False, isotropic=True, feature_means=None, **kwargs):
"""
Expand Down Expand Up @@ -98,6 +98,21 @@ def __init__(self, J, L, Z, lik="poi", psd_kernel=tfk.MaternThreeHalves,
self._disp0 = disp
self._init_misc()
self.Kuu_chol = tf.Variable(self.eval_Kuu_chol(self.get_kernel()), dtype=dtp, trainable=False)
if chol:
self.alpha_x = tfl.cholesky_solve(Kuu_chol, Kuf) #LxMxN
N = X.shape[0]
L = self.W.shape[1]
mu_x = self.beta0+tfl.matmul(self.beta, X, transpose_b=True) #LxN
Kuf = kernel.matrix(self.Z, X) #LxMxN
Kff_diag = kernel.apply(X, X, example_ndims=1)+self.nugget #LxN

mu_tilde = mu_x + tfl.matvec(self.alpha_x, self.delta-mu_z, transpose_a=True) #LxN
#compute the alpha(x_i)'(K_uu-Omega)alpha(x_i) term
self.a_t_Kchol = tfl.matmul(self.alpha_x, Kuu_chol, transpose_a=True) #LxNxM
aKa = tf.reduce_sum(tf.square(self.a_t_Kchol), axis=2) #LxN
self.a_t_Omega_tril = tfl.matmul(self.alpha_x, self.Omega_tril, transpose_a=True) #LxNxM
aOmega_a = tf.reduce_sum(tf.square(self.a_t_Omega_tril ), axis=2) #LxN
Sigma_tilde = Kff_diag - aKa + aOmega_a #LxN
if self.lik=="gau" and not self.nonneg:
self.feature_means = feature_means
else:
Expand Down Expand Up @@ -461,4 +476,3 @@ def init_npf_with_nmf(fit, Y, X=None, sz=1, pseudocount=1e-2, factors=None,
fit.beta0.assign(beta0[:,None],read_value=False)
fit.delta.assign(U.T,read_value=False)
if beta is not None: fit.beta.assign(beta,read_value=False)

0 comments on commit e0f15fe

Please sign in to comment.