inVAE is a conditionally invariant variational autoencoder that identifies both spurious (distractors) and invariant features. It leverages domain variability to learn conditionally invariant representations. We show that inVAE captures biological variations in single-cell datasets obtained from diverse conditions and labs. inVAE incorporates biological covariates and mechanisms such as disease states, to learn an invariant data representation. This improves cell classification accuracy significantly.
-
PyPI only
pip install invae
-
Development Version (latest version on github)
git clone https://github.com/theislab/inVAE.git
cd inVAE
pip install .
Integration of Human Lung Cell Atlas using both healthy and disease samples
- Load the data:
adata = sc.read(path/to/data)
- Optional - Split the data into train, val, test (in supervised case for training classifier as well)
- Initialize the model, either Factorized or Non-Factorized:
from inVAE import FinVAE, NFinVAE`
inv_covar_keys = {
'cont': [],
'cat': ['cell_type', 'disease'] #set to the keys in the adata
}
spur_covar_keys = {
'cont': [],
'cat': ['batch'] #set to the keys in the adata
}
model = FinVAE(
adata = adata_train,
layer = 'counts', # The layer where the raw counts are stored in adata (None for adata.X: default)
inv_covar_keys = inv_covar_keys,
spur_covar_keys = spur_covar_keys,
latent_dim_inv = 20,
latent_dim_spur = 5,
device = 'cpu',
decoder_dist = 'nb'
)
Set inject_covar_in_latent= True
if you wish to add the spurious conditions directly to the latent (instead of learning the spurious latents). This gives you the most compatible version to SCVI.
For non-factorized model, use:
model = NFinVAE(
adata = adata_train,
layer = 'counts', # The layer where the raw counts are stored in adata (None for adata.X: default)
inv_covar_keys = inv_covar_keys,
spur_covar_keys = spur_covar_keys,
latent_dim_inv = 20,
latent_dim_spur = 5,
device = 'cpu',
decoder_dist = 'nb'
)
- Train the generative model:
model.train(n_epochs=500, lr_train=0.001, weight_decay=0.0001)
- Get the latent representation: In the case that covariates that were used are missing the encoder gets zeros as inputs for that sample and covariate
# This works for an arbitrary adata object not only for the training data
# Other options for the latent type are: full or spurious
latent = model.get_latent_representation(adata, latent_type='invariant')
- Optional - Train the classifer (for cell types): if adata_val is not given or does not have labels the classifier is just trained on the adata object the generative model was trained on (here: adata_train)
model.train_classifier(
adata_val,
batch_key = 'batch',
label_key = 'cell_type',
)
- Optional - Predict cell types:
# Other possible dataset_types: train or val
# train corresponds to the adata_train object above
# val to the adata used in the train_classifier function
# test is for a new unseen object
pred_test = model.predict(adata_test, dataset_type='test')
- Optional - Infer latent representation via trained classifier:
# As key one can use 'val' or 'test' depending which key was used in the predict function above
# E.g. for invariant latent representation
# Otherwise do not subset for the full representation or subset to the last dimensions for the spurious one
latent_samples_inv = model.saved_latent['val'][:, :model.latent_dim_inv]
- Optional - Saving and loading model:
model.save('./checkpoints/path.pt')
model.load('./checkpoints/path.pt')
Newest version now supports loading model parameters and weights in one:
# Same syntax for saving but now saves model params too
model.save('./checkpoints/path.pt')
# New loading function (old function can be used to load older checkpoints)
FinVAE.load_model('./checkpoints/path.pt', adata_train, device)
# or for NFinVAE
NFinVAE.load_model('./checkpoints/path.pt', adata_train, device)
- scanpy==1.9.3
- torch==2.0.1
- tensorboard==2.13.0
- anndata==0.8.0