You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This may be a problem with how I am interpreting how GPJax handles combination kernels, so sorry if I've missed something.
It seems that kernels which are a combination of a combination kernel are not being handled as expected when more than one type of combination operator is used (e.g the kernel is a sum of product kernels, or the kernel is a product of sum kernels). There doesn't appear to be a problem if both combination operators are identical (a sum of a sum kernel, or a product of product kernel).
Expected behavior:
When using a combination of combination kernel, predictive mean should be identical whether using GPJax or computing manually.
Steps to reproduce:
see below
Related code:
xall = jnp.linspace(-5,5,1000)
toy_fun = lambda x: 1/5*x**2 + jnp.sin(x*5)**3 + jnp.cos(x*3)**2
xtrain = xall[0::25][:, None]
ytrain = toy_fun(xtrain)
xtest = xall[:, None]
ytest = toy_fun(xtest)
D = gpx.gps.Dataset(xtrain, ytrain)
kernel1 = gpx.kernels.RBF()
kernel2 = gpx.kernels.Matern32()
sum_kernel = kernel1 + kernel2
# using GPJax
pos_kernel = sum_kernel * sum_kernel # pos = product of sum
pos_prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel = pos_kernel)
pos_posterior = pos_prior * gpx.gps.Gaussian(D.n)
latent_dist_pos = pos_posterior.likelihood(pos_posterior(xtest, train_data=D))
mu_pos = latent_dist_pos.mean()
std_pos = latent_dist_pos.stddev()
# manual calculation of predictive dist
kxx = (kernel1.gram(xtrain).to_dense() + kernel2.gram(xtrain).to_dense()) * (kernel1.gram(xtrain).to_dense() + kernel2.gram(xtrain).to_dense())
kxt = (kernel1.cross_covariance(xtrain,xtest) + kernel2.cross_covariance(xtrain,xtest)) * (kernel1.cross_covariance(xtrain,xtest) + kernel2.cross_covariance(xtrain,xtest))
ktt = (kernel1.gram(xtest).to_dense() + kernel2.gram(xtest).to_dense()) * (kernel1.gram(xtest).to_dense() + kernel2.gram(xtest).to_dense())
L = jnp.linalg.cholesky(kxx + 1*jnp.eye(D.n)) #1 here is to match the obs noise as assigned in the GPJax likelihood
alpha = jnp.linalg.solve(L.T,jnp.linalg.solve(L,ytrain))
v = jnp.linalg.solve(L,kxt)
mu_manual_pos = kxt.T @ alpha
cov_manual_pos = ktt - v.T @ v
var_manual_pos = jnp.diag(cov_manual_pos) +1 # adding obs variance to match GPJax stddev output
plt.plot(xtest,mu_manual_pos,':')
plt.plot(xtest,mu_pos,'--')
there is a discrepancy between "mu_manual_pos" and "mu_pos" when I don't believe there should be. Also true if we use a kernel that is a sum of individual product kernels. However, if the combination operators are identical (sum of sum, product of products), then the results become the same, and so it appears there is some problem with the way that GPJax is handling combinations of combinations that contain multiple operators.
Other information:
I found this issue when I've been working with kernels that are combinations of combinations for a personal project, where I am seeing drastic differences between using GPJax and manual computation. I've tried to simplify the problem for this post to make it as clear as possible.
The text was updated successfully, but these errors were encountered:
Hey Matthew,
I just ran into the same problem. I think the issue is in the post_init of the Combination kernel class.
def__post_init__(self):
# Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels.kernels_list: List[AbstractKernel] = []
forkernelinself.kernels:
ifnotisinstance(kernel, AbstractKernel):
raiseTypeError("can only combine Kernel instances") # pragma: no coverifisinstance(kernel, self.__class__):
kernels_list.extend(kernel.kernels)
else:
kernels_list.append(kernel)
self.kernels=kernels_list
Here it calculates a flattened list of kernels, and saves it to the the kernels attribute. When the kernel is called, it returns the operation of the kernel across all kernels in the kernel list
So the structure of operations of kernels is lost, it blindly applies the current operation (e.g. sum) for all sub-kernels. This explains why the results are consistent if all kernel operations are the same.
I assume the easy fix would be to have two attributes, self.kernels and self.flattened_kernels
I don't think we need to have a separate flattened_kernels; I would either
a) change SumKernel and ProductKernel to be actual subclasses of CombinationKernel (in which case the test on self.__class__ would only allow combining when the operation matches), or
b) explicitly add an additional check that self.operator is kernel.operator.
Bug Report
0.8.0
Current behavior:
This may be a problem with how I am interpreting how GPJax handles combination kernels, so sorry if I've missed something.
It seems that kernels which are a combination of a combination kernel are not being handled as expected when more than one type of combination operator is used (e.g the kernel is a sum of product kernels, or the kernel is a product of sum kernels). There doesn't appear to be a problem if both combination operators are identical (a sum of a sum kernel, or a product of product kernel).
Expected behavior:
When using a combination of combination kernel, predictive mean should be identical whether using GPJax or computing manually.
Steps to reproduce:
see below
Related code:
there is a discrepancy between "mu_manual_pos" and "mu_pos" when I don't believe there should be. Also true if we use a kernel that is a sum of individual product kernels. However, if the combination operators are identical (sum of sum, product of products), then the results become the same, and so it appears there is some problem with the way that GPJax is handling combinations of combinations that contain multiple operators.
Other information:
I found this issue when I've been working with kernels that are combinations of combinations for a personal project, where I am seeing drastic differences between using GPJax and manual computation. I've tried to simplify the problem for this post to make it as clear as possible.
The text was updated successfully, but these errors were encountered: