Skip to content

Commit

Permalink
Update pf_multiSample.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Nov 14, 2024
1 parent 89ede6f commit 283969b
Showing 1 changed file with 14 additions and 24 deletions.
38 changes: 14 additions & 24 deletions mNSF/pf_multiSample.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,29 +250,19 @@ def sample_latent_GP_funcs(self, X, S=1, kernel=None, mu_z=None, Kuu_chol=None,
mu_z = self.get_mu_z()
if Kuu_chol is None:
Kuu_chol = self.get_Kuu_chol(kernel=kernel, from_cache=(not chol))
if (not chol):
N = X.shape[0]
L = self.W.shape[1]
mu_x = self.beta0+tfl.matmul(self.beta, X, transpose_b=True) #LxN
mu_tilde = mu_x + tfl.matvec(self.alpha_x, self.delta-mu_z, transpose_a=True) #LxN
#a_t_Kchol = self.a_t_Kchol
#aKa = tf.reduce_sum(tf.square(a_t_Kchol), axis=2) #LxN
Sigma_tilde = self.Sigma_tilde #LxN
if chol:
alpha_x = tfl.cholesky_solve(Kuu_chol, Kuf) #LxMxN
N = X.shape[0]
L = self.W.shape[1]
mu_x = self.beta0+tfl.matmul(self.beta, X, transpose_b=True) #LxN
Kuf = kernel.matrix(self.Z, X) #LxMxN
Kff_diag = kernel.apply(X, X, example_ndims=1)+self.nugget #LxN

mu_tilde = mu_x + tfl.matvec(alpha_x, self.delta-mu_z, transpose_a=True) #LxN
#compute the alpha(x_i)'(K_uu-Omega)alpha(x_i) term
a_t_Kchol = tfl.matmul(alpha_x, Kuu_chol, transpose_a=True) #LxNxM
aKa = tf.reduce_sum(tf.square(a_t_Kchol), axis=2) #LxN
a_t_Omega_tril = tfl.matmul(alpha_x, self.Omega_tril, transpose_a=True) #LxNxM
aOmega_a = tf.reduce_sum(tf.square(a_t_Omega_tril), axis=2) #LxN
Sigma_tilde = Kff_diag - aKa + aOmega_a #LxN
N = X.shape[0]
L = self.W.shape[1]
mu_x = self.beta0+tfl.matmul(self.beta, X, transpose_b=True) #LxN
Kuf = kernel.matrix(self.Z, X) #LxMxN
Kff_diag = kernel.apply(X, X, example_ndims=1)+self.nugget #LxN
alpha_x = tfl.cholesky_solve(Kuu_chol, Kuf) #LxMxN
mu_tilde = mu_x + tfl.matvec(alpha_x, self.delta-mu_z, transpose_a=True) #LxN
#compute the alpha(x_i)'(K_uu-Omega)alpha(x_i) term
a_t_Kchol = tfl.matmul(alpha_x, Kuu_chol, transpose_a=True) #LxNxM
aKa = tf.reduce_sum(tf.square(a_t_Kchol), axis=2) #LxN
a_t_Omega_tril = tfl.matmul(alpha_x, self.Omega_tril, transpose_a=True) #LxNxM
aOmega_a = tf.reduce_sum(tf.square(a_t_Omega_tril), axis=2) #LxN
Sigma_tilde = Kff_diag - aKa + aOmega_a #LxN
#print(S)
#print(L)
#print(N)
Expand Down Expand Up @@ -351,7 +341,7 @@ def elbo_avg(self, X, Y, sz=1, S=1, Ntot=None, chol=True):
#kl_terms is not affected by minibatching so use reduce_sum
#print(1111)
kl_term = tf.reduce_sum(self.eval_kl_term(mu_z, Kuu_chol))
Mu = self.sample_predictive_mean(X, sz=sz, S=S, kernel=ker, mu_z=mu_z, Kuu_chol=Kuu_chol, chol = chol)
Mu = self.sample_predictive_mean(X, sz=sz, S=S, kernel=ker, mu_z=mu_z, Kuu_chol=Kuu_chol)
eloglik = likelihoods.lik_to_distr(self.lik, Mu, self.disp).log_prob(Y)
return J*tf.reduce_mean(eloglik) - kl_term/Ntot

Expand Down

0 comments on commit 283969b

Please sign in to comment.