From 3d79d0d5222c456d92acb2c32cafebb747b18363 Mon Sep 17 00:00:00 2001 From: Yi Wang <37149810+yiwang12@users.noreply.github.com> Date: Wed, 9 Oct 2024 21:37:47 +0200 Subject: [PATCH] Update pf.py --- mNSF/NSF/pf.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/mNSF/NSF/pf.py b/mNSF/NSF/pf.py index 3a99950..43812ae 100644 --- a/mNSF/NSF/pf.py +++ b/mNSF/NSF/pf.py @@ -25,20 +25,6 @@ dtp = "float32" rng = np.random.default_rng() -@tf.custom_gradient -def checkpoint_grad(x): - y = tf.identity(x) - def grad(dy): - if tf.executing_eagerly(): - with tf.GradientTape() as tape: - tape.watch(x) - y = tf.identity(x) - return tape.gradient(y, x, output_gradients=dy) - else: - return tf.gradients(y, x, dy)[0] - return y, grad - - class ProcessFactorization(tf.Module): def __init__(self, J, L, Z, lik="poi", chol = True, X=None, psd_kernel=tfk.MaternThreeHalves, nugget=1e-5, length_scale=0.1, disp="default", @@ -252,8 +238,7 @@ def sample_latent_GP_funcs(self, X, S=1, kernel=None, mu_z=None, Kuu_chol=None, Kuf = kernel.matrix(self.Z, X) #LxMxN Kff_diag = kernel.apply(X, X, example_ndims=1)+self.nugget #LxN alpha_x = tfl.cholesky_solve(Kuu_chol, Kuf) #LxMxN - #mu_tilde = mu_x + tfl.matvec(alpha_x, self.delta-mu_z, transpose_a=True) #LxN - mu_tilde = checkpoint_grad(mu_x + tfl.matvec(alpha_x, self.delta-mu_z, transpose_a=True)) + mu_tilde = mu_x + tfl.matvec(alpha_x, self.delta-mu_z, transpose_a=True) #LxN #compute the alpha(x_i)'(K_uu-Omega)alpha(x_i) term a_t_Kchol = tfl.matmul(alpha_x, Kuu_chol, transpose_a=True) #LxNxM aKa = tf.reduce_sum(tf.square(a_t_Kchol), axis=2) #LxN @@ -269,8 +254,8 @@ def sample_predictive_mean(self, X, sz=1, S=1, kernel=None, mu_z=None, Kuu_chol= sz is a tensor of shape (N,1) of size factors. Typically sz would be the rowSums or rowMeans of the outcome matrix Y. """ - F = checkpoint_grad(self.sample_latent_GP_funcs(X, S=S, kernel=kernel, mu_z=mu_z, - Kuu_chol=Kuu_chol, chol=chol)) + F = self.sample_latent_GP_funcs(X, S=S, kernel=kernel, mu_z=mu_z, + Kuu_chol=Kuu_chol, chol=chol) #SxLxN if self.nonneg: Lam = tfl.matrix_transpose(tfl.matmul(self.W, tf.exp(F))) #SxNxJ if self.lik=="gau": @@ -335,7 +320,7 @@ def elbo_avg(self, X, Y, sz=1, S=1, Ntot=None, chol=True): #kl_terms is not affected by minibatching so use reduce_sum #print(1111) kl_term = tf.reduce_sum(self.eval_kl_term(mu_z, Kuu_chol)) - Mu = checkpoint_grad(self.sample_predictive_mean(X, sz=sz, S=S, kernel=ker, mu_z=mu_z, Kuu_chol=Kuu_chol)) + Mu = self.sample_predictive_mean(X, sz=sz, S=S, kernel=ker, mu_z=mu_z, Kuu_chol=Kuu_chol) eloglik = likelihoods.lik_to_distr(self.lik, Mu, self.disp).log_prob(Y) return J*tf.reduce_mean(eloglik) - kl_term/Ntot