-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GMM fitting with full covariance crashes unexpectedly unlike sklearn GMM fitting #20
Comments
Hmm, I've seen this issue occur non-deterministically at some times, thanks for the MWE! I'll try to investigate the issue in the coming days but I'd also be happy about any more input ;) In the meantime, you might get around your issue by increasing |
Increasing |
I have found out another test case which fails unexpectedly. Below is my sample code. I use tied covariances here and the ONLY way I can get the training to converge is to set the covariance regularization to 10.0. You can try it for values of 1.0, 1e-1, 1e-2, 1e-3 and 1e-6 but they all fail. I find this odd because the dimensionality of the data is relatively low here. If you look at the eigenvalues of the covariance matrix (using covariance regularization of 1e-6) you get -0.7701, -0.5969, 0.9763, 0.9994, 1.0165 so it is clearly not positive-definite, although it is still symmetric. Therefore, however the covariances are being computed should be at fault because covariance matrices cannot have these kind of eigenvalues. At first I thought maybe there were some small negative eigenvalues due to numerical precision erros but these are the same magnitude as the positive eigen values. I will try to find out more by poking around the lightning module stuff! from pycave.bayes import GaussianMixture #Set seed #Inputs #Make some Gaussian data #Fit PyCave GMM |
Does this issue occur when you do not perform mini-batch training? Also, I would advise to try using double precision (I think you can pass |
If I use the whole dataset of 10,000 points instead of 1,000 mini-batches, I still get the same issue for covariances regularization under 1.0; however, the 1.0 covariance regularization now works. I passed precision=64 to the trainer with no change in behavior. |
I just realized that if I initialize with 'kmeans' instead of 'kmeans++', it works fine. So maybe there is something weird going on with 'kmeans++'? |
When I fit GMMs using the kmeans or kmeans++ initializations, I get a non-positive-definiteness error if the covariance regularization is too low. This error comes from the initialization typically and not necessarily from the fitting of the GMM. Can there be two different covariane regularizations? One for the initialization and one for fitting? Because one may want to have a more regularized initialization so that you get a good start but not necessarily have a heavily regularized GMM fit. |
I was trying to fit a GMM on data and kept getting the same error (with variying numbers for the batch and order):
_LinAlgError: torch.linalg_cholesky: (Batch element 3): The factorization could not be completed because the input is not positive-definite (the leading minor of order 941 is not positive-definite).
I made a minimal working example to show when this comes up in practice. I compared to sklearn and somehow sklearn is able to avoid this problem. This issue happens both on CPU and GPU. I have PyCave 3.1.3 and sklearn 0.24.2. Do you have any idea what could be the issue?
Minimum working example:
from pycave.bayes import GaussianMixture
import torch
import numpy as np
from sklearn import mixture
#Set seed
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
#Inputs
n = 5000
p = 2000
k = 10
#Make some non-Gaussian data
X = np.random.randn(n,p)
X = torch.Tensor(X)
X = torch.nn.ReLU()(X-1)
#Fit Sklearn GMM
gmm_sk = mixture.GaussianMixture(n_components=k,
covariance_type='full',
init_params='kmeans')
gmm_sk.fit(X.numpy())
#Fit PyCave GMM
gmm = GaussianMixture(num_components=k,
covariance_type='full',
init_strategy='kmeans')
gmm.fit(X)
The text was updated successfully, but these errors were encountered: