Skip to content

Commit

Permalink
Update pf.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Oct 9, 2024
1 parent 6d30d39 commit 3d79d0d
Showing 1 changed file with 4 additions and 19 deletions.
23 changes: 4 additions & 19 deletions mNSF/NSF/pf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,6 @@
dtp = "float32"
rng = np.random.default_rng()

@tf.custom_gradient
def checkpoint_grad(x):
y = tf.identity(x)
def grad(dy):
if tf.executing_eagerly():
with tf.GradientTape() as tape:
tape.watch(x)
y = tf.identity(x)
return tape.gradient(y, x, output_gradients=dy)
else:
return tf.gradients(y, x, dy)[0]
return y, grad


class ProcessFactorization(tf.Module):
def __init__(self, J, L, Z, lik="poi", chol = True, X=None, psd_kernel=tfk.MaternThreeHalves,
nugget=1e-5, length_scale=0.1, disp="default",
Expand Down Expand Up @@ -252,8 +238,7 @@ def sample_latent_GP_funcs(self, X, S=1, kernel=None, mu_z=None, Kuu_chol=None,
Kuf = kernel.matrix(self.Z, X) #LxMxN
Kff_diag = kernel.apply(X, X, example_ndims=1)+self.nugget #LxN
alpha_x = tfl.cholesky_solve(Kuu_chol, Kuf) #LxMxN
#mu_tilde = mu_x + tfl.matvec(alpha_x, self.delta-mu_z, transpose_a=True) #LxN
mu_tilde = checkpoint_grad(mu_x + tfl.matvec(alpha_x, self.delta-mu_z, transpose_a=True))
mu_tilde = mu_x + tfl.matvec(alpha_x, self.delta-mu_z, transpose_a=True) #LxN
#compute the alpha(x_i)'(K_uu-Omega)alpha(x_i) term
a_t_Kchol = tfl.matmul(alpha_x, Kuu_chol, transpose_a=True) #LxNxM
aKa = tf.reduce_sum(tf.square(a_t_Kchol), axis=2) #LxN
Expand All @@ -269,8 +254,8 @@ def sample_predictive_mean(self, X, sz=1, S=1, kernel=None, mu_z=None, Kuu_chol=
sz is a tensor of shape (N,1) of size factors.
Typically sz would be the rowSums or rowMeans of the outcome matrix Y.
"""
F = checkpoint_grad(self.sample_latent_GP_funcs(X, S=S, kernel=kernel, mu_z=mu_z,
Kuu_chol=Kuu_chol, chol=chol))
F = self.sample_latent_GP_funcs(X, S=S, kernel=kernel, mu_z=mu_z,
Kuu_chol=Kuu_chol, chol=chol) #SxLxN
if self.nonneg:
Lam = tfl.matrix_transpose(tfl.matmul(self.W, tf.exp(F))) #SxNxJ
if self.lik=="gau":
Expand Down Expand Up @@ -335,7 +320,7 @@ def elbo_avg(self, X, Y, sz=1, S=1, Ntot=None, chol=True):
#kl_terms is not affected by minibatching so use reduce_sum
#print(1111)
kl_term = tf.reduce_sum(self.eval_kl_term(mu_z, Kuu_chol))
Mu = checkpoint_grad(self.sample_predictive_mean(X, sz=sz, S=S, kernel=ker, mu_z=mu_z, Kuu_chol=Kuu_chol))
Mu = self.sample_predictive_mean(X, sz=sz, S=S, kernel=ker, mu_z=mu_z, Kuu_chol=Kuu_chol)
eloglik = likelihoods.lik_to_distr(self.lik, Mu, self.disp).log_prob(Y)
return J*tf.reduce_mean(eloglik) - kl_term/Ntot

Expand Down

0 comments on commit 3d79d0d

Please sign in to comment.