diff --git a/.gitignore b/.gitignore index f16cf9f..b0f5e9b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ dist/ **/__pycache__/ .ipynb_checkpoints/ +build/ +*.egg-info/ diff --git a/build/lib/scnym/__init__.py b/build/lib/scnym/__init__.py deleted file mode 100644 index 14f386b..0000000 --- a/build/lib/scnym/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -__author__ = "Jacob C. Kimmel, David R. Kelley" -__email__ = "jacobkimmel+scnym@gmail.com, drk@calicolabs.com" -__version__ = "0.3.4" - -# populate the namespace so top level imports work -# e.g. -# >> from scnym.model import CellTypeCLF -from . import api, main, dataprep, interpret, model, predict, trainer, utils diff --git a/build/lib/scnym/__main__.py b/build/lib/scnym/__main__.py deleted file mode 100644 index 0eccd5e..0000000 --- a/build/lib/scnym/__main__.py +++ /dev/null @@ -1,3 +0,0 @@ -from scnym.main import main - -main() diff --git a/build/lib/scnym/api.py b/build/lib/scnym/api.py deleted file mode 100644 index d39dc54..0000000 --- a/build/lib/scnym/api.py +++ /dev/null @@ -1,1515 +0,0 @@ -""" -Classify cell identities using scNym - -scnym_api() is the main API endpoint for users. -This function allows for training and prediction using scnym_train() -and scnym_predict(). Both of these functions will be infrequently -accessed by users. - -get_pretrained_weights() is a wrapper function that downloads pretrained -weights from our cloud storage bucket. -atlas2target() downloads preprocessed reference datasets and concatenates -them onto a user supplied target dataset. -""" -from typing import Optional, Union, List, Tuple -from anndata import AnnData -import scanpy as sc -import numpy as np -import pandas as pd -import torch -import os -import os.path as osp -import copy -import pickle -import warnings -import itertools -import pprint -import logging -import tqdm - -# for fetching pretrained weights, all in standard lib -import requests -import json -import urllib - -# for data splits -from sklearn.model_selection import StratifiedKFold - -# from scnym -from . import utils -from . import model -from . import main -from . import predict -from . import dataprep -from . import interpret - -# Define constants - -logger = logging.getLogger(__name__) - -TEST_URL = ( - "https://storage.googleapis.com/calico-website-mca-storage/kang_2017_stim_pbmc.h5ad" -) -WEIGHTS_JSON = "https://storage.googleapis.com/calico-website-scnym-storage/link_tables/pretrained_weights.json" -REFERENCE_JSON = "https://storage.googleapis.com/calico-website-scnym-storage/link_tables/cell_atlas.json" - -ATLAS_ANNOT_KEYS = { - "human": "celltype", - "mouse": "cell_ontology_class", - "rat": "cell_ontology_class", -} - -TASKS = ( - "train", - "predict", -) - -# Define configurations - -CONFIGS = { - "default": { - "n_epochs": 100, - "patience": 40, - "lr": 1.0, - "optimizer_name": "adadelta", - "weight_decay": 1e-4, - "batch_size": 256, - "balanced_classes": False, - "weighted_classes": False, - "mixup_alpha": 0.3, - "unsup_max_weight": 1.0, - "unsup_mean_teacher": False, - "ssl_method": "mixmatch", - "ssl_kwargs": { - "augment_pseudolabels": False, - "augment": "log1p_drop", - "unsup_criterion": "mse", - "n_augmentations": 1, - "T": 0.5, - "ramp_epochs": 100, - "burn_in_epochs": 0, - "dan_criterion": True, - "dan_ramp_epochs": 20, - "dan_max_weight": 0.1, - "min_epochs": 20, - }, - "model_kwargs": { - "n_hidden": 256, - "n_layers": 2, - "init_dropout": 0.0, - "residual": False, - }, - "tensorboard": False, - }, -} - -CONFIGS["no_new_identity"] = copy.deepcopy(CONFIGS["default"]) -CONFIGS["no_new_identity"][ - "description" -] = "Train scNym models with MixMatch and a domain adversary, assuming no new cell types in the target data." - -CONFIGS["new_identity_discovery"] = copy.deepcopy(CONFIGS["default"]) -CONFIGS["new_identity_discovery"]["ssl_kwargs"]["pseudolabel_min_confidence"] = 0.9 -CONFIGS["new_identity_discovery"]["ssl_kwargs"]["dan_use_conf_pseudolabels"] = True -CONFIGS["new_identity_discovery"][ - "description" -] = "Train scNym models with MixMatch and a domain adversary, using pseudolabel thresholding to allow for new cell type discoveries." - -CONFIGS["no_dan"] = copy.deepcopy(CONFIGS["default"]) -CONFIGS["no_dan"]["ssl_kwargs"]["dan_max_weight"] = 0.0 -CONFIGS["no_dan"]["ssl_kwargs"]["dan_ramp_epochs"] = 1 -CONFIGS["no_dan"][ - "description" -] = "Train scNym models with MixMatch but no domain adversary. May be useful if class imbalance is very large." - -CONFIGS["no_ssl"] = copy.deepcopy(CONFIGS["default"]) -CONFIGS["no_ssl"]["ssl_kwargs"]["dan_max_weight"] = 0.0 -CONFIGS["no_ssl"]["ssl_kwargs"]["dan_ramp_epochs"] = 1 -CONFIGS["no_ssl"]["ssl_kwargs"]["unsup_max_weight"] = 0.0 -CONFIGS["no_ssl"][ - "description" -] = "Train scNym models with MixMatch but no domain adversary. May be useful if class imbalance is very large." - - -UNLABELED_TOKEN = "Unlabeled" - - -def scnym_api( - adata: AnnData, - task: str = "train", - groupby: str = None, - domain_groupby: str = None, - out_path: str = "./scnym_outputs", - trained_model: str = None, - config: Union[dict, str] = "new_identity_discovery", - key_added: str = "scNym", - copy: bool = False, - **kwargs, -) -> Optional[AnnData]: - """ - scNym: Semi-supervised adversarial neural networks for - single cell classification [Kimmel2020]_. - - scNym is a cell identity classifier that transfers annotations from one - single cell experiment to another. The model is implemented as a neural - network that employs MixMatch semi-supervision and a domain adversary to - take advantage of unlabeled data during training. scNym offers superior - performance to many baseline single cell identity classification methods. - - Parameters - ---------- - adata - Annotated data matrix used for training or prediction. - If `"scNym_split"` in `.obs_keys()`, uses the cells annotated - `"train", "val"` to select data splits. - task - Task to perform, either "train" or "predict". - If "train", uses `adata` as labeled training data. - If "predict", uses `trained_model` to infer cell identities for - observations in `adata`. - groupby - Column in `adata.obs` that contains cell identity annotations. - Values of `"Unlabeled"` indicate that a given cell should be used - only as unlabeled data during training. - domain_groupby - Column in `adata.obs` that contains domain labels as integers. - Each domain of origin (e.g. batch, species) should be given a unique - domain label. - If `domain_groupby is None`, train and target data are each considered - a unique domain. - out_path - Path to a directory for saving scNym model weights and training logs. - trained_model - Path to the output directory of an scNym training run - or a string specifying a pretrained model. - If provided while `task == "train"`, used as an initialization. - config - Configuration name or dictionary of configuration of parameters. - Pre-defined configurations: - "new_identity_discovery" - Default. Employs pseudolabel thresholding to - allow for discovery of new cell identities in the target dataset using - scNym confidence scores. - "no_new_identity" - Assumes all cells in the target data belong to one - of the classes in the training data. Recommended to improve performance - when this assumption is valid. - key_added - Key added to `adata.obs` with scNym predictions if `task=="predict"`. - copy - copy the AnnData object before predicting cell types. - - Returns - ------- - Depending on `copy`, returns or updates `adata` with the following fields. - - `X_scnym` : :class:`~numpy.ndarray`, (:attr:`~anndata.AnnData.obsm`, shape=(n_samples, n_hidden), dtype `float`) - scNym embedding coordinates of data. - `scNym` : (`adata.obs`, dtype `str`) - scNym cell identity predictions for each observation. - `scNym_train_results` : :class:`~dict`, (:attr:`~anndata.AnnData.uns`) - results of scNym model training. - - Examples - -------- - >>> import scanpy as sc - >>> from scnym.api import scnym_api, atlas2target - - **Loading Data and preparing labels** - - >>> adata = sc.datasets.kang17() - >>> target_bidx = adata.obs['stim']=='stim' - >>> adata.obs['cell'] = np.array(adata.obs['cell']) - >>> adata.obs.loc[target_bidx, 'cell'] = 'Unlabeled' - - **Train an scNym model** - - >>> scnym_api( - ... adata=adata, - ... task='train', - ... groupby='clusters', - ... out_path='./scnym_outputs', - ... config='no_new_identity', - ... ) - - **Predict cell identities with the trained scNym model** - - >>> path_to_model = './scnym_outputs/' - >>> scnym_api( - ... adata=adata, - ... task='predict', - ... groupby='scNym', - ... trained_model=path_to_model, - ... config='no_new_identity', - ... ) - - **Perform semi-supervised training with an atlas** - - >>> joint_adata = atlas2target( - ... adata=adata, - ... species='mouse', - ... key_added='annotations', - ... ) - >>> scnym_api( - ... adata=joint_adata, - ... task='train', - ... groupby='annotations', - ... out_path='./scnym_outputs', - ... config='no_new_identity', - ... ) - """ - if task not in TASKS: - msg = f"{task} is not a valid scNym task.\n" - msg += f"must be one of {TASKS}" - raise ValueError(msg) - - # check configuration arguments and choose a config - if type(config) == str: - if config not in CONFIGS.keys(): - msg = f"{config} is not a predefined configuration.\n" - msg += f"must be one of {CONFIGS.keys()}." - raise ValueError(msg) - else: - config = CONFIGS[config] - elif type(config) != dict: - msg = f"`config` was a {type(config)}, must be dict or str." - raise TypeError(msg) - else: - # config is a dictionary of parameters - # add or update default parameters based on these - dconf = CONFIGS["default"] - for k in config.keys(): - dconf[k] = config[k] - config = dconf - logger.debug(f"Finalized config: {config}") - - # check for CUDA - if torch.cuda.is_available(): - print("CUDA compute device found.") - else: - print("No CUDA device found.") - print("Computations will be performed on the CPU.") - print("Add a CUDA compute device to improve speed dramatically.\n") - - if not osp.exists(out_path): - os.makedirs(out_path, exist_ok=True) - - # add args to `config` - config["out_path"] = out_path - config["groupby"] = groupby - config["key_added"] = key_added - config["trained_model"] = trained_model - config["domain_groupby"] = domain_groupby - - ################################################ - # check that there are no duplicate genes in the input object - ################################################ - n_genes = adata.shape[1] - n_unique_genes = len(np.unique(adata.var_names)) - if n_genes != n_unique_genes: - msg = "Duplicate Genes Error\n" - msg += "Not all genes passed to scNym were unique.\n" - msg += f"{n_genes} genes are present but only {n_unique_genes} unique genes were detected.\n" - msg += "Please use unique gene names in your input object.\n" - msg += "This can be achieved by running `adata.var_names_make_unique()`" - raise ValueError(msg) - - ################################################ - # check that `adata.X` are log1p(CPM) counts - ################################################ - # we can't directly check if cells were normalized to CPM because - # users may have filtered out genes *a priori*, so the cell sum - # may no longer be ~= 1e6. - # however, we can check that our assumptions about log normalization - # are true. - - # check that the min/max are within log1p(CPM) range - x_max = np.max(adata.X) > np.log1p(1e6) - x_min = np.min(adata.X) < 0.0 - - # check to see if a user accidently provided raw counts - if type(adata.X) == np.ndarray: - int_counts = np.all(np.equal(np.mod(adata.X, 1), 0)) - else: - int_counts = np.all(np.equal(np.mod(adata.X.data, 1), 0)) - - if x_max or x_min or int_counts: - msg = "Normalization error\n" - msg += ( - "`adata.X` does not appear to be log(CountsPerMillion+1) normalized data.\n" - ) - msg += "Please replace `adata.X` with log1p(CPM) values.\n" - msg += ">>> # starting from raw counts in `adata.X`\n" - msg += ">>> sc.pp.normalize_total(adata, target_sum=1e6))\n" - msg += ">>> sc.pp.log1p(adata)" - raise ValueError(msg) - - ################################################ - # check inputs and launch the appropriate task - ################################################ - - if task == "train": - # pass parameters to training routine - if groupby not in adata.obs.columns: - msg = f"{groupby} is not a variable in `adata.obs`" - raise ValueError(msg) - - scnym_train( - adata=adata, - config=config, - ) - elif task == "predict": - # check that a pre-trained model was specified or - # provided for prediction - if trained_model is None: - msg = "must provide a path to a trained model for prediction." - raise ValueError(msg) - if not os.path.exists(trained_model) and "pretrained_" not in trained_model: - msg = "path to the trained model does not exist." - raise FileNotFoundError(msg) - # predict identities - config["model_weights"] = trained_model - scnym_predict( - adata=adata, - config=config, - ) - - elif task == "interpret": - - scnym_interpret( - adata=adata, - config=config, - **kwargs, - ) - - else: - msg = f"{task} is not a valid task." - raise ValueError(msg) - - return - - -def scnym_train( - adata: AnnData, - config: dict, -) -> None: - """Train an scNym model. - - Parameters - ---------- - adata : AnnData - [Cells, Genes] experiment containing annotated - cells to train on. - config : dict - configuration options. - - Returns - ------- - None. - Saves model outputs to `config["out_path"]` and adds model results - to `adata.uns["scnym_train_results"]`. - - Notes - ----- - This method should only be directly called by advanced users. - Most users should use `scnym_api`. - - See Also - -------- - scnym_api - """ - # determine if unlabeled examples are present - n_unlabeled = np.sum(adata.obs[config["groupby"]] == UNLABELED_TOKEN) - if n_unlabeled == 0: - print("No unlabeled data was found.") - print(f'Did you forget to set some examples as `"{UNLABELED_TOKEN}"`?') - print("Proceeding with purely supervised training.") - print() - - unlabeled_counts = None - unlabeled_genes = None - - X = utils.get_adata_asarray(adata) - y = pd.Categorical( - np.array(adata.obs[config["groupby"]]), - categories=np.unique(adata.obs[config["groupby"]]), - ).codes - class_names = np.unique(adata.obs[config["groupby"]]) - # set all samples for training - train_adata = adata - # set no samples as `target_bidx` - target_bidx = np.zeros(adata.shape[0], dtype=np.bool) - else: - print(f"{n_unlabeled} unlabeled observations found.") - print( - "Using unlabeled data as a target set for semi-supervised, adversarial training." - ) - print() - - target_bidx = adata.obs[config["groupby"]] == UNLABELED_TOKEN - - train_adata = adata[~target_bidx, :] - target_adata = adata[target_bidx, :] - - print("training examples: ", train_adata.shape) - print("target examples: ", target_adata.shape) - - X = utils.get_adata_asarray(train_adata) - y = pd.Categorical( - np.array(train_adata.obs[config["groupby"]]), - categories=np.unique(train_adata.obs[config["groupby"]]), - ).codes - unlabeled_counts = utils.get_adata_asarray(target_adata) - class_names = np.unique(train_adata.obs[config["groupby"]]) - - print("X: ", X.shape) - print("y: ", y.shape) - - if "scNym_split" not in adata.obs_keys(): - # perform a 90/10 train test split - traintest_idx = np.random.choice( - X.shape[0], size=int(np.floor(0.9 * X.shape[0])), replace=False - ) - val_idx = np.setdiff1d(np.arange(X.shape[0]), traintest_idx) - else: - train_idx = np.where(train_adata.obs["scNym_split"] == "train")[0] - test_idx = np.where( - train_adata.obs["scNym_split"] == "test", - )[0] - val_idx = np.where(train_adata.obs["scNym_split"] == "val")[0] - - if len(train_idx) < 100 or len(test_idx) < 10 or len(val_idx) < 10: - msg = "Few samples in user provided data split.\n" - msg += f"{len(train_idx)} training samples.\n" - msg += f"{len(test_idx)} testing samples.\n" - msg += f"{len(val_idx)} validation samples.\n" - msg += "Halting." - raise RuntimeError(msg) - # `fit_model()` takes a tuple of `traintest_idx` - # as a training index and testing index pair. - traintest_idx = ( - train_idx, - test_idx, - ) - - # check if domain labels were manually specified - if config.get("domain_groupby", None) is not None: - domain_groupby = config["domain_groupby"] - # check that the column actually exists - if domain_groupby not in adata.obs.columns: - msg = f"no column `{domain_groupby}` exists in `adata.obs`.\n" - msg += "if domain labels are specified, a matching column must exist." - raise ValueError(msg) - # get the label indices as unique integers using pd.Categorical - # to code each unique label with an int - domains = np.array( - pd.Categorical( - adata.obs[domain_groupby], - categories=np.unique(adata.obs[domain_groupby]), - ).codes, - dtype=np.int32, - ) - # split domain labels into source and target sets for `fit_model` - input_domain = domains[~target_bidx] - unlabeled_domain = domains[target_bidx] - print("Using user provided domain labels.") - n_source_doms = len(np.unique(input_domain)) - n_target_doms = len(np.unique(unlabeled_domain)) - print( - f"Found {n_source_doms} source domains and {n_target_doms} target domains." - ) - else: - # no domains manually supplied, providing `None` to `fit_model` - # will treat source data as one domain and target data as another - input_domain = None - unlabeled_domain = None - - # check if pre-trained weights should be used to initialize the model - if config["trained_model"] is None: - pretrained = None - elif "pretrained_" in config["trained_model"]: - msg = "pretrained model fetching is not supported for training." - raise NotImplementedError(msg) - else: - # setup a prediction model - pretrained = osp.join( - config["trained_model"], - "00_best_model_weights.pkl", - ) - if not osp.exists(pretrained): - msg = f"{pretrained} file not found." - raise FileNotFoundError(msg) - - acc, loss = main.fit_model( - X=X, - y=y, - traintest_idx=traintest_idx, - val_idx=val_idx, - batch_size=config["batch_size"], - n_epochs=config["n_epochs"], - lr=config["lr"], - optimizer_name=config["optimizer_name"], - weight_decay=config["weight_decay"], - ModelClass=model.CellTypeCLF, - balanced_classes=config["balanced_classes"], - weighted_classes=config["weighted_classes"], - out_path=config["out_path"], - mixup_alpha=config["mixup_alpha"], - unlabeled_counts=unlabeled_counts, - input_domain=input_domain, - unlabeled_domain=unlabeled_domain, - unsup_max_weight=config["unsup_max_weight"], - unsup_mean_teacher=config["unsup_mean_teacher"], - ssl_method=config["ssl_method"], - ssl_kwargs=config["ssl_kwargs"], - pretrained=pretrained, - patience=config.get("patience", None), - save_freq=config.get("save_freq", None), - tensorboard=config.get("tensorboard", False), - **config["model_kwargs"], - ) - - # add the final model results to `adata` - results = { - "model_path": osp.realpath( - osp.join(config["out_path"], "00_best_model_weights.pkl") - ), - "final_acc": acc, - "final_loss": loss, - "n_genes": adata.shape[1], - "n_cell_types": len(np.unique(y)), - "class_names": class_names, - "gene_names": adata.var_names.tolist(), - "model_kwargs": config["model_kwargs"], - "traintest_idx": traintest_idx, - "val_idx": val_idx, - } - assert osp.exists(results["model_path"]) - - adata.uns["scNym_train_results"] = results - - # save the final model results to disk - train_results_path = osp.join( - config["out_path"], - "scnym_train_results.pkl", - ) - - with open(train_results_path, "wb") as f: - pickle.dump(results, f) - return - - -@torch.no_grad() -def scnym_predict( - adata: AnnData, - config: dict, -) -> None: - """Predict cell identities using an scNym model. - - Parameters - ---------- - adata : AnnData - [Cells, Genes] experiment containing annotated - cells to train on. - config : dict - configuration options. - - Returns - ------- - None. Adds `adata.obs[config["key_added"]]` and `adata.obsm["X_scnym"]`. - - Notes - ----- - This method should only be directly called by advanced users. - Most users should use `scnym_api`. - - See Also - -------- - scnym_api - """ - # check if a pretrained model was requested - if "pretrained_" in config["trained_model"]: - msg = "Pretrained Request Error\n" - msg += "Pretrained weights are no longer supported in scNym.\n" - raise NotImplementedError(msg) - # species = _get_pretrained_weights( - # trained_model=config['trained_model'], - # out_path=config['out_path'], - # ) - # print(f'Successfully downloaded pretrained model for {species}.') - # config['trained_model'] = config['out_path'] - - # load training parameters - with open( - osp.join(config["trained_model"], "scnym_train_results.pkl"), - "rb", - ) as f: - results = pickle.load(f) - - # setup a prediction model - model_weights_path = osp.join( - config["trained_model"], - "00_best_model_weights.pkl", - ) - - P = predict.Predicter( - model_weights=model_weights_path, - n_genes=results["n_genes"], - n_cell_types=results["n_cell_types"], - labels=results["class_names"], - **config["model_kwargs"], - ) - n_cell_types = results["n_cell_types"] - n_genes = results["n_genes"] - print(f"Loaded model predicting {n_cell_types} classes from {n_genes} features") - print(results["class_names"]) - - # Generate a classification matrix - print("Building a classification matrix...") - X_raw = utils.get_adata_asarray(adata) - X = utils.build_classification_matrix( - X=X_raw, - model_genes=np.array(results["gene_names"]), - sample_genes=np.array(adata.var_names), - ) - - # Predict cell identities - print("Predicting cell types...") - pred, names, prob = P.predict( - X, - output="prob", - ) - - prob = pd.DataFrame( - prob, - columns=results["class_names"], - index=adata.obs_names, - ) - - # Extract model embeddings - print("Extracting model embeddings...") - ds = dataprep.SingleCellDS(X=X, y=np.zeros(X.shape[0])) - dl = torch.utils.data.DataLoader( - ds, - batch_size=config["batch_size"], - shuffle=False, - ) - - model = P.models[0] - lz_02 = torch.nn.Sequential(*list(list(model.modules())[0].children())[1][:-1]) - - embeddings = [] - for data in dl: - input_ = data["input"] - input_ = input_.to(device=next(model.parameters()).device) - z = lz_02(input_) - embeddings.append(z.detach().cpu()) - Z = torch.cat(embeddings, 0) - - # Store results in the anndata object - adata.obs[config["key_added"]] = names - adata.obs[config["key_added"] + "_confidence"] = np.max(prob, axis=1) - adata.uns["scNym_probabilities"] = prob - adata.obsm["X_scnym"] = Z.numpy() - - return - - -def _get_pretrained_weights( - trained_model: str, - out_path: str, -) -> str: - """Given the name of a set of pretrained model weights, - fetch weights from GCS and return the model state dict. - - Parameters - ---------- - trained_model : str - the name of a pretrained model to use, formatted as - "pretrained_{species}". - species should be one of {"human", "mouse", "rat"}. - out_path : str - path for saving model weights and outputs. - - Returns - ------- - species : str - species parsed from the trained model name. - Saves "{out_path}/00_best_model_weights.pkl" and - "{out_path}/scnym_train_results.pkl". - - Notes - ----- - Requires an internet connection to download pre-trained weights. - """ - # check that the trained_model argument is valid - if "pretrained_" not in trained_model: - msg = 'pretrained model names must contain `"pretrained_"`' - raise ValueError(msg) - - species = trained_model.split("pretrained_")[1] - - # download a table of available pretrained models - try: - pretrained_weights_dict = json.loads(requests.get(WEIGHTS_JSON).text) - except requests.exceptions.ConnectionError: - print("Could not download pretrained weighs listing from:") - print(f"\t{WEIGHTS_JSON}") - print("Loading pretrained model failed.") - - # check that the species specified has pretrained weights - if species not in pretrained_weights_dict.keys(): - msg = f"pretrained weights not available for {species}." - raise ValueError(species) - - # get pretrained weights - path_for_weights = osp.join(out_path, f"00_best_model_weights.pkl") - urllib.request.urlretrieve( - pretrained_weights_dict[species], - path_for_weights, - ) - - # load model parameters - model_params = {} - urllib.request.urlretrieve( - pretrained_weights_dict["model_params"][species]["gene_names"], - osp.join(out_path, "pretrained_gene_names.csv"), - ) - urllib.request.urlretrieve( - pretrained_weights_dict["model_params"][species]["class_names"], - osp.join(out_path, "pretrained_class_names.csv"), - ) - model_params["gene_names"] = np.loadtxt( - osp.join(out_path, "pretrained_gene_names.csv"), - delimiter=",", - dtype="str", - ) - model_params["class_names"] = np.loadtxt( - osp.join(out_path, "pretrained_class_names.csv"), - delimiter=",", - dtype="str", - ) - model_params["n_genes"] = len(model_params["gene_names"]) - model_params["n_cell_types"] = len(model_params["class_names"]) - - # save model parameters to a results file in the output dir - path_for_results = f"{out_path}/scnym_train_results.pkl" - with open(path_for_results, "wb") as f: - pickle.dump(model_params, f) - - # check that files are present - if not osp.exists(path_for_weights): - raise FileNotFoundError(path_for_weights) - if not osp.exists(path_for_results): - raise FileNotFoundError(path_for_results) - - return species - - -def atlas2target( - adata: AnnData, - species: str, - key_added: str = "annotations", -) -> AnnData: - """Download a preprocessed cell atlas dataset and - append your new dataset as a target to allow for - semi-supervised scNym training. - - Parameters - ---------- - adata : anndata.AnnData - [Cells, Features] experiment to use as a target - dataset. - `adata.var_names` must be formatted as Ensembl gene - names for the relevant species to match the atlas. - e.g. `"Gapdh`" for mouse or `"GAPDH"` for human, rather - than Ensembl gene IDs or another gene annotation. - - Returns - ------- - joint_adata : anndata.AnnData - [Cells, Features] experiment concatenated with a - preprocessed cell atlas reference dataset. - Annotations from the atlas are copied to `.obs[key_added]` - and all cells in the target dataset `adata` are labeled - with the special "Unlabeled" token. - - Examples - -------- - >>> adata = sc.datasets.pbmc3k() - >>> joint_adata = scnym.api.atlas2target( - ... adata=adata, - ... species='human', - ... key_added='annotations', - ... ) - - Notes - ----- - Requires an internet connection to download reference datasets. - """ - # download a directory of cell atlases - try: - reference_dict = json.loads(requests.get(REFERENCE_JSON).text) - except requests.exceptions.ConnectionError: - print("Could not download pretrained weighs listing from:") - print(f"\t{REFERENCE_JSON}") - print("Loading pretrained model failed.") - - # check that the species presented is available - if species not in reference_dict.keys(): - msg = f"pretrained weights not available for {species}." - raise ValueError(species) - - # check that there are no gene duplications - n_uniq_genes = len(np.unique(adata.var_names)) - if n_uniq_genes < len(adata.var_names): - msg = f"{n_uniq_genes} unique features found, but {adata.shape[1]} features are listed.\n" - msg += "Please de-duplicate features in `adata` before joining with an atlas dataset.\n" - msg += "Consider `adata.var_names_make_unique()` or aggregating values for features with the same identifier." - raise ValueError(msg) - - # download the atlas of interest - atlas = sc.datasets._datasets.read( - sc.settings.datasetdir / f"atlas_{species}.h5ad", - backup_url=reference_dict[species], - ) - del atlas.raw - - # get the key used by the cell atlas - atlas_annot_key = ATLAS_ANNOT_KEYS[species] - - # copy atlas annotations to the specified column - atlas.obs[key_added] = np.array(atlas.obs[atlas_annot_key]) - atlas.obs["scNym_dataset"] = "atlas_reference" - - # label target data with "Unlabeled" - adata.obs[key_added] = "Unlabeled" - adata.obs["scNym_dataset"] = "target" - - # check that at least some genes overlap between the atlas - # and the target data - FEW_GENES = 100 - n_overlapping_genes = len(np.intersect1d(adata.var_names, atlas.var_names)) - if n_overlapping_genes == 0: - msg = "No genes overlap between the target data `adata` and the atlas.\n" - msg += 'Genes in the atlas are named using Ensembl gene symbols (e.g. `"Gapdh"`).\n' - msg += "Ensure `adata.var_names` also uses gene symbols." - raise RuntimeError(msg) - elif n_overlapping_genes < FEW_GENES: - msg = f"Only {n_overlapping_genes} overlapping genes were found between the target and atlas.\n" - msg += "Ensure your target dataset `adata.var_names` are Ensembl gene names.\n" - msg += "Continuing with transer, but performance is likely to be poor." - warnings.warn(msg) - else: - msg = f"{n_overlapping_genes} overlapping genes found between the target and atlas data." - logger.info(msg) - - # join the target and atlas data - joint_adata = atlas.concatenate( - adata, - join="inner", - ) - - return joint_adata - - -def list_configs(): - for k in CONFIGS.keys(): - print(f"name: {k}") - print("\t" + CONFIGS[k]["description"]) - return - - -def _get_keys_and_list(d: dict) -> Tuple[List[list], List[list]]: - """Get a set of keys mapping to a list in a - nested dictionary structure and the list value. - - Parameters - ---------- - d : dict - a nested dictionary structure where all terminal - values are lists. - - Returns - ------- - keys : List[list] - sequential keys required to access a set of - associated terminal values. - mapped by index to `values`. - values : List[list] - lists of terminal values, each accessed by the - set of `keys` with a matching index from `d`. - """ - accession_keys = [] - associated_values = [] - for k in d.keys(): - if type(d[k]) == dict: - # the value is nested, recurse - keys, values = _get_keys_and_list(d[k]) - keys = [ - [ - k, - ] - + x - for x in keys - ] - else: - keys = [ - [k], - ] - values = [d[k]] - - for i in range(len(values)): - accession_keys.append(keys[i]) - associated_values.append(values[i]) - - return accession_keys, associated_values - - -def _updated_nested(d: dict, keys: list, value: list) -> dict: - """Updated the values in a dictionary with multiple nested levels. - - Parameters - ---------- - d : dict - multilevel dictionary. - keys : list - sequential keys specifying a value to update - value : list - new value to use in the update. - - Returns - ------- - d : dict - updated dictionary. - """ - if type(d.get(keys[0], None)) == dict: - # multilevel, recurse - _updated_nested(d[keys[0]], keys[1:], value) - else: - d[keys[0]] = value - return - - -def split_data( - adata: AnnData, - groupby: str, - n_splits: int, -) -> None: - """Split data using a stratified k-fold. - - Parameters - ---------- - adata : anndata.AnnData - [Cells, Genes] experiment. - groupby : str - annotation column in `.obs`. - used for stratification. - n_splits : int - number of train/test/val splits to perform for tuning. - performs at least 5-fold splitting and uses a subset of - the folds if `n_splits < 5`. - - Returns - ------- - None. Adds `f"scNym_split_{n}"` to `adata.obs` for all `n` - in `[0, n_splits)`. - """ - # generate cross val splits - cv = StratifiedKFold( - n_splits=max(5, n_splits), - shuffle=True, - ) - split_indices = list(cv.split(adata.X, adata.obs[groupby])) - - for split_number, train_test in enumerate(split_indices): - - train_idx = train_test[0] - testval_idx = train_test[1] - - test_idx = np.random.choice( - testval_idx, - size=int(np.ceil(len(testval_idx) / 2)), - replace=False, - ) - val_idx = np.setdiff1d( - testval_idx, - test_idx, - ) - - # these tokens are recognized by `api.scnym_train` - adata.obs[f"scNym_split_{split_number}"] = "ERROR" - adata.obs.loc[ - adata.obs_names[train_idx], f"scNym_split_{split_number}" - ] = "train" - adata.obs.loc[adata.obs_names[test_idx], f"scNym_split_{split_number}"] = "test" - adata.obs.loc[adata.obs_names[val_idx], f"scNym_split_{split_number}"] = "val" - - return - - -def _circular_train( - search_config: dict, - params: tuple, - adata: AnnData, - groupby: str, - out_path: str, - accession_keys: List[list], - hold_out_only: bool, - groupby_eval: str, -) -> pd.DataFrame: - """ - Perform a circular training loop for a parameter set. - - Parameters - ---------- - search_config : tuple - configuration for parameter search. - params : tuple - search parameter values - adata : anndata.AnnData - [Cells, Genes] experiment for optimization. - groupby : str - annotation column in `.obs`. - accession_keys : List[list] - sequential keys required to access a set of - associated terminal values. - mapped by index to `values`. - hold_out_only : bool - evaluate the circular accuracy only on a held-out set of - training data, not used in the training of the first - source -> target model. - - Returns - ------- - search_df : pd.DataFrame - [1, (params,) + (acc,)] - search_config : dict - adjusted configuration file for this parameter search. - """ - search_number = search_config["search_number"] - split_number = search_config["split_number"] - # fit the source2target - s2t_out_path = osp.join( - out_path, f"search_{search_number:04}_split_{split_number:04}_source2target" - ) - adata = adata.copy() - - logger.info("\n>>>\nTraining source2target model\n>>>\n") - scnym_api( - adata=adata, - groupby=groupby, - task="train", - out_path=s2t_out_path, - config=search_config, - ) - - # load the hold out test acc - with open(osp.join(s2t_out_path, "scnym_train_results.pkl"), "rb") as f: - s2t_res = pickle.load(f) - s2t_source_test_acc = s2t_res["final_acc"] - - logger.info("\n>>>\nPredicting with source2target model\n>>>\n") - # predict on the target set - scnym_api( - adata=adata, - task="predict", - trained_model=s2t_out_path, - config=search_config, - ) - - # invert the problem -- train on the new labels - circ_adata = adata.copy() - circ_adata.obs[groupby] = adata.obs["scNym"] - circ_adata.obs.drop(columns=["scNym"], inplace=True) - # set the training data as unlabeled, leaving labels only on the target data - circ_adata.obs.loc[adata.obs[groupby] != UNLABELED_TOKEN, groupby] = UNLABELED_TOKEN - - # fit a new model - t2s_out_path = osp.join( - out_path, f"search_{search_number:04}_split_{split_number:04}_target2source" - ) - - logger.info("\n>>>\nTraining target2source model\n>>>\n") - - scnym_api( - adata=circ_adata, - groupby=groupby, - task="train", - out_path=t2s_out_path, - config=search_config, - ) - - # predict with new model - logger.info("\n>>>\nPredicting with target2source model\n>>>\n") - scnym_api( - adata=circ_adata, - task="predict", - trained_model=t2s_out_path, - config=search_config, - ) - - # evaluate the model - samples_bidx = adata.obs[groupby] != "Unlabeled" - samples_bidx = ( - samples_bidx & (adata.obs["scNym_split"] == "val") - if hold_out_only - else samples_bidx - ) - y_true = np.array(adata.obs[groupby])[samples_bidx] - y_pred = np.array(circ_adata.obs["scNym"])[samples_bidx] - - n_correct = np.sum(y_true == y_pred) - n_total = len(y_true) - acc = n_correct / n_total - - accession_keys_str = ["::".join(x) for x in accession_keys] - search_df = pd.DataFrame( - columns=accession_keys_str + ["acc"], - index=[search_number], - ) - search_df.loc[search_number] = params + (acc,) - search_df["test_source_acc"] = s2t_source_test_acc - - if groupby_eval is not None: - # compute the test accuracy in the target domain - # here, we use the predictions made by the source2target - # model stored in `adata.obs["scNym"]`. - samples_bidx = adata.obs[groupby] == "Unlabeled" - y_true = np.array(adata.obs[groupby_eval])[samples_bidx] - y_pred = np.array(adata.obs["scNym"])[samples_bidx] - n_correct = np.sum(y_true == y_pred) - test_acc = n_correct / len(y_true) - search_df["test_target_acc"] = "None" - search_df.loc[search_number, "test_target_acc"] = test_acc - - search_df.to_csv(osp.join(t2s_out_path, "result.csv")) - - return search_df - - -def scnym_tune( - adata: AnnData, - groupby: str, - parameters: dict, - search: str = "grid", - base_config: str = "no_new_identity", - n_points: int = 100, - out_path: str = "./scnym_tune", - hold_out_only: bool = True, - groupby_eval: str = None, - n_splits: int = 1, -) -> Tuple[pd.DataFrame, dict]: - """Perform hyperparameter tuning of an scNym model using - circular cross-validation. - - Parameters - ---------- - adata : anndata.AnnData - [Cells, Genes] experiment for optimization. - groupby : str - annotation column in `.obs`. - parameters : dict - key:List[value] pairs of parameters to use for - hyperparameter tuning. - base_config : str - one of {"no_new_identity", "new_identity_discovery"}. - base configuration for model training that described - default parameters, not explicitly provided in - `parameters`. - search : str - {"grid", "random"} perform either a random or grid - search over `parameters`. - n_points : int - number of random points to search if `search == "random"`. - out_path : str - path for intermediary files during hyperparameter tuning. - hold_out_only : bool - evaluate the circular accuracy only on a held-out set of - training data, not used in the training of the first - source -> target model. - groupby_eval : str - column in `adata.obs` containing ground truth labels - for the "Unlabeled" dataset to use for evaluation. - n_splits : int - number of train/test/val splits to perform for tuning. - performs at least 5-fold splitting and uses a subset of - the folds if `n_splits < 5`. - - Returns - ------- - tuning_results : pd.DataFrame - [n_points, (parameters,) + (circ_acc, circ_loss)] - best_parameter_set : dict - a configuration describing the best parameter set tested. - - Examples - -------- - >>> # `adata` contains labels in `.obs["annotations"]` where - ... # the target dataset is labeled "Unlabeled" - >>> tuning_results, best_parameters = scnym_tune( - ... adata=adata, - ... groupby="annotations", - ... parameters={ - ... "weight_decay": [1e-6, 1e-5, 1e-4], - ... "unsup_max_weight": [0.1, 1., 10.], - ... }, - ... base_config="no_new_identity", - ... search="grid", - ... out_path="./scnym_tuning", - ... n_splits=5, - ... ) - - Notes - ----- - Circular/Reverse cross-validation evaluates the impact of hyperparameter - selection in semi-supervised learning settings using the training data, - training labels, and target data, but not the target labels. - - This is achieved by training a model :math:`f` on the training set, then - predicting "pseudolabels" for the target set. - A second model :math:`g` is then trained on the target data and - the associated pseudolabels. - The model :math:`g` is used to predict labels for the *training* set. - The accuracy of this "reverse" prediction is then used as an estimate - of the effectiveness of a hyperparameter set. - """ - os.makedirs(out_path, exist_ok=True) - - # get the base configuration dict - # configurations have one layer of nested dictionaries within - config = CONFIGS.get(base_config, None) - if config is None: - msg = f"{base_config} is not a valid base configuration." - raise ValueError(msg) - - ################################################# - # get all possible combinations of parameters - ################################################# - # `_get_keys_and_list` traverses a nested dictionary and - # returns a List[list] of sequential keys to access each - # item in `parameter_ranges`. - # items in `parameter_ranges: List[list]` are lists of - # values for the parameter specified in `accession_keys`. - accession_keys, parameter_ranges = _get_keys_and_list(parameters) - # find all possible combinations of parameters - # each item in `param_sets` is a tuple of parameter values - # each element in the tuple matches the keys in `keys` with - # the same index. - param_sets = list( - itertools.product( - *parameter_ranges, - ) - ) - - ################################################# - # select a set of parameters to search - ################################################# - if search.lower() == "random": - # perform a random search by subsetting grid points - param_idx = np.random.choice( - len(param_sets), - size=n_points, - replace=False, - ) - else: - param_idx = range(len(param_sets)) - - ################################################# - # set a common train/test/val split for all params - ################################################# - - splits_provided = "scNym_split_0" in adata.obs.columns - splits_provided = splits_provided or "scNym_split" in adata.obs.columns - - if not splits_provided: - split_data( - adata, - groupby=groupby, - n_splits=n_splits, - ) - elif n_splits == 1 and "scNym_split" in adata.obs.columns: - adata.obs["scNym_split_0"] = adata.obs["scNym_split"] - elif n_splits > 1 and splits_provided: - # check that we have the relevant split for each fold - splits_correct = True - for s in range(n_splits): - splits_correct = splits_correct & (f"scNym_split_{s}" in adata.obs.columns) - if not splits_correct: - msg = '"scNym_split_" was provided with `n_splits>1.\n' - msg += 'f"scNym_split_{n}"" must be present in `adata.obs` for all {n} in `range(n_splits)`\n' - raise ValueError(msg) - else: - msg = "invalid argument for n_splits" - raise ValueError(msg) - - ################################################# - # circular training for each parameter set - ################################################# - - accession_keys_str = ["::".join(x) for x in accession_keys] - - search_results = [] - search_config_store = [] - for search_number, idx in enumerate(param_idx): - # get the parameter set - params = param_sets[idx] - # update the base config with search parameters - search_config = copy.deepcopy(config) - for p_i in range(len(params)): - keys2update = accession_keys[p_i] - value2set = params[p_i] - # updates in place - _updated_nested( - search_config, - keys2update, - value2set, - ) - - # disable checkpoints, tensorboard to reduce I/O - search_config["save_freq"] = 10000 - search_config["tensorboard"] = False - # add search number to config - search_config["search_number"] = search_number - - search_config_store.append( - copy.deepcopy(search_config), - ) - logger.info("searching config:") - logger.info(f"{search_config}") - - for split_number in range(n_splits): - # set the relevant split indices - adata.obs["scNym_split"] = adata.obs[f"scNym_split_{split_number}"] - # set the split number - split_config = copy.deepcopy(search_config) - split_config["split_number"] = split_number - search_df = _circular_train( - search_config=split_config, - params=params, - adata=adata, - groupby=groupby, - out_path=out_path, - accession_keys=accession_keys, - hold_out_only=hold_out_only, - groupby_eval=groupby_eval, - ) - # add the split information - search_df["split_number"] = split_number - search_df["search_number"] = search_number - # save results - search_results.append(search_df) - - # concatenate - search_results = pd.concat(search_results, 0) - best_idx = np.argmax(search_results["acc"]) - best_search = int(search_results.iloc[best_idx]["search_number"]) - - best_config = search_config_store[best_search] - print(">>>>>>") - print("Best config") - print(best_config) - print(">>>>>>") - print() - return search_results, best_config - - -def scnym_interpret( - adata: AnnData, - groupby: str, - source: str, - target: str, - trained_model: str, - **kwargs, -) -> dict: - """ - Extract salient features motivating scNym model predictions by estimating - expected gradients. - - Parameters - ---------- - adata - Annotated data matrix used for training or prediction. - If `"scNym_split"` in `.obs_keys()`, uses the cells annotated - `"train", "val"` to select data splits. - groupby - Column in `adata.obs` that contains cell identity annotations. - Values of `"Unlabeled"` indicate that a given cell should be used - only as unlabeled data during training. - source : str - class name for source class in `adata.obs[groupby]`. - target : str - class name for target class in `adata.obs[groupby]`. - trained_model - Path to the output directory of an scNym training run - or a string specifying a pretrained model. - If provided while `task == "train"`, used as an initialization. - kwargs : dict - keyword arguments passed to `scnym.interpret.ExpectedGradients.query(...)`. - - Returns - ------- - expgrad : dict - "gradients" - [Cells, Features] pd.DataFrame of expected gradients for - the target class. - "saliency" - [Features,] pd.Series of mean expected gradients across query - cells, sorted by saliency positive -> negative. - - See Also - -------- - scnym.interpret.ExpectedGradients - """ - # check if a pretrained model was requested - if "pretrained_" in trained_model: - msg = "Pretrained Request Error\n" - msg += "Pretrained weights are no longer supported in scNym.\n" - raise NotImplementedError(msg) - - # load training parameters - with open( - osp.join(trained_model, "scnym_train_results.pkl"), - "rb", - ) as f: - results = pickle.load(f) - - # setup a model object for interpretation - clf = model.CellTypeCLF( - n_genes=results["n_genes"], - n_cell_types=results["n_cell_types"], - **results["model_kwargs"], - ) - - model_weights_path = osp.join( - trained_model, - "00_best_model_weights.pkl", - ) - clf.load_state_dict( - torch.load( - model_weights_path, - map_location="cpu", - ) - ) - if torch.cuda.is_available(): - clf = clf.cuda() - logger.info("Model moved to CUDA compute device.") - - # setup expected gradients - EG = interpret.ExpectedGradient( - model=clf, - gene_names=np.array(results["gene_names"]), - class_names=np.array(results["class_names"]), - ) - - # perform expected gradient estimation - saliency = EG.query( - adata=adata, - source=source, - target=target, - cell_type_col=groupby, - **kwargs, - ) - gradients = EG.gradients - - r = { - "saliency": saliency, - "gradients": gradients, - } - return r diff --git a/build/lib/scnym/attributionpriors.py b/build/lib/scnym/attributionpriors.py deleted file mode 100644 index 978ec93..0000000 --- a/build/lib/scnym/attributionpriors.py +++ /dev/null @@ -1,605 +0,0 @@ -#!/usr/bin/env python -# adopted from https://github.com/suinleelab/attributionpriors -import functools -import operator -from typing import Callable, Union -import numpy as np -import torch -from torch.autograd import grad -from torch.utils.data import DataLoader -import logging - -logger = logging.getLogger(__name__) - -DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - -def gather_nd(params, indices): - """ - Args: - params: Tensor to index - indices: k-dimension tensor of integers. - Returns: - output: 1-dimensional tensor of elements of ``params``, where - output[i] = params[i][indices[i]] - - params indices output - - 1 2 1 1 4 - 3 4 2 0 ----> 5 - 5 6 0 0 1 - """ - max_value = functools.reduce(operator.mul, list(params.size())) - 1 - indices = indices.t().long() - ndim = indices.size(0) - idx = torch.zeros_like(indices[0]).long() - m = 1 - - for i in range(ndim)[::-1]: - idx += indices[i] * m - m *= params.size(i) - - idx[idx < 0] = 0 - idx[idx > max_value] = 0 - return torch.take(params, idx) - - -def adj2lap( - adj: torch.FloatTensor, -) -> torch.FloatTensor: - """Convert an adjacency matrix to a graph Laplacian - - Notes - ----- - Graph Laplacian is - - .. math:: - - L = D - A - - where :math:`D` is a diagonal matrix with the degree of - each node and :math:`A` is the graph adjacency matrix. - """ - adj = (adj > 0).float() - row_sum = torch.sum(adj, dim=1) - # constructs [n_vertices, n_vertices] with row_sum on diagonal - D = torch.diag(row_sum) - return D - adj - - -def tgini(x): - mad = torch.mean(torch.abs(x.reshape(-1, 1) - x.reshape(1, -1))) - rmad = mad / torch.mean(x) - g = 0.5 * rmad - return g - - -def gini_eg(shaps: torch.FloatTensor) -> torch.FloatTensor: - """Gini coefficient sparsity prior - - Parameters - ---------- - shaps : torch.FloatTensor - [Observations, Features] estimated Shapley values. - - Returns - ------- - gini_prior : torch.FloatTensor - inverse Gini coefficient prior penalty. - """ - abs_attrib = shaps.abs() - return -tgini(abs_attrib.mean(0)) - - -def gini_classwise_eg( - shaps: torch.FloatTensor, - target: torch.LongTensor, -) -> torch.FloatTensor: - """Compute Gini coefficient sparsity prior within individual - classes. This allows each class to have a unique set of sparsely - activated features, rather than globally requiring all classes - to use the same small feature set. - - Parameters - ---------- - shaps : torch.FloatTensor - [Observations, Features] estimated Shapley values. - target : torch.LongTenspr - [Observations,] int class labels. - - Returns - ------- - gini_prior : torch.FloatTensor - inverse Gini coefficient prior penalty. - """ - classes = torch.unique(target) - ginis = torch.zeros((len(classes),)).to(device=shaps.device) - n_obs = torch.zeros((len(classes),)).to(device=shaps.device) - for i, c in enumerate(classes): - c_shaps = shaps[target == c] - c_gini = gini_eg(c_shaps) - ginis[i] = c_gini - n_obs[i] = c_shaps.size(0) - # compute weighted gini coefficient - p_obs = n_obs / torch.sum(n_obs) - weighted_gini = torch.sum(p_obs * ginis) - return weighted_gini - - -def graph_eg( - shaps: torch.FloatTensor, - graph: torch.FloatTensor, -) -> torch.FloatTensor: - """Graph attribution prior - - Parameters - ---------- - shaps : torch.FloatTensor - [Observations, Features] estimated Shapley values. - graph : torch.FloatTensor - [Features, Features] adjacency matrix (weighted or binary). - - Returns - ------- - graph_prior : torch.FloatTensor - graph prior penalty. - """ - # get mean gradient for each feature - feature_grad = torch.mean(shaps, dim=0) - # get a matrix of differences between feature grads - cols = feature_grad.view(1, -1).repeat(feature_grad.size(0), 1) - rows = feature_grad.view(-1, 1).repeat(1, feature_grad.size(0)) - # delta[i, j] is grad_i - grad_j - delta = rows - cols - # "Gaussian" penalty is just square of delta - penalty = torch.pow(delta, 2) - weighted_penalty = penalty * graph - return weighted_penalty - - -def check(key, sets: dict, reference: set) -> list: - return [x in reference for x in sets[key]] - - -class AttributionPriorExplainer(object): - def __init__( - self, - background_dataset: torch.utils.data.Dataset, - batch_size: int, - random_alpha: bool = True, - k: int = 1, - scale_by_inputs: bool = True, - abs_scale: bool = True, - input_batch_index: Union[str, int, tuple] = None, - ) -> None: - """Estimates feature gradients using expected gradients. - - Parameters - ---------- - background_dataset : torch.utils.data.Dataset - dataset of samples to use as background references. - most commonly, this is the whole training set. - batch_size : int - batch size used for training. must be the same as the - batch size for the training dataloader. - random_alpha : bool - use randomized `alpha ~ Unif(0, 1)` values for computing - an intermediary sample between the reference and target - sample at each minibatch. - k : int - number of references to use per training example per minibatch. - `k=1` works well as a default with minimal computational - overhead. - scale_by_inputs : bool - scale expected gradient values using a dot-product with the - difference `(input-reference)` feature values. - abs_scale : bool - only considered if `scale_by_inputs=True`. Rather than scaling - by the raw difference, scale by the absolute value of the - difference. - input_batch_index : Union[str,int,tuple], optional - key for extracting the input values from a batch drawn from - `background_dataset`. e.g. if batches are stored in `dict`, - this is the key for the input tensor. if batches are `tuple`, - this is the index of the input tensor. - - Returns - ------- - None. - - References - ---------- - https://github.com/suinleelab/attributionpriors - """ - self.random_alpha = random_alpha - self.k = k - self.scale_by_inputs = scale_by_inputs - self.abs_scale = abs_scale - self.batch_size = batch_size - self.ref_set = background_dataset - self.ref_sampler = DataLoader( - dataset=background_dataset, - batch_size=batch_size * k, - shuffle=True, - drop_last=True, - ) - self.input_batch_index = input_batch_index - return - - def _get_ref_batch( - self, - k=None, - ): - """Get a batch from the reference dataset""" - b = next(iter(self.ref_sampler)) - if self.input_batch_index is not None: - # extract the input tensor using a provided index - b = b[self.input_batch_index].float() - b = b.to(device=self.DEFAULT_DEVICE) - if self.batch_transformation is not None: - # transform the reference batch with a specified transformation - b = self.batch_transformation(b) - return b - - def _get_samples_input( - self, - input_tensor: torch.FloatTensor, - reference_tensor: torch.FloatTensor, - ) -> torch.FloatTensor: - """ - Calculate interpolation points between input samples and reference - samples. - - Parameters - ---------- - input_tensor : torch.FloatTensor - shape (batch, ...), where ... indicates the input dimensions. - reference_tensor : torch.FloatTensor - shape (batch, k, ...) where k represents the number of - background reference samples to draw per input in the batch. - - Returns - ------- - samples_input : torch.FloatTensor - shape (batch, k, ...) with the interpolated points between - input and ref. - - Notes - ----- - For integrated gradients, we compute some `M=100+` samples interpolating - between each input and a relevant reference sample. For expected - gradients, we rather compute interpolation points that lie randomly - along the linear path between the sample and reference in each minibatch. - """ - input_dims = list(input_tensor.size())[1:] - num_input_dims = len(input_dims) - - batch_size = reference_tensor.size()[0] - k_ = reference_tensor.size()[1] - - # Grab a [batch_size, k]-sized interpolation sample - if self.random_alpha: - t_tensor = ( - torch.FloatTensor(batch_size, k_).uniform_(0, 1).to(self.DEFAULT_DEVICE) - ) - else: - if k_ == 1: - t_tensor = torch.cat( - [torch.Tensor([1.0]) for i in range(batch_size)] - ).to(self.DEFAULT_DEVICE) - else: - t_tensor = torch.cat( - [torch.linspace(0, 1, k_) for i in range(batch_size)] - ).to(self.DEFAULT_DEVICE) - - shape = [batch_size, k_] + [1] * num_input_dims - interp_coef = t_tensor.view(*shape) - - # Evaluate the end points - end_point_ref = (1.0 - interp_coef) * reference_tensor - - input_expand_mult = input_tensor.unsqueeze(1) - end_point_input = interp_coef * input_expand_mult - - # A fine Affine Combine - samples_input = end_point_input + end_point_ref - return samples_input - - def _get_samples_delta( - self, - input_tensor: torch.FloatTensor, - reference_tensor: torch.FloatTensor, - ) -> torch.FloatTensor: - """Compute the distance in feature space between input samples - and reference samples. - - Parameters - ---------- - input_tensor : torch.FloatTensor - shape (batch, ...), where ... indicates the input dimensions. - reference_tensor : torch.FloatTensor - shape (batch, k, ...) where k represents the number of - background reference samples to draw per input in the batch. - - Returns - ------- - sd : torch.FloatTensor - (batch, k, ...) differences in each feature between input - samples and the assigned reference. - """ - input_expand_mult = input_tensor.unsqueeze(1) - sd = input_expand_mult - reference_tensor - if self.abs_scale: - sd = torch.abs(sd) - return sd - - def _get_grads( - self, - samples_input: torch.FloatTensor, - model: torch.nn.Module, - sparse_labels: torch.LongTensor = None, - ) -> torch.FloatTensor: - """Compute gradients for a given model and input tensor, - taking into account sparse labels if provided. - - Parameters - ---------- - samples_input : torch.FloatTensor - (batch, k, ...) input features. - during training, these are interpolated samples between input - and reference. - during evaluation, these are raw input samples. - model : torch.nn.Module - model for evaluation. - sparse_labels : torch.LongTensor, optional - (batch, classes) one-hot labels for class assignments. - must be provided if `classes > 1`. - - Returns - ------- - grad_tensor : torch.FloatTensor - (batch, ...) gradient values - """ - samples_input.requires_grad = True - - grad_tensor = torch.zeros(samples_input.shape).float().to(self.DEFAULT_DEVICE) - - for i in range(self.k): - particular_slice = samples_input[:, i] - batch_output = model(particular_slice) - # should check that users pass in sparse labels - # Only look at the user-specified label - # if there is only one class, `batch_output` is already `(batch, 1)` - if batch_output.size(1) > 1: - if sparse_labels is None: - msg = "`sparse_labels` must be provided if more than one\n" - msg += "output class is present." - raise TypeError(msg) - - sample_indices = torch.arange(0, batch_output.size(0)).to( - self.DEFAULT_DEVICE - ) - indices_tensor = torch.cat( - [ - sample_indices.unsqueeze(1), - sparse_labels.unsqueeze(1), - ], - dim=1, - ) - # gathers the relevant class output for each sample to create - # batch_output shape : (batch, 1). - batch_output = gather_nd(batch_output, indices_tensor) - - model_grads = grad( - outputs=batch_output, - inputs=particular_slice, - grad_outputs=torch.ones_like(batch_output).to(self.DEFAULT_DEVICE), - create_graph=True, - ) - grad_tensor[:, i, :] = model_grads[0] - return grad_tensor - - def shap_values( - self, - model: torch.nn.Module, - input_tensor: torch.FloatTensor, - sparse_labels: torch.LongTensor = None, - batch_transformation: Callable = None, - ) -> torch.FloatTensor: - """ - Calculate expected gradients approximation of Shapley values for the - sample ``input_tensor``. - - Parameters - ---------- - model : torch.nn.Module - Pytorch model for which the output should be explained. - input_tensor : torch.Tensor - (batch, ...) tensor representing the input to be explained, - where `...` are feature dimensions. - sparse_labels : torch.LongTensor, optional - (batch, classes) one-hot class labels. - not required if only one output class is present. - batch_transformation : Callable, optional. - transformation to apply to reference batches after drawing. - - Returns - ------- - expected_grads : torch.FloatTensor - (batch, ...) expected gradients for each sample in the input. - """ - # set device to use - self.DEFAULT_DEVICE = list(model.parameters())[0].device - # set a batch transformation if applicable - self.batch_transformation = batch_transformation - if batch_transformation is not None and not callable(batch_transformation): - msg = "`batch_transformation` arguments must be callable." - raise TypeError(msg) - # sample a batch from the reference dataset and reshape - # to match the inputs - reference_tensor = self._get_ref_batch() - shape = reference_tensor.shape - reference_tensor = reference_tensor.view( - self.batch_size, self.k, *(shape[1:]) - ).to(self.DEFAULT_DEVICE) - # get interpolation points between provided inputs and the - # assigned reference sample for each sample in the batch - samples_input = self._get_samples_input(input_tensor, reference_tensor) - # compute the difference across each feature between - # input and reference samples - samples_delta = self._get_samples_delta(input_tensor, reference_tensor) - # compute gradients on label scores w.r.t. the interpolation inputs - grad_tensor = self._get_grads(samples_input, model, sparse_labels) - # scale the gradient tensor by the difference - mult_grads = ( - samples_delta * grad_tensor if self.scale_by_inputs else grad_tensor - ) - expected_grads = mult_grads.mean(1) - return expected_grads - - -class VariableBatchExplainer(AttributionPriorExplainer): - """ - Subclasses AttributionPriorExplainer to avoid pre-specified batch size. Will adapt batch - size based on shape of input tensor. - """ - - def __init__(self, background_dataset, random_alpha=True, scale_by_inputs=True): - """ - Arguments: - background_dataset: PyTorch dataset - may not work with iterable-type (vs map-type) datasets - random_alpha: boolean - Whether references should be interpolated randomly (True, corresponds - to Expected Gradients) or on a uniform grid (False - corresponds to Integrated Gradients) - """ - self.random_alpha = random_alpha - self.k = None - self.scale_by_inputs = scale_by_inputs - self.ref_set = background_dataset - self.ref_sampler = DataLoader( - dataset=background_dataset, batch_size=1, shuffle=True, drop_last=True - ) - self.refs_needed = -1 - return - - def _get_ref_batch(self, refs_needed=None): - """ - Arguments: - refs_needed: int - number of references to provide - """ - if refs_needed != self.refs_needed: - self.ref_sampler = DataLoader( - dataset=self.ref_set, - batch_size=refs_needed, - shuffle=True, - drop_last=True, - ) - self.refs_needed = refs_needed - return next(iter(self.ref_sampler))[0].float() - - def shap_values(self, model, input_tensor, sparse_labels=None, k=1): - """ - Arguments: - base_model: PyTorch network - input_tensor: PyTorch tensor to get attributions for, as in normal torch.nn.Module API - sparse_labels: np.array of sparse integer labels, i.e. 0-9 for MNIST. Used if you only - want to explain the prediction for the true class per sample. - k: int - Number of references to use default for explanations. As low as 1 for training. - 100-200 for reliable explanations. - """ - self.k = k - n_input = input_tensor.shape[0] - refs_needed = n_input * self.k - # This is a reasonable check but prevents compatibility with non-Map datasets - assert refs_needed <= len( - self.ref_set - ), "Can't have more samples*references than there are reference points!" - reference_tensor = self._get_ref_batch(refs_needed) - shape = reference_tensor.shape - reference_tensor = reference_tensor.view(n_input, self.k, *(shape[1:])).to( - DEFAULT_DEVICE - ) - samples_input = self._get_samples_input(input_tensor, reference_tensor) - samples_delta = self._get_samples_delta(input_tensor, reference_tensor) - grad_tensor = self._get_grads(samples_input, model, sparse_labels) - mult_grads = ( - samples_delta * grad_tensor if self.scale_by_inputs else grad_tensor - ) - expected_grads = mult_grads.mean(1) - - return expected_grads - - -class ExpectedGradientsModel(torch.nn.Module): - """ - Wraps a PyTorch model (one that implements torch.nn.Module) so that model(x) - produces SHAP values as well as predictions (controllable by 'shap_values' - flag. - """ - - def __init__( - self, base_model, refset, k=1, random_alpha=True, scale_by_inputs=True - ): - """ - Arguments: - base_model: PyTorch network that subclasses torch.nn.Module - refset: PyTorch dataset - may not work with iterable-type (vs map-type) datasets - k: int - Number of references to use by default for explanations. As low as 1 for training. - 100-200 for reliable explanations. - """ - super(ExpectedGradientsModel, self).__init__() - self.k = k - self.base = base_model - self.refset = refset - self.random_alpha = random_alpha - self.exp = VariableBatchExplainer( - self.refset, - random_alpha=random_alpha, - scale_by_inputs=scale_by_inputs, - ) - - def forward(self, x, shap_values=False, sparse_labels=None, k=1): - """ - Arguments: - x: PyTorch tensor to predict with, as in normal torch.nn.Module API - shap_values: Binary flag -- whether to produce SHAP values - sparse_labels: np.array of sparse integer labels, i.e. 0-9 for MNIST. Used if you only - want to explain the prediction for the true class per sample. - k: int - Number of references to use default for explanations. As low as 1 for training. - 100-200 for reliable explanations. - """ - output = self.base(x) - if not shap_values: - return output - else: - shaps = self.exp.shap_values(self.base, x, sparse_labels=sparse_labels, k=k) - return output, shaps - - -def tmp(): - """ - def convert_csr_to_sparse_tensor_inputs(X): - coo = sp.coo_matrix(X) - indices = np.mat([coo.row, coo.col]).transpose() - return indices, coo.data, coo.shape - - def graph_mult(values, indices, shape, y): - # sparse tensor multiplication function - x_tensor = tf.SparseTensor(indices, values, shape) - out_layer = tf.sparse_tensor_dense_matmul(x_tensor, y) - return out_layer - - def adj_to_lap(x): - # calculate graph laplacian from adjacency matrix - rowsum = np.array(x.sum(1)) - D = sp.diags(rowsum) - return D - x - - adj = adj_to_lap(adj) - adj_indices, adj_values, adj_shape = convert_csr_to_sparse_tensor_inputs(adj) - - # ... during training ... - ma_eg = tf.reduce_mean(tf.abs(expected_gradients_op),axis=0) - graph_reg = tf.matmul(tf.transpose(graph_mult(adj_values, adj_indices, adj_shape, ma_eg[145:,:])),ma_eg[145:,:]) - """ - # pass - return diff --git a/build/lib/scnym/dataprep.py b/build/lib/scnym/dataprep.py deleted file mode 100644 index 1dbf1f0..0000000 --- a/build/lib/scnym/dataprep.py +++ /dev/null @@ -1,765 +0,0 @@ -import torch -import numpy as np -from scipy import sparse -from torch.utils.data import Dataset -from typing import Callable, Any, Union -import logging - - -logger = logging.getLogger(__name__) - - -class SingleCellDS(Dataset): - """Dataset class for loading single cell profiles. - - Attributes - ---------- - X : np.ndarray, sparse.csr_matrix - [Cells, Genes] cell profiles. - y_labels : np.ndarray, sparse.csr_matrix - [Cells,] integer class labels. - y : torch.FloatTensor - [Cells, Classes] one hot labels. - transform : Callable - performs data transformation operations on a - `sample` dict. - num_classes : int - number of classes in the dataset. default `-1` infers - the number of classes as `len(unique(y))`. - """ - - def __init__( - self, - X: Union[sparse.csr.csr_matrix, np.ndarray], - y: Union[sparse.csr.csr_matrix, np.ndarray], - domain: Union[sparse.csr.csr_matrix, np.ndarray] = None, - transform: Callable = None, - num_classes: int = -1, - num_domains: int = -1, - ) -> None: - """ - Load single cell expression profiles. - - Parameters - ---------- - X : np.ndarray, sparse.csr_matrix - [Cells, Genes] expression count matrix. - scNym models expect ln(Counts Per Million + 1). - Pathfinder models expect raw counts. - y : np.ndarray, sparse.csr_matrix - [Cells,] integer cell type labels. - domain : np.ndarray, sparse.csr_matrix - [Cells,] integer domain labels. - transform : Callable - transform to apply to samples. - num_classes : int - total number of classes for the task. - num_domains : int - total number of domains for the task. - - Returns - ------- - None. - """ - super(SingleCellDS, self).__init__() - - # check types on input arrays - if type(X) not in ( - np.ndarray, - sparse.csr_matrix, - ): - msg = f"X is type {type(X)}, must `np.ndarray` or `sparse.csr_matrix`" - raise TypeError(msg) - - if type(y) not in ( - np.ndarray, - sparse.csr_matrix, - ): - msg = f"X is type {type(y)}, must `np.ndarray` or `sparse.csr_matrix`" - raise TypeError(msg) - - if type(y) != np.ndarray: - # densify labels - y = y.toarray() - - self.X = X - self.y_labels = torch.from_numpy(y).long() - self.y = torch.nn.functional.one_hot( - self.y_labels, - num_classes=num_classes, - ).float() - - self.dom_labels = domain - if self.dom_labels is not None: - self.dom = torch.nn.functional.one_hot( - torch.from_numpy(self.dom_labels).long(), - num_classes=num_domains, - ).float() - else: - self.dom = np.zeros_like(self.y) - 1 - - self.transform = transform - - if not self.X.shape[0] == self.y.shape[0]: - sizes = (self.X.shape[0], self.y.shape[0]) - raise ValueError("X rows %d not equal to y rows %d." % sizes) - return - - def __len__( - self, - ) -> int: - """Return the number of examples in the data set.""" - return self.X.shape[0] - - def __getitem__( - self, - idx: int, - ) -> dict: - """Get a single cell expression profile and corresponding label. - - Parameters - ---------- - idx : int - index value in `range(len(self))`. - - Returns - ------- - sample : dict - 'input' - torch.FloatTensor, input vector - 'output' - torch.LongTensor, target label - """ - if type(idx) != int: - raise TypeError(f"indices must be int, you passed {type(idx)}, {idx}") - - # check if the idx value is valid given the dataset size - if idx < 0 or idx > len(self): - vals = (idx, len(self)) - raise ValueError("idx %d is invalid for dataset with %d examples." % vals) - - # retrieve relevant sample vector and associated label - # store in a hash table for later manipulation and retrieval - - # input_ is either an `np.ndarray` or `sparse.csr.csr_matrix` - input_ = self.X[idx, ...] - # label is already a `torch.Tensor` - label = self.y[idx] - - # if the corresponding vectors are sparse, convert them to dense - # we perform this operation on a samplewise-basis to avoid - # storing the whole count matrix in dense format - if type(input_) != np.ndarray: - input_ = input_.toarray() - - input_ = torch.from_numpy(input_).float() - if input_.size(0) == 1: - input_ = input_.squeeze() - - sample = { - "input": input_, - "output": label, - } - - sample["domain"] = self.dom[idx] - - # if a transformer was supplied, apply transformations - # to the sample vector and label - if self.transform is not None: - sample = self.transform(sample) - return sample - - -def balance_classes( - y: np.ndarray, - class_min: int = 256, -) -> np.ndarray: - """ - Perform class balancing by undersampling majority classes - and oversampling minority classes, down to a minimum value. - - Parameters - ---------- - y : np.ndarray - class assignment indices. - class_min : int - minimum number of examples to use for a class. - below this value, minority classes will be oversampled - with replacement. - - Returns - ------- - all_idx : np.ndarray - indices for balanced classes. some indices may be repeated. - """ - # determine the size of the smallest class - # if < `class_min`, we oversample to `class_min` samples. - classes, counts = np.unique(y, return_counts=True) - min_count = int(np.min(counts)) - if min_count < class_min: - min_count = class_min - - # generate indices with equal representation of each class - all_idx = [] - for i, c in enumerate(classes): - class_idx = np.where(y == c)[0].astype("int") - rep = counts[i] < min_count # oversample minority classes - if rep: - print("Count for class %s is %d. Oversampling." % (c, counts[i])) - ridx = np.random.choice(class_idx, size=min_count, replace=rep) - all_idx += [ridx] - all_idx = np.concatenate(all_idx).astype("int") - return all_idx - - -class LibrarySizeNormalize(object): - """Perform library size normalization.""" - - def __init__( - self, - counts_per_cell_after: int = int(1e6), - log1p: bool = True, - ) -> None: - self.counts_per_cell_after = counts_per_cell_after - self.log1p = log1p - return - - def __call__( - self, - sample: dict, - ) -> dict: - """Perform library size normalization in-place - on a sample dict. - - Parameters - ---------- - sample : dict - 'input' - torch.FloatTensor, input vector [N, C] - 'output' - torch.LongTensor, target label [N,] - - Returns - ------- - sample : dict - 'input' - torch.FloatTensor, input vector [N, C] - 'output' - torch.LongTensor, target label [N,] - """ - input_ = sample["input"] - size = torch.sum(input_, dim=1).reshape(-1, 1) - - # get proportions of each feature per sample, - # scale by `counts_per_cell_after` - prop_input_ = input_ / size - norm_input_ = prop_input_ * self.counts_per_cell_after - if self.log1p: - norm_input_ = torch.log1p(norm_input_) - sample["input"] = norm_input_ - return sample - - -class ExpMinusOne(object): - def __init__( - self, - ) -> None: - """Perform an exponential minus one transformation - on an input vector""" - return - - def __call__( - self, - sample: dict, - ) -> dict: - """Perform an exponential minus one transformation - on the sample input.""" - sample["input"] = torch.expm1( - sample["input"], - ) - return sample - - -class MultinomialSample(object): - """Sample an mRNA abundance profile from a multinomial - distribution parameterized by observations. - """ - - def __init__( - self, - depth: tuple = (10000, 100000), - depth_ratio: tuple = None, - ) -> None: - """Sample an mRNA abundance profile from a multinomial - distribution parameterized by observations. - - Parameters - ---------- - depth : tuple - (min, max) depth for multinomial sampling. - depth_ratio : tuple - (min, max) ratio of profile depth for multinomial - sampling. supercedes `depth`. - - Returns - ------- - None. - """ - self.depth = depth - self.depth_ratio = depth_ratio - - if self.depth_ratio is not None: - self.depth = None - - return - - def __call__( - self, - sample: dict, - ) -> dict: - """ - Sample an mRNA profile from a multinomial - parameterized by observations. - - Parameters - ---------- - sample : dict - 'input' - torch.FloatTensor, input vector [N, C] - 'output' - torch.LongTensor, target label [N,] - - Returns - ------- - sample : dict - 'input' - torch.FloatTensor, input vector [N, C] - 'output' - torch.LongTensor, target label [N,] - - Notes - ----- - We perform multinomial sampling with a call to `np.random.multinomial` - for each observation. This may be faster in the future using the native - `torch.distributions.Multinomial`, but right now the sampling procedure - is incredibly slow. The implementation below is ~100X slower than our - `numpy` calls. - - ``` - multi = torch.distributions.Multinomial( - total_count=d, - probs=p, - ) - - m = multi.sample() - m = m.float() - ``` - - Follow: - https://github.com/pytorch/pytorch/issues/11931 - """ - # input is a torch.FloatTensor - # we assume x is NOT log-transformed - # cast to float64 to preserve precision of proportions - x = sample["input"].to(torch.float64) - size = torch.sum(x, dim=1).detach().cpu().numpy() - - # generate a relative abundance profile - p = x / torch.sum(x, dim=1).reshape(-1, 1) - # normalize to ensure roundoff errors don't - # give us p.sum() > 1 - idx = torch.where(p.sum(1) > 1) - for i in idx[0]: - p[i, :] = p[i, :] / np.min([p[i, :].sum(), 1.0]) - # sample a sequencing depth - if self.depth_ratio is None: - # tile the specified depth for all cells - depth = np.tile(np.array(self.depth).reshape(1, -1), (x.size(0), 1)).astype( - np.int - ) - else: - # compute a range of depths based on the library size - # of each observation - depth = np.concatenate( - [ - np.floor(self.depth_ratio[0] * size).reshape(-1, 1), - np.ceil(self.depth_ratio[1] * size).reshape(-1, 1), - ], - axis=1, - ).astype(np.int) - - # sample from a multinomial - # np.random.multinomial is ~100X faster than the native - # torch.distributions.Multinomial, implemented in Notes - m = np.zeros(x.size()) - for i in range(x.size(0)): - - d = int( - np.random.choice( - np.arange(depth[i, 0], depth[i, 1]), - size=1, - ) - ) - - m[i, :] = np.random.multinomial( - d, - pvals=p[i, :].detach().cpu().numpy(), - ) - m = torch.from_numpy(m).float() - m = m.to(device=x.device) - output = { - "input": m, - "output": sample["output"], - } - return output - - -class GeneMasking(object): - def __init__( - self, - p_drop: float = 0.1, - p_apply: float = 0.5, - sample_p_drop: bool = False, - ) -> None: - """Mask a subset of genes in the gene expression vector - with zeros. This may simulate a failed detection event. - This mask is applied to `p_apply`*100% of input vectors. - - Parameters - ---------- - p_drop : float - proportion of genes to mask with zeros. - p_apply : float - proportion of samples to mask. - sample_p_drop : bool - sample the proportion of genes to drop from - `Unif(0, p_drop)`. - - Returns - ------- - None. - """ - self.p_drop = p_drop - self.p_apply = p_apply - self.sample_p_drop = sample_p_drop - return - - def __call__( - self, - sample: dict, - ) -> dict: - """Mask a subset of genes.""" - do_apply = np.random.random() - if do_apply > self.p_apply: - # no-op - return sample - - # input is a torch.FloatTensor - x = sample["input"].clone() - - if self.sample_p_drop: - p_drop = np.random.random() * self.p_drop - else: - p_drop = self.p_drop - - # mask a proportion `p` of genes with `0` - # assume x [N, Genes] - n_genes = x.size(1) - for i in range(x.size(0)): - idx = np.random.choice( - np.arange(n_genes), - size=int(np.floor(n_genes * p_drop)), - replace=False, - ).astype(np.int) - x[i, idx] = 0 - - sample["input"] = x - return sample - - -class InputDropout(object): - def __init__( - self, - p_drop: float = 0.1, - ) -> None: - """Randomly mask `p_drop` genes. - - Parameters - ---------- - p_drop : float - proportion of genes to mask. - - Returns - ------- - None - """ - self.p_drop = p_drop - return - - def __call__( - self, - sample: dict, - ) -> dict: - sample["input"] = torch.nn.functional.dropout( - sample["input"], - p=self.p_drop, - inplace=False, - ) - return sample - - -class PoissonSample(object): - """Sample a gene expression profile based on gene-specific - Poisson distributions""" - - def __init__( - self, - depth: Union[float, tuple] = 1.0, - ) -> None: - """Sample a gene expression profile based on gene-specific - Poisson distributions. - - Parameters - ---------- - depth : tuple, float - (min_factor, max_factor) for scaling the rate of the Poisson - that samples are drawn from. Scaling down produces sparser - profiles, scaling up produces less sparse profiles. - if `float`, uses a single depth value. Default = 1. - - Returns - ------- - None. - - Notes - ----- - Treats a raw gene count as an estimate of the rate for a Poisson - distribution. - """ - self.depth = depth - return - - def __call__( - self, - sample: dict, - ) -> dict: - # input is a torch.FloatTensor - # we assume x is NOT log-transformed - x = sample["input"].to(torch.float64) - - if type(self.depth) != float: - # sample a scale factor for the rate in the specified interval - # Unif(r1, r2) = Unif(0, 1) * (r1 - r2) + r2 - logging.debug("Multiscale Poisson depths") - r = torch.rand(x.size(0)).to(device=x.device) - r = r * (self.depth[0] - self.depth[1]) + self.depth[1] - else: - logging.debug("Single scale Poisson sampling") - r = torch.ones(x.size(0)).to(device=x.device) - r *= self.depth - - logger.debug(f"Poisson rate: {r}") - logger.debug(f"Poisson sample: {x}") - # torch Poisson can't handle rates equal to zero - # here we manually set zero rates to eps, then zero - # them back out later - rate = x * r.view(-1, 1) - rate[x == 0.0] = 1.0 - P = torch.distributions.Poisson( - rate=rate, - ) - x_poisson = P.sample() - x_poisson[x == 0.0] = 0.0 - - assert x.size() == x_poisson.size() - - sample["input"] = x_poisson.float() - return sample - - -"""Implement MixUp training""" - - -def mixup( - a: torch.FloatTensor, - b: torch.FloatTensor, - gamma: torch.FloatTensor, -) -> torch.FloatTensor: - """Perform a MixUp operation. - This is effectively just a weighted average, where - `gamma = 0.5` yields the mean of `a` and `b`. - - Parameters - ---------- - a : torch.FloatTensor - [Batch, C] first sample matrix. - b : torch.FloatTensor - [Batch, C] second sample matrix. - gamma : torch.FloatTensor - [Batch,] MixUp coefficient. - - Returns - ------- - m : torch.FloatTensor - [Batch, C] mixed sample matrix. - """ - return gamma * a + (1 - gamma) * b - - -class SampleMixUp(object): - def __init__( - self, - alpha: float = 0.2, - keep_dominant_obs: bool = False, - ) -> None: - """Perform a MixUp operation on a sample batch. - - Parameters - ---------- - alpha : float - alpha parameter of the Beta distribution. - keep_dominant_obs : bool - use max(gamma, 1-gamma) for each pair of samples - so the identity of the dominant observation can be - associated with the mixed sample. - - Returns - ------- - None. - - References - ---------- - mixup: Beyond Empirical Risk Minimization - Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz - arXiv:1710.09412 - - Notes - ----- - Zhang et. al. note alpha [0.1, 0.4] improve performance on CIFAR-10, - while larger values of alpha induce underfitting. - """ - self.alpha = alpha - if alpha > 0.0: - self.beta = torch.distributions.beta.Beta( - self.alpha, - self.alpha, - ) - self.keep_dominant_obs = keep_dominant_obs - return - - def __call__( - self, - sample: dict, - ) -> dict: - """Perform a MixUp operation on the sample. - - Parameters - ---------- - sample : dict - 'input' - torch.FloatTensor, input vector - 'output' - torch.LongTensor, target label - - Returns - ------- - sample : dict - 'input' - torch.FloatTensor, input vector - 'output' - torch.LongTensor, target label - """ - if self.alpha == 0.0: - # mixup is deactivated, return the original - # sample without mixing - return sample - - input_ = sample["input"] - output = sample["output"] - - # randomly permute the input and output - ridx = torch.randperm(input_.size(0)) - r_input_ = input_[ridx] - r_output = output[ridx] - - # perform the mixup operation between the source - # data and the rearranged data -- random pairs - gamma = self.beta.sample((input_.size(0),)) - if self.keep_dominant_obs: - gamma, _ = torch.max( - torch.stack( - [ - gamma, - 1 - gamma, - ], - dim=1, - ), - dim=1, - ) - gamma = gamma.reshape(-1, 1) - # move gamma weights to the same device as the - # inputs - gamma = gamma.to(device=input_.device) - - mix_input_ = mixup(input_, r_input_, gamma=gamma) - mix_output = mixup(output, r_output, gamma=gamma) - - sample["input"] = mix_input_ - sample["output"] = mix_output - - # if there are additional tensors in sample, also mix - # them up - other_keys = [k for k in sample.keys() if k not in ("input", "output")] - for k in other_keys: - if type(sample[k]) == torch.Tensor: - sample[k] = mixup(sample[k], sample[k][ridx], gamma=gamma) - - # add the randomization index to the sample in case - # it's useful downstream - sample["random_idx"] = ridx - - return sample - - -################################################# -# Define augmentation series -################################################# - -from torchvision import transforms - - -def identity(x: Any) -> Any: - """Identity function""" - return x - - -AUGMENTATION_SCHEMES = { - "log1p_drop": transforms.Compose( - [ - ExpMinusOne(), - InputDropout( - p_drop=0.1, - ), - LibrarySizeNormalize(log1p=True), - ] - ), - "log1p_mask": transforms.Compose( - [ - ExpMinusOne(), - GeneMasking( - p_drop=0.1, - p_apply=0.5, - ), - LibrarySizeNormalize(log1p=True), - ] - ), - "log1p_poisson": transforms.Compose( - [ - ExpMinusOne(), - PoissonSample(), - LibrarySizeNormalize(log1p=True), - ] - ), - "log1p_poisson_drop": transforms.Compose( - [ - ExpMinusOne(), - PoissonSample(depth=(0.1, 2.0)), - InputDropout(p_drop=0.1), - LibrarySizeNormalize(log1p=True), - ] - ), - "count_poisson": transforms.Compose( - [ - PoissonSample(), - ] - ), - "None": identity, - "none": identity, - None: identity, -} diff --git a/build/lib/scnym/distributions.py b/build/lib/scnym/distributions.py deleted file mode 100644 index a0f1aad..0000000 --- a/build/lib/scnym/distributions.py +++ /dev/null @@ -1,420 +0,0 @@ -"""torch Distributions for use with scNym models - -Negative Binomial adopted from scvi-tools -https://github.com/YosefLab/scvi-tools/blob/42315756ba879b9421630696ea7afcd74e012a07/scvi/distributions/_negative_binomial.py -""" -import warnings -from typing import Optional, Tuple, Union - -import torch -import torch.nn.functional as F -from torch.distributions import Distribution, Gamma, Poisson, constraints -from torch.distributions.utils import ( - broadcast_all, - lazy_property, - logits_to_probs, - probs_to_logits, -) - - -def log_zinb_positive( - x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, pi: torch.Tensor, eps=1e-8 -): - """ - Log likelihood (scalar) of a minibatch according to a zinb model. - Parameters - ---------- - x - Data - mu - mean of the negative binomial (has to be positive support) (shape: minibatch x vars) - theta - inverse dispersion parameter (has to be positive support) (shape: minibatch x vars) - pi - logit of the dropout parameter (real support) (shape: minibatch x vars) - eps - numerical stability constant - Notes - ----- - We parametrize the bernoulli using the logits, hence the softplus functions appearing. - """ - # theta is the dispersion rate. If .ndimension() == 1, it is shared for all cells (regardless of batch or labels) - if theta.ndimension() == 1: - theta = theta.view( - 1, theta.size(0) - ) # In this case, we reshape theta for broadcasting - - softplus_pi = F.softplus(-pi) # uses log(sigmoid(x)) = -softplus(-x) - log_theta_eps = torch.log(theta + eps) - log_theta_mu_eps = torch.log(theta + mu + eps) - pi_theta_log = -pi + theta * (log_theta_eps - log_theta_mu_eps) - - case_zero = F.softplus(pi_theta_log) - softplus_pi - mul_case_zero = torch.mul((x < eps).type(torch.float32), case_zero) - - case_non_zero = ( - -softplus_pi - + pi_theta_log - + x * (torch.log(mu + eps) - log_theta_mu_eps) - + torch.lgamma(x + theta) - - torch.lgamma(theta) - - torch.lgamma(x + 1) - ) - mul_case_non_zero = torch.mul((x > eps).type(torch.float32), case_non_zero) - - res = mul_case_zero + mul_case_non_zero - - return res - - -def log_nb_positive(x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, eps=1e-8): - """ - Log likelihood (scalar) of a minibatch according to a nb model. - Parameters - ---------- - x - data - mu - mean of the negative binomial (has to be positive support) (shape: minibatch x vars) - theta - inverse dispersion parameter (has to be positive support) (shape: minibatch x vars) - eps - numerical stability constant - Notes - ----- - We parametrize the bernoulli using the logits, hence the softplus functions appearing. - """ - if theta.ndimension() == 1: - theta = theta.view( - 1, theta.size(0) - ) # In this case, we reshape theta for broadcasting - - log_theta_mu_eps = torch.log(theta + mu + eps) - - res = ( - theta * (torch.log(theta + eps) - log_theta_mu_eps) - + x * (torch.log(mu + eps) - log_theta_mu_eps) - + torch.lgamma(x + theta) - - torch.lgamma(theta) - - torch.lgamma(x + 1) - ) - - return res - - -def log_mixture_nb( - x: torch.Tensor, - mu_1: torch.Tensor, - mu_2: torch.Tensor, - theta_1: torch.Tensor, - theta_2: torch.Tensor, - pi_logits: torch.Tensor, - eps=1e-8, -): - """ - Log likelihood (scalar) of a minibatch according to a mixture nb model. - pi_logits is the probability (logits) to be in the first component. - For totalVI, the first component should be background. - Parameters - ---------- - x - Observed data - mu_1 - Mean of the first negative binomial component (has to be positive support) (shape: minibatch x features) - mu_2 - Mean of the second negative binomial (has to be positive support) (shape: minibatch x features) - theta_1 - First inverse dispersion parameter (has to be positive support) (shape: minibatch x features) - theta_2 - Second inverse dispersion parameter (has to be positive support) (shape: minibatch x features) - If None, assume one shared inverse dispersion parameter. - pi_logits - Probability of belonging to mixture component 1 (logits scale) - eps - Numerical stability constant - """ - if theta_2 is not None: - log_nb_1 = log_nb_positive(x, mu_1, theta_1) - log_nb_2 = log_nb_positive(x, mu_2, theta_2) - # this is intended to reduce repeated computations - else: - theta = theta_1 - if theta.ndimension() == 1: - theta = theta.view( - 1, theta.size(0) - ) # In this case, we reshape theta for broadcasting - - log_theta_mu_1_eps = torch.log(theta + mu_1 + eps) - log_theta_mu_2_eps = torch.log(theta + mu_2 + eps) - lgamma_x_theta = torch.lgamma(x + theta) - lgamma_theta = torch.lgamma(theta) - lgamma_x_plus_1 = torch.lgamma(x + 1) - - log_nb_1 = ( - theta * (torch.log(theta + eps) - log_theta_mu_1_eps) - + x * (torch.log(mu_1 + eps) - log_theta_mu_1_eps) - + lgamma_x_theta - - lgamma_theta - - lgamma_x_plus_1 - ) - log_nb_2 = ( - theta * (torch.log(theta + eps) - log_theta_mu_2_eps) - + x * (torch.log(mu_2 + eps) - log_theta_mu_2_eps) - + lgamma_x_theta - - lgamma_theta - - lgamma_x_plus_1 - ) - - logsumexp = torch.logsumexp(torch.stack((log_nb_1, log_nb_2 - pi_logits)), dim=0) - softplus_pi = F.softplus(-pi_logits) - - log_mixture_nb = logsumexp - softplus_pi - - return log_mixture_nb - - -def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6): - r""" - NB parameterizations conversion. - Parameters - ---------- - mu - mean of the NB distribution. - theta - inverse overdispersion. - eps - constant used for numerical log stability. (Default value = 1e-6) - Returns - ------- - type - the number of failures until the experiment is stopped - and the success probability. - """ - if not (mu is None) == (theta is None): - raise ValueError( - "If using the mu/theta NB parameterization, both parameters must be specified" - ) - logits = (mu + eps).log() - (theta + eps).log() - total_count = theta - return total_count, logits - - -def _convert_counts_logits_to_mean_disp(total_count, logits): - """ - NB parameterizations conversion. - Parameters - ---------- - total_count - Number of failures until the experiment is stopped. - logits - success logits. - Returns - ------- - type - the mean and inverse overdispersion of the NB distribution. - """ - theta = total_count - mu = logits.exp() * theta - return mu, theta - - -def _gamma(theta, mu): - concentration = theta - rate = theta / mu - # Important remark: Gamma is parametrized by the rate = 1/scale! - gamma_d = Gamma(concentration=concentration, rate=rate) - return gamma_d - - -class NegativeBinomial(Distribution): - r""" - Negative binomial distribution. - One of the following parameterizations must be provided: - (1), (`total_count`, `probs`) where `total_count` is the number of failures until - the experiment is stopped and `probs` the success probability. (2), (`mu`, `theta`) - parameterization, which is the one used by scvi-tools. These parameters respectively - control the mean and inverse dispersion of the distribution. - In the (`mu`, `theta`) parameterization, samples from the negative binomial are generated as follows: - 1. :math:`w \sim \textrm{Gamma}(\underbrace{\theta}_{\text{shape}}, \underbrace{\theta/\mu}_{\text{rate}})` - 2. :math:`x \sim \textrm{Poisson}(w)` - Parameters - ---------- - total_count - Number of failures until the experiment is stopped. - probs - The success probability. - mu - Mean of the distribution. - theta - Inverse dispersion. - validate_args - Raise ValueError if arguments do not match constraints - """ - - arg_constraints = { - "mu": constraints.greater_than_eq(0), - "theta": constraints.greater_than_eq(0), - } - support = constraints.nonnegative_integer - - def __init__( - self, - total_count: Optional[torch.Tensor] = None, - probs: Optional[torch.Tensor] = None, - logits: Optional[torch.Tensor] = None, - mu: Optional[torch.Tensor] = None, - theta: Optional[torch.Tensor] = None, - validate_args: bool = False, - ): - self._eps = 1e-8 - if (mu is None) == (total_count is None): - raise ValueError( - "Please use one of the two possible parameterizations. Refer to the documentation for more information." - ) - - using_param_1 = total_count is not None and ( - logits is not None or probs is not None - ) - if using_param_1: - logits = logits if logits is not None else probs_to_logits(probs) - total_count = total_count.type_as(logits) - total_count, logits = broadcast_all(total_count, logits) - mu, theta = _convert_counts_logits_to_mean_disp(total_count, logits) - else: - mu, theta = broadcast_all(mu, theta) - self.mu = mu - self.theta = theta - super().__init__(validate_args=validate_args) - - @property - def mean(self): - return self.mu - - @property - def variance(self): - return self.mean + (self.mean ** 2) / self.theta - - def sample( - self, sample_shape: Union[torch.Size, Tuple] = torch.Size() - ) -> torch.Tensor: - with torch.no_grad(): - gamma_d = self._gamma() - p_means = gamma_d.sample(sample_shape) - - # Clamping as distributions objects can have buggy behaviors when - # their parameters are too high - l_train = torch.clamp(p_means, max=1e8) - counts = Poisson( - l_train - ).sample() # Shape : (n_samples, n_cells_batch, n_vars) - return counts - - def log_prob(self, value: torch.Tensor) -> torch.Tensor: - if self._validate_args: - try: - self._validate_sample(value) - except ValueError: - warnings.warn( - "The value argument must be within the support of the distribution", - UserWarning, - ) - - return log_nb_positive(value, mu=self.mu, theta=self.theta, eps=self._eps) - - def _gamma(self): - return _gamma(self.theta, self.mu) - - -class ZeroInflatedNegativeBinomial(NegativeBinomial): - r""" - Zero-inflated negative binomial distribution. - One of the following parameterizations must be provided: - (1), (`total_count`, `probs`) where `total_count` is the number of failures until - the experiment is stopped and `probs` the success probability. (2), (`mu`, `theta`) - parameterization, which is the one used by scvi-tools. These parameters respectively - control the mean and inverse dispersion of the distribution. - In the (`mu`, `theta`) parameterization, samples from the negative binomial are generated as follows: - 1. :math:`w \sim \textrm{Gamma}(\underbrace{\theta}_{\text{shape}}, \underbrace{\theta/\mu}_{\text{rate}})` - 2. :math:`x \sim \textrm{Poisson}(w)` - Parameters - ---------- - total_count - Number of failures until the experiment is stopped. - probs - The success probability. - mu - Mean of the distribution. - theta - Inverse dispersion. - zi_logits - Logits scale of zero inflation probability. - validate_args - Raise ValueError if arguments do not match constraints - """ - - arg_constraints = { - "mu": constraints.greater_than_eq(0), - "theta": constraints.greater_than_eq(0), - "zi_probs": constraints.half_open_interval(0.0, 1.0), - "zi_logits": constraints.real, - } - support = constraints.nonnegative_integer - - def __init__( - self, - total_count: Optional[torch.Tensor] = None, - probs: Optional[torch.Tensor] = None, - logits: Optional[torch.Tensor] = None, - mu: Optional[torch.Tensor] = None, - theta: Optional[torch.Tensor] = None, - zi_logits: Optional[torch.Tensor] = None, - validate_args: bool = False, - ): - - super().__init__( - total_count=total_count, - probs=probs, - logits=logits, - mu=mu, - theta=theta, - validate_args=validate_args, - ) - self.zi_logits, self.mu, self.theta = broadcast_all( - zi_logits, self.mu, self.theta - ) - - @property - def mean(self): - pi = self.zi_probs - return (1 - pi) * self.mu - - @property - def variance(self): - raise NotImplementedError - - @lazy_property - def zi_logits(self) -> torch.Tensor: - return probs_to_logits(self.zi_probs, is_binary=True) - - @lazy_property - def zi_probs(self) -> torch.Tensor: - return logits_to_probs(self.zi_logits, is_binary=True) - - def sample( - self, sample_shape: Union[torch.Size, Tuple] = torch.Size() - ) -> torch.Tensor: - with torch.no_grad(): - samp = super().sample(sample_shape=sample_shape) - is_zero = torch.rand_like(samp) <= self.zi_probs - samp[is_zero] = 0.0 - return samp - - def log_prob(self, value: torch.Tensor) -> torch.Tensor: - try: - self._validate_sample(value) - except ValueError: - warnings.warn( - "The value argument must be within the support of the distribution", - UserWarning, - ) - return log_zinb_positive(value, self.mu, self.theta, self.zi_logits, eps=1e-08) diff --git a/build/lib/scnym/interpret.py b/build/lib/scnym/interpret.py deleted file mode 100644 index 70c9c98..0000000 --- a/build/lib/scnym/interpret.py +++ /dev/null @@ -1,1368 +0,0 @@ -"""Tools for interpreting trained scNym models""" -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -import pandas as pd -from scipy import sparse -import anndata - -# self -from .utils import build_classification_matrix, get_adata_asarray -from . import dataprep -from . import attributionpriors as attrprior - -# stdlib -import typing -import copy -import warnings -import logging -import time -from pathlib import Path - -logger = logging.getLogger(__name__) - - -class Salience(object): - """ - Performs backpropogation to compute gradients on a target - class with regards to an input. - - Notes - ----- - Saliency analysis computes a gradient on a target class - score :math:`f_i(x)` with regards to some input :math:`x`. - - - .. math:: - - S_i = \frac{\partial f_i(x)}{\partial x} - """ - - def __init__( - self, - model: nn.Module, - class_names: np.ndarray, - gene_names: np.ndarray = None, - layer_to_hook: int = None, - verbose: bool = False, - ) -> None: - """ - Performs backpropogation to compute gradients on a target - class with regards to an input. - - Parameters - ---------- - model : torch.nn.Module - trained scNym model. - class_names : np.ndarray - list of str names matching output nodes in `model`. - gene_names : np.ndarray, optional - gene names for the model. - layer_to_hook : int - index of the layer from which to record gradients. - defaults to the gene level input features. - - Returns - ------- - None. - """ - # ensure class names are unique for each output node - if len(np.unique(class_names)) != len(class_names): - msg = "`class_names` must all be unique." - raise ValueError(msg) - - self.class_names = np.array(class_names) - self.n_classes = len(class_names) - self.verbose = verbose - - # load model into CUDA compute if available - if torch.cuda.is_available(): - self.model = model.cuda() - else: - self.model = model - # ensure we're not in training mode - self.model = self.model.eval() - - self.gene_names = gene_names - - if layer_to_hook is None: - self._hook_first_layer_gradients() - else: - self._hook_nth_layer_gradients(n=layer_to_hook) - return - - def _hook_first_layer_gradients(self): - """Set up hooks to record gradients from the first linear - layer into a target tensor. - - References - ---------- - https://pytorch.org/docs/stable/nn.html#torch.nn.Module.register_backward_hook - """ - - def _record_gradients(module, grad_in, grad_out): - """Record gradients of a layer with the correct input - shape""" - self.gradients = grad_in[1] - if self.verbose: - print([x.size() if x is not None else "None" for x in grad_in]) - print("Hooked gradients to: ", module) - - for module in self.model.modules(): - if isinstance(module, nn.Linear) and module.in_features == len( - self.gene_names - ): - module.register_backward_hook(_record_gradients) - return - - def _hook_nth_layer_gradients(self, n: int): - """Set up hooks to record gradients from an arbitrary layer. - - References - ---------- - https://pytorch.org/docs/stable/nn.html#torch.nn.Module.register_backward_hook - """ - - def _record_gradients(module, grad_in, grad_out): - """Record gradients of a layer with the correct input - shape""" - self.gradients = grad_in[1] - if self.verbose: - print([x.size() if x is not None else "None" for x in grad_in]) - print("Hooked gradients to: ", module) - - module = list(self.model.modules())[n] - module.register_backward_hook(_record_gradients) - return - - def _guided_backprop_hooks(self): - """Set up forward and backward hook functions to perform - "Guided backpropogation" - - Notes - ----- - Guided backpropogation only passes positive gradients upward through the network. - - Normal backprop: - - .. math:: - - f_i^{(l + 1)} = ReLU(f_i^{(l)}) - - R_i^{(l)} = (f_i^{(l)} > 0) \cdot R_i^{(l+1)} - - where - - .. math:: - - R_i^{(l + 1)} = \frac{\partial f_{out}}{\partial f_i^{l + 1}} - - - By contrast, guided backpropogation only passes gradient values where both - the activates :math:`f_i^{(l)}` and the gradients :math:`R_i^{(l + 1)}` are - greater than :math:`0`. - - - References - ---------- - https://arxiv.org/pdf/1412.6806.pdf - - https://pytorch.org/docs/stable/nn.html#torch.nn.Module.register_forward_hook - https://pytorch.org/docs/stable/nn.html#torch.nn.Module.register_backward_hook - """ - - def _record_relu_outputs(module, in_, out_): - """Store the outputs to each ReLU layer""" - self.rectified_outputs.append( - out_, - ) - self.store_rectified_outputs.append( - out_, - ) - - def _clamp_grad(module, grad_in, grad_out): - """Clamp ReLU gradients to [0, inf] and return a - new gradient to be used in subsequent outputs. - """ - self.store_grad.append(grad_in[0]) - - grad = grad_in[0].clamp(min=0.0) - self.store_clamped_grad.append(grad) - - # here we pop the outputs off to ensure that the - # final output is always the current ReLU layer - # we're investigating - last_relu_output = self.rectified_outputs.pop() - last_relu_output = copy.copy(last_relu_output) - last_relu_output[last_relu_output > 0] = 1 - rectified_grad = last_relu_output * grad - - self.store_rectified_grad.append(rectified_grad) - return (rectified_grad,) - - self.store_rectified_outputs = [] - self.store_grad = [] - self.store_clamped_grad = [] - - for _, module in self.model.named_modules(): - if isinstance(module, nn.ReLU): - module.register_forward_hook(_record_relu_outputs) - module.register_backward_hook(_clamp_grad) - - return - - def get_saliency( - self, - x: torch.FloatTensor, - target_class: str, - guide_backprop: bool = False, - ) -> torch.FloatTensor: - """Compute the saliency of a target class on an input - vector `x`. - - Parameters - ---------- - x : torch.FloatTensor - [1, Genes] vector of gene expression. - target_class : str - class in `.class_names` for which to compute gradients. - guide_backprop : bool - perform "guided backpropogation" by clamping gradients - to only positive values at each ReLU. - see: https://arxiv.org/pdf/1412.6806.pdf - - Returns - ------- - salience : torch.FloatTensor - gradients on `target_class` with respect to `x`. - """ - if target_class not in self.class_names: - msg = f"{target_class} is not in `.class_names`" - raise ValueError(msg) - - target_idx = np.where(target_class == self.class_names)[0].astype(np.int) - target_idx = int(target_idx) - - self.model.zero_grad() - - if guide_backprop: - self.rectified_outputs = [] - self.store_rectified_grad = [] - self._guided_backprop_hooks() - - # store gradients on the input - if torch.cuda.is_available(): - x = x.cuda() - x.requires_grad = True - - # module hook will record gradients here - self.gradients = torch.zeros_like(x) - - # forward pass - output = self.model(x) - - # create a [N, C] tensor to store gradients - target = torch.zeros_like(output) - # set the target class to `1`, creating a one-hot - # of the target class - target[:, target_idx] = 1 - - # compute gradients with backprop - output.backward( - gradient=target, - ) - - # detach from the graph and move to main memory - target = target.detach().cpu() - - return self.gradients - - def rank_genes_by_saliency( - self, - **kwargs, - ) -> np.ndarray: - """ - Rank genes by saliency for a target class and input. - - Passes **kwargs to `.get_saliency` and uses the output - to rank genes. - - Returns - ------- - ranked_genes : np.ndarray - gene names with high saliency, ranked highest to - lowest. - """ - s = self.get_saliency(**kwargs) - sort_idx = torch.argsort(s) - idx = sort_idx[0].numpy()[::-1] - return self.gene_names[idx.astype(np.int)] - - -class IntegratedGradient(object): - def __init__( - self, - model: nn.Module, - class_names: typing.Union[list, np.ndarray], - gene_names: typing.Union[list, np.ndarray] = None, - grad_activation: str = "input", - verbose: bool = False, - ) -> None: - """Performs integrated gradient computations for feature attribution - in scNym models. - - Parameters - ---------- - model : torch.nn.Module - trained scNym model. - class_names : list or np.ndarray - list of str names matching output nodes in `model`. - gene_names : list or np.ndarray, optional - gene names for the model. - grad_activation : str - activations where gradients should be collected. - default "input" collects gradients at the level of input features. - verbose : bool - verbose outputs for stdout. - - Returns - ------- - None. - - Notes - ----- - Integrated gradients are computed as the path integral between a "baseline" - gene expression vector (all 0 counts) and an observed gene expression vector. - The path integral is computed along a straight line in the feature space. - - Stated formally, we define a our baseline gene expression vector as :math:`x`, - our observed vector as :math:`x'`, an scnym model :math:`f(\cdot)`, and a - number of steps :math:`M` for approximating the integral by Reimann sums. - - The integrated gradient :math:`\int \nabla` for a feature :math:`x_i` is then - - .. math:: - - r = \sum_{m=1}^M \partial f(x' + \frac{m}{M}(x - x')) / \partial x_i \\ - \int \nabla_i = (x_i' - x_i) \frac{1}{M} r - """ - self.model = copy.deepcopy(model) - if torch.cuda.is_available(): - self.model = self.model.cuda() - print("Model loaded on CUDA compute device.") - self.model.zero_grad() - for param in self.model.parameters(): - param.requires_grad = False - - # get gradients on the specified layer activation if - # the specified layer is not "input" - self.grad_activation = grad_activation - - if grad_activation == "input": - self.get_grad = self._get_grad_input - elif grad_activation == "first_layer": - self.get_grad = self._get_grad_first_layer - self.input2first = nn.Sequential(*list(model.modules())[3:7]) - self.first2output = nn.Sequential(*list(model.modules())[7:]) - else: - msg = f"`grad_activation={grad_activation}` is not implemented." - raise NotImplementedError(msg) - - self.class_names = class_names - self.gene_names = gene_names - self.verbose = verbose - self.grads_for_class = {} - - if type(self.class_names) == np.ndarray: - self.class_names = self.class_names.tolist() - - return - - def _get_grad_input( - self, - x: torch.Tensor, - target_class: str, - ) -> typing.Tuple[torch.Tensor, torch.Tensor]: - """Get the gradient on the observed features with respect - to a target class. - - Parameters - ---------- - x : torch.Tensor - [Batch, Features] input tensor. - target_class : str - target class for gradient computation. - - Returns - ------- - grad : torch.Tensor - [Batch, Features] feature gradients with respect to the - target class. - target : torch.Tensor - [Batch,] value of the target class score. - """ - target_idx = self.class_names.index(target_class) - - # store gradients on the input - if torch.cuda.is_available(): - x = x.cuda() - x.requires_grad = True - - # forward pass through the model - output = self.model(x) - sm_output = F.softmax(output, dim=-1) - - # get the softmax output on the target class for each - # observation as a loss - index = torch.ones(output.size(0)).view(-1, 1) * target_idx - index = index.long() - index = index.to(device=sm_output.device) - # `.gather(dim, index)` takes a dimension number and a tensor - # of indices size [Batch,] where each val is an integer index - # grabs the specific element for each observation along the given dim. - target = sm_output.gather(1, index) - - # zero any existing gradients - self.model.zero_grad() - if x.grad is not None: - x.grad.zero_() - target.backward() - - grad = x.grad.detach().cpu() - - return grad, target - - def _catch_grad(self, grad) -> None: - """Hook to catch gradients from an activation - of interest.""" - self.caught_grad = grad.detach() - return - - def _get_grad_first_layer( - self, - x: torch.Tensor, - target_class: str, - ): - """Get the gradient on the first layer activations. - - Parameters - ---------- - x : torch.Tensor - [Batch, Features] input tensor. e.g. first layer - embedding coordinates to pass to the rest of the model. - target_class : str - target class for gradient computation. - - Returns - ------- - grad : torch.Tensor - [Batch, Features] feature gradients with respect to the - target class. - target : torch.Tensor - [Batch,] value of the target class score. - """ - target_idx = self.class_names.index(target_class) - # store gradients on the input - if torch.cuda.is_available(): - x = x.cuda() - x.requires_grad = True - - # forward through the activation embedder - x.register_hook(self._catch_grad) - # forward through to outputs - output = self.first2output(x) - sm_output = F.softmax(output, dim=-1) - - # get the softmax output on the target class for each - # observation as a loss - index = torch.ones(output.size(0)).view(-1, 1) * target_idx - index = index.long() - index = index.to(device=sm_output.device) - # `.gather(dim, index)` takes a dimension number and a tensor - # of indices size [Batch,] where each val is an integer index - # grabs the specific element for each observation along the given dim. - target = sm_output.gather(1, index) - - # zero any existing gradients - self.model.zero_grad() - if x.grad is not None: - x.grad.zero_() - - target.backward() - grad = self.caught_grad - - return grad, target - - def _check_integration( - self, - integrated_grad: torch.Tensor, - ) -> bool: - """Check that the approximation of the path integral is appropriate. - If we used a sufficient number of steps in the Reimann sum, we should - find that the gradient sum is roughly equivalent to the difference in - class scores for the baseline vector and target vector. - """ - score_difference = self.raw_scores[-1] - self.raw_scores[0] - check = torch.isclose( - integrated_grad.sum(), - score_difference, - rtol=0.1, - ) - if not check: - msg = "integrated gradient magnitude does not match the difference in scores.\n" - msg += f"magnitude {integrated_grad.sum().item()} vs. {score_difference.item()}.\n" - msg += "consider using more steps to estimate the path integral." - warnings.warn(msg) - return check - - def get_integrated_gradient( - self, - x: torch.Tensor, - target_class: str, - M: int = 300, - baseline: torch.Tensor = None, - ) -> torch.Tensor: - """Compute the integrated gradient for a single observation. - - Parameters - ---------- - x : torch.Tensor - [Features,] input tensor. - target_class : str - class in `self.class_names` for optimization. - M : int - number of gradient steps to use when approximating - the path integral. - baseline : torch.Tensor - [Features,] baseline gene expression vector to use. - if `None`, uses the `0` vector. - - Returns - ------- - integrated_grad : torch.Tensor - [Features,] integrated gradient tensor. - - Notes - ----- - 1. Define a difference between the baseline input and observation. - 2. Approximate a linear path between the baseline and observation - with `M` steps. - 3. Compute the gradient at each step in the path. - 4. Sum gradients across steps and divide by number of steps. - 5. Elementwise multiply with input features as in saliency. - """ - if baseline is None: - n_dims = ( - len(self.gene_names) - if self.grad_activation == "input" - else self.model.n_hidden_init - ) - - if self.verbose: - print("Using the 0-vector as a baseline.") - base = self.baseline_input = torch.zeros((1, n_dims)).float() - else: - base = self.baseline_input = baseline - if base.dim() > 1 and base.size(0) != 1: - msg = "baseline must be a single gene expression vector" - raise ValueError(msg) - base = base.view(1, -1) - - self.target_class = target_class - - if x.dim() > 1 and x.size(0) == 1: - # tensor has an empty batch dimension, flatten it - x = x.view(-1) - - # create a batch of observations where each observation is - # a single step along the path integral - path = base.repeat((M, 1)) - - # if `first_layer` activations are used, x_activ is the relevant - # activation setting for saliency - if self.grad_activation == "first_layer": - x = x.to(device=list(self.input2first.parameters())[0].device) - x_rel = self.input2first(x.view(1, -1)).detach().cpu() - else: - x_rel = x - self.x_rel = x_rel - - # create a tensor marking the "step number" for each observation - step = ((x_rel - base) / M).view(1, -1) - step_coord = torch.arange(1, M + 1).view(-1, 1).repeat((1, path.size(1))) - - # add the correct number of steps to fill the path tensor - path += step * step_coord - - if self.verbose: - print("baseline", base.size()) - print(base.sort()) - print("observation", x.size()) - print(x.sort()) - print() - print("step : ", step.size()) - print(step) - print("step_coord : ", step_coord.size()) - print(step_coord) - print("path : ", path.size()) - print(path[0].sort()) - print("-" * 3) - print(path[-1].sort()) - - # compute the gradient on the input at each step - # along the path - grad_dim = ( - path.size(1) - if self.grad_activation == "input" - else self.model.n_hidden_init - ) - gradients = torch.zeros((path.size(0), grad_dim)) - scores = torch.zeros(path.size(0)) - - for m in range(M): - gradients[m, :], target_scores = self.get_grad( - path[m, :].view(1, -1), - self.target_class, - ) - scores[m] = target_scores - - self.raw_gradients = gradients - self.raw_scores = scores - self.path = path - - # sum gradients and normalize by step number - integrated_grad = x_rel * (gradients.sum(0) / M) - - self._check_integration(integrated_grad) - - return integrated_grad - - def get_gradients_for_class( - self, - adata: anndata.AnnData, - groupby: str, - target_class: str, - reference_class: str = None, - n_cells: int = None, - *args, - **kwargs, - ) -> pd.DataFrame: - """Get integrated gradients for a target class given - an AnnData experiment. - - Parameters - ---------- - adata : anndata.AnnData - [Cells, Features] experiment. - groupby : str - column in `adata.obs` containing class names. - target_class : str - class in `self.class_names` and `adata.obs[groupby]` - for optimization. - reference_class : str - reference class in `self.class_names`. "all" uses all - non-target classes as a reference. - n_cells : int - number of cells to use to compute a characteristic - integrated gradient. - if `None`, uses all cells. - *args, **kwargs : dict - passed to `self.get_integrated_gradient`. - - Returns - ------- - gradients : pd.DataFrame - [Cells, Features] integrated gradients. - Sets `self.grads_for_class[target_class]` with the value - of `gradients`. - - See Also - -------- - get_integrated_gradient - """ - if not np.all(adata.var_names == self.gene_names): - # gene names don't match, check if IG names are a subset - shared_genes = np.intersect1d( - adata.var_names, - self.gene_names, - ) - if len(shared_genes) < len(self.gene_names): - # some genes are missing - msg = "Not all genes in `gene_names` were found in `adata`." - raise ValueError(msg) - else: - # subset adata to the gene set used - # this will also handle gene name permutations - adata = adata[:, self.gene_names] - - if groupby not in adata.obs_keys(): - msg = f"{groupby} not in `adata.obs` columns." - raise ValueError(msg) - - groups = np.unique(adata.obs[groupby]) - if target_class not in groups: - msg = f"`{target_class}` is not a class in `{groupby}`" - raise ValueError(msg) - if target_class not in self.class_names: - msg = f"`{target_class}` is not a class in `self.class_names`" - raise ValueError(msg) - - # get the indices for cells of the target class - cell_idx = np.where(adata.obs[groupby] == target_class)[0].astype(np.int) - if n_cells is not None: - if n_cells < len(cell_idx): - # subset if a specific number of cells was specified - cell_idx = np.random.choice( - cell_idx, - size=n_cells, - replace=False, - ) - msg = f"Using {n_cells} cells for integrated gradient analysis." - logger.debug(msg) - else: - msg = f"n_cells {n_cells} > n_cells_in_class {len(cell_idx)}.\n" - msg += "Using all available cells." - logger.warning(msg) - - # compute integrated gradients - grads = [] - for i, idx in enumerate(cell_idx): - x = adata.X[idx, :] - if type(x) == np.matrix: - x = np.array(x) - if type(x) == sparse.csr_matrix: - x = x.toarray() - if type(x) != np.ndarray: - msg = "gene vector was not coerced to np.ndarray" - raise TypeError(msg) - x = x.flatten() - x = torch.from_numpy(x).float() - - g = self.get_integrated_gradient( - x=x, - target_class=target_class, - *args, - **kwargs, - ) - grads.append(g.view(-1)) - - logger.debug(f"x size: {x.size()}") - logger.debug(f"g size: {g.size()}") - - G = torch.stack(grads, dim=0).cpu().numpy() - - if self.grad_activation == "input": - col_names = self.gene_names - else: - col_names = [f"z_{i}" for i in range(G.shape[1])] - - gradients = pd.DataFrame( - G, - columns=col_names, - index=adata.obs_names[cell_idx], - ) - - self.grads_for_class[target_class] = gradients - - return gradients - - def get_top_features_from_gradients( - self, - target_class: str = None, - gradients: pd.DataFrame = None, - ) -> np.ndarray: - """Get the top features from a set of pre-computed integrated - gradients. - - Parameters - ---------- - target_class : str - target class with gradients stored in `self.grads_for_class[target_class]`. - gradients : pd.DataFrame - [Cells, Features] integrated gradients to use. If provided, supercedes - `target_class`. - - Returns - ------- - top_features : np.ndarray - [Features,] sorted [High, Low] values. - i.e. `top_features[0]` is the top feature. - """ - if target_class is None and gradients is None: - raise ValueError("must provide `gradients` or `target_class`") - - # `if gradients is not None`, use gradients instead of - # the stored gradients regardless of whether or not - # target_class as provided - if gradients is None: - gradients = self.grads_for_class[target_class] - logger.debug(f"Using stored gradients for {target_class}") - - grad_means = gradients.mean(0) - sort_idx = np.argsort(grad_means)[::-1] # high to low - - top_features = self.gene_names[sort_idx] - return top_features - - -class ExpectedGradient(object): - def __init__( - self, - model: nn.Module, - class_names: typing.Union[list, np.ndarray], - gene_names: typing.Union[list, np.ndarray] = None, - verbose: bool = False, - ) -> None: - """Performs expected gradient computations for feature attribution - in scNym models. - - Parameters - ---------- - model : torch.nn.Module - trained scNym model. - class_names : list or np.ndarray - list of str names matching output nodes in `model`. - gene_names : list or np.ndarray, optional - gene names for the model. - verbose : bool - verbose outputs for stdout. - - Returns - ------- - None. - - Notes - ----- - Integrated gradients are computed as the path integral between a "baseline" - gene expression vector (all 0 counts) and an observed gene expression vector. - The path integral is computed along a straight line in the feature space. - - Stated formally, we define a our baseline gene expression vector as :math:`x`, - our observed vector as :math:`x'`, an scnym model :math:`f(\cdot)`, and a - number of steps :math:`M` for approximating the integral by Reimann sums. - - The integrated gradient :math:`\int \nabla` for a feature :math:`x_i` is then - - .. math:: - - r = \sum_{m=1}^M \partial f(x' + \frac{m}{M}(x - x')) / \partial x_i \\ - \int \nabla_i = (x_i' - x_i) \frac{1}{M} r - """ - self.model = model - if torch.cuda.is_available(): - self.model = self.model.cuda() - logger.info("Model loaded on CUDA device for E[Grad] estimation.") - self.model.zero_grad() - for param in self.model.parameters(): - param.requires_grad = False - - self.model_device = list(self.model.parameters())[0].device - - self.class_names = np.array(class_names) - self.gene_names = np.array(gene_names) - self.verbose = verbose - self.grads_for_class = {} - # define the values for `source` that will trigger using all data as the - # reference dataset - self.background_vals = ( - "all", - None, - ) - - return - - def _check_inputs( - self, - adata: anndata.AnnData, - source: str, - target: str, - cell_type_col: str, - ) -> anndata.AnnData: - """Check that inputs match model expectations. - - Parameters - ---------- - adata : anndata.AnnData - [Cells, Genes] - source : str - class name for source class. - target : str - class name for target class. - cell_type_col : str - column in `adata.obs` containing cell type labels. - - Returns - ------- - adata : anndata.AnnData - [Cells, len(self.gene_names)] experiment. - modifies anndata to match model training gene names - if necessary. - """ - # check cell type arguments - if cell_type_col not in adata.obs.columns: - msg = f"{cell_type_col} is not a column in `adata.obs`" - raise ValueError(msg) - self.cell_type_col = cell_type_col - - cell_types = np.unique(adata.obs[self.cell_type_col]) - if source not in cell_types and source not in self.background_vals: - msg = f"{source} not in the detected cell types or background values." - raise ValueError(msg) - if target not in cell_types: - msg = f"{target} not in the detected cell types." - raise ValueError(msg) - - # check that genes match the training gene names - match = np.all(np.array(adata.var_names) == np.array(self.gene_names)) - if not match: - msg = "Gene names for model and `adata` query do not match.\n" - msg += "\t Coercing..." - logger.warn(msg) - X = build_classification_matrix( - X=get_adata_asarray( - adata, - ), - model_genes=np.array(self.gene_names), - sample_genes=np.array(adata.var_names), - gene_batch_size=1024, - ) - adata2 = anndata.AnnData( - X=X, - obs=adata.obs.copy(), - ) - adata2.var_names = self.gene_names - else: - logger.debug("Model and query gene names match.") - adata2 = adata - - return adata2 - - def _get_exp_grad( - self, - model: torch.nn.Module, - input_: torch.FloatTensor, - target: torch.LongTensor, - ) -> torch.FloatTensor: - """Get expected gradients from the input layer""" - exp_grad = self.APExp.shap_values( - model, - input_, - sparse_labels=target, - ) - return exp_grad - - def _setup_dataset(self, X, y, adata=None) -> None: - """Setup `Dataset` and `DataLoader`classes for train - and validation data. - - Returns - ------- - None. - Sets `.train_ds`, `.val_ds` and `.train_dl`, `.val_dl`. - """ - self.n_cell_types = len(np.unique(y)) - self.n_genes = X.shape[1] - - self.y_orig = y - y = np.array(pd.Categorical(y, categories=np.unique(y)).codes) - self.y = y - self.y_categories = np.unique(self.y_orig) - # setup dataset, model, and training components - # for querying, we also set a dataset with all of the data - self.all_ds = dataprep.SingleCellDS( - X=X, - y=np.array(y), - ) - self.all_dl = torch.utils.data.DataLoader( - self.all_ds, - batch_size=self.batch_size, - shuffle=False, - drop_last=False, - ) - - return - - def query( - self, - adata: anndata.AnnData, - target: str, - source: str = "all", - cell_type_col: str = "cell_ontology_class", - batch_size: int = 512, - n_batches: int = 100, - n_cells: int = None, - ) -> pd.DataFrame: - """Find the features that distinguish `target` cells from `source` cells - using expected gradient estimation. - - Parameters - ---------- - adata : anndata.AnnData - [Cells, Genes] - target : str - class name for target class. - expected gradients highlight important features that define this cell type. - source : str - class name for source class to use as reference cells for expected - gradient estimation. - if `"all"` or `None`, uses all cells in `adata` as possible references. - cell_type_col : str - column in `adata.obs` containing cell type labels. - n_batches : int - number of reference batches to draw for each target sample. - n_cells : int - number of target samples to use for E[G] estimation. - if `None`, uses all available samples. - - Returns - ------- - saliency : pd.DataFrame - [Genes, 1] mean expected gradient across cells used for - estimation for each gene. - """ - self.batch_size = batch_size - adata = self._check_inputs( - adata=adata, - source=source, - target=target, - cell_type_col=cell_type_col, - ) - self._setup_dataset(adata.X, adata.obs[cell_type_col], adata=adata) - self.model.train(False) - - target_bidx = adata.obs[self.cell_type_col] == target - if source in self.background_vals: - source_bidx = np.ones(adata.shape[0], dtype=np.bool) - # ensure target cells aren't in the source data - source_bidx[target_bidx] = False - else: - source_bidx = adata.obs[self.cell_type_col] == source - # regenerate labels in case the query dataset is different from the - # training dataset - class_names = self.class_names.tolist() - target_y = np.array( - [class_names.index(target)] * sum(target_bidx), - dtype=np.int32, - ) - source_y = adata.obs.loc[source_bidx, self.cell_type_col].tolist() - source_y = np.array( - [class_names.index(x) for x in source_y], - dtype=np.int32, - ) - - source_adata = adata[source_bidx, :].copy() - target_adata = adata[target_bidx, :].copy() - logging.info(f"Subset adata to {target_adata.shape[0]} target cells.") - - if n_cells is not None: - target_idx = np.random.choice( - np.arange(target_adata.shape[0]), - size=n_cells, - replace=target_adata.shape[0] < n_cells, - ).astype(np.int32) - else: - target_idx = np.arange(target_adata.shape[0]) - - target_ds = dataprep.SingleCellDS( - X=target_adata.X[target_idx], - y=target_y[target_idx], - ) - logging.info( - f"Using {target_ds.X.shape[0]} target cells for expgrad estimation." - ) - # save the cell indices in attributes - self._query_cell_obs_names = pd.DataFrame( - { - "names": source_adata.obs_names.tolist() - + target_adata.obs_names[target_idx].tolist(), - "dataset": ["source"] * source_adata.shape[0] - + ["target"] * len(target_idx), - }, - ) - - # make sure the source dataset has at least as many examples as - # the target by replicating at random - n_reps = int(np.ceil(sum(target_bidx) / sum(source_bidx))) - source_indices = np.arange(source_adata.X.shape[0]) - source_indices = np.tile(source_indices, (n_reps,)) - source_ds = dataprep.SingleCellDS( - X=source_adata.X[source_indices], - y=source_y[source_indices], - ) - - batch_size = min(self.batch_size, len(target_idx)) - target_dl = torch.utils.data.DataLoader( - target_ds, - batch_size=batch_size, - shuffle=False, - drop_last=self.batch_size == batch_size, - ) - - # use only the source samples as references if specified - # otherwise, use the whole training set - self.APExp = attrprior.AttributionPriorExplainer( - source_ds, - batch_size=batch_size, - k=1, - input_batch_index="input", - ) - logger.debug("Set up Attribution Prior Explainer") - gradients_by_batch = [] - for input_batch in target_dl: - batch_grads = [] - input_ = input_batch["input"].to(device=self.model_device) - _, labels = torch.max(input_batch["output"], dim=1) - labels = labels.to(device=self.model_device).long() - # for each input, use `n_batches` different random references - for i in range(n_batches): - s = time.time() - logger.debug(f"gradient batch {i}") - g = self._get_exp_grad( - self.model, - input_, - target=labels, - ) - g = g.detach() - batch_grads.append(g.detach().cpu()) - e = time.time() - logger.debug(f"time: {e-s} secs") - # [Obs, Features, estimation_batch] - batch_grads = torch.stack(batch_grads, dim=-1) - batch_grads = torch.mean(batch_grads, dim=-1) - - gradients_by_batch.append(batch_grads) - gradients = torch.cat(gradients_by_batch, dim=0) - gradients = gradients.detach().cpu().numpy() - - gradients = pd.DataFrame( - gradients, - index=target_adata.obs_names[target_idx][: gradients.shape[0]], - ) - if gradients.shape[1] == len(adata.var_names): - gradients.columns = adata.var_names.tolist() - - # compute mean gradients across cells - saliency = gradients.mean(0).sort_values(ascending=False) - saliency.columns = ["exp_grad"] - - self.saliency = saliency - self.gradients = gradients - return saliency - - def save_query( - self, - path: str, - ) -> None: - """Save intermediary representations generated during a - `query` call""" - if path is None: - return - # save query outputs - saliency_path = str(Path(path) / Path("saliency.csv")) - self.saliency.to_csv(saliency_path) - gradients_path = str(Path(path) / Path("gradients.csv")) - self.gradients.to_csv(gradients_path) - obs_names_path = str(Path(path) / Path("obs_names.csv")) - self._query_cell_obs_names.to_csv(obs_names_path) - return - - -class ClassificationEntropy(object): - def __init__(self, reduce: str = "mean") -> None: - """Compute the entropy of a classification probability vector""" - self.reduce = reduce - return - - def __call__(self, x: torch.FloatTensor) -> torch.FloatTensor: - """Compute entropy for a probability tensor `x` - - Parameters - ---------- - x : torch.FloatTensor - [Cells, Classes] probability tensor - - Returns - ------- - H : torch.FloatTensor - either [Cells,] or [1,] if `reduce is not None`. - """ - H = -1 * torch.sum(x * torch.log(x), dim=1) - if self.reduce == "mean": - H = torch.mean(H) - return H - - -class Tesseract(IntegratedGradient): - """Tessaract finds a path from a source vector in feature - space to a destination vector. - - Attributes - ---------- - model : torch.nn.Module - trained scNym model. - class_names : list or np.ndarray - list of str names matching output nodes in `model`. - gene_names : list or np.ndarray, optional - gene names for the model. - grad_activation : str - activations where gradients should be collected. - default "input" collects gradients at the level of input features. - verbose : bool - verbose outputs for stdout. - energy_criterion : Callable - criterion to evaluate the potential energy of a gene - expression state given args `model` and `x` where - `x` is a gene expression vector. - optimizer : torch.optim.Optimizer - optimizer for finding paths through gene expression - space using a parametric gene expression vector. - """ - - def __init__( - self, - *, - energy_criterion: typing.Callable, - optimizer_class: typing.Callable, - **kwargs, - ) -> None: - """Tessaract finds a path from a source vector in feature - space to a destination vector that maximizes the likelihood - of observing each intermediate position using a trained - classification model. - - Parameters - ---------- - energy_criterion : Callable - criterion to evaluate the potential energy of a gene - expression state given args `model` and `x` where - `x` is a gene expression vector. - optimizer_class : Callable - function to initialize a `torch.optim.Optimizer`. - - Returns - ------- - None - """ - super(Tesseract, self).__init__(**kwargs) - self.energy_criterion = energy_criterion - self.optimizer_class = optimizer_class - return - - def find_path( - self, - adata: anndata.AnnData, - groupby: str, - source_class: str, - target_class: str, - energy_weight: float = 1.0, - n_epochs: int = 500, - tol: float = 1.0, - patience: int = 10, - ) -> torch.FloatTensor: - """Find a path between a source and target cell class - given an AnnData experiment containing both. - - Parameters - ---------- - adata : anndata.AnnData - [Cells, Features] experiment. - groupby : str - column in `adata.obs` containing class names. - source_class : str - class in `self.class_names` and `adata.obs[groupby]` - for initialization. - target_class : str - class in `self.class_names` and `adata.obs[groupby]` - for optimization. - energy_weight : float, optional - weight for the energy criterion relative to the class - scores. - n_epochs : int, optional - number of epochs for optimization. - tol : float, optional - minimum L2 difference across epochs to consider - the optimization to be progressing. - patience : int, optional - number of epochs to wait before early stopping. - - Returns - ------- - path : torch.FloatTensor - [epochs, Features] path through gene expression space. - Sets `self.last_path` with the value of path. - """ - - source_cell_idx = adata.obs[groupby] == source_class - source_mean = torch.from_numpy( - np.array(adata[source_cell_idx, :].X.mean(0)) - ).float() - model_device = list(self.model.parameters())[0].device - source_mean = source_mean.to(device=model_device) - - if self.grad_activation == "first_layer": - # we're using first layer embeddings as the relevant - # space for integrated gradient computation and - # optimization - source_mean2use = self.input2first(source_mean) - self.scoring_model = self.first2output - else: - source_mean2use = source_mean - self.scoring_model = self.model - - # initialize the gene expression vector at the source - x = copy.deepcopy(source_mean2use) - self.optimizer = self.optimizer_class({"params": x, "name": "expression_path"}) - - # perform optimization to the target class while - # minimizing an energy criterion - def loss( - x, - ): - target_idx = self.class_names.index(target_class) - source_idx = self.class_names.index(source_class) - outputs = self.scoring_model( - x, - ) - probs = torch.nn.functional.softmax(outputs, dim=1) - energy = ( - self.energy_criterion( - x, - ) - * energy_weight - ) - l = (probs[source_idx] - probs[target_idx]) + energy - return l - - # intialize path collector and set the `waiting_epochs` - # for early stopping to an initial zero value - path_points = [] - waiting_epochs = 0 - logger.info("Beginning pathfinding optimization") - for epoch in range(n_epochs): - # save path locations - path_points.append(copy.deepcopy(x.detach().cpu())) - - # compute loss and perform an update step - l = loss( - x, - ) - self.optimizer.zero_grad() - l.backward() - self.optimizer.step() - - # check if x is changing substantially - delta = x.data - path_points[-1] - l2 = torch.norm(delta, p=2) - if l2 < tol and waiting_epochs > patience: - msg = f"\tchange in l2 < {tol} for {patience} epochs\n" - msg += "\tending optimizing." - logger.warning(msg) - elif l2 < tol: - waiting_epochs += 1 - else: - waiting_epochs = 0 - - path = torch.cat(path_points, dim=0) - self.last_path = path - return path diff --git a/build/lib/scnym/losses.py b/build/lib/scnym/losses.py deleted file mode 100644 index 8b950a4..0000000 --- a/build/lib/scnym/losses.py +++ /dev/null @@ -1,1838 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import logging -from typing import Callable, Union, Iterable, Tuple -from .dataprep import SampleMixUp -from .model import CellTypeCLF, DANN, AE -from .distributions import NegativeBinomial -import copy - -logger = logging.getLogger(__name__) - - -class MultiTaskCriterion(object): - def __init__( - self, - ) -> None: - """Abstraction for MultiTask losses - - Note: Depreceated, inheriting from `torch.nn.Module` now. - """ - return - - def train(self, on: bool) -> None: - """Toggle the training mode of learned parameters""" - return - - def eval( - self, - ) -> None: - """Disable training of learned parameters""" - self.train(on=False) - return - - -def get_class_weight( - y: np.ndarray, -) -> np.ndarray: - """Generate relative class weights based on the representation - of classes in a label vector `y` - - Parameters - ---------- - y : np.ndarray - [N,] vector of class labels. - - Returns - ------- - class_weight : np.ndarray - [Classes,] vector of loss weight coefficients. - if classes are `str`, returns weights in lexographically - sorted order. - - """ - # find all unique class in y and their counts - u_classes, class_counts = np.unique(y, return_counts=True) - # compute class proportions - class_prop = class_counts / len(y) - # invert proportions to get class weights - class_weight = 1.0 / class_prop - # normalize so that the minimum value is 1. - class_weight = class_weight / class_weight.min() - return class_weight - - -def cross_entropy( - pred_: torch.FloatTensor, - label: torch.FloatTensor, - class_weight: torch.FloatTensor = None, - sample_weight: torch.FloatTensor = None, - reduction: str = "mean", -) -> torch.FloatTensor: - """Compute cross entropy loss for prediction outputs - and potentially non-binary targets. - - Parameters - ---------- - pred_ : torch.FloatTensor - [Batch, C] model outputs. - label : torch.FloatTensor - [Batch, C] labels. may not necessarily be one-hot, - but must satisfy simplex criterion. - class_weight : torch.FloatTensor - [C,] relative weights for each of the output classes. - useful for increasing attention to underrepresented - classes. - reduction : str - reduction method across the batch. - - Returns - ------- - loss : torch.FloatTensor - mean cross-entropy loss across the batch indices. - - Notes - ----- - Crossentropy is defined as: - - .. math:: - - H(P, Q) = -\Sum_{k \in K} P(k) log(Q(k)) - - where P, Q are discrete probability distributions defined - with a common support K. - - References - ---------- - See for class weight computation: - https://pytorch.org/docs/stable/nn.html#crossentropyloss - """ - if pred_.size() != label.size(): - msg = ( - f"pred size {pred_.size()} not compatible with label size {label.size()}\n" - ) - raise ValueError(msg) - - if reduction.lower() not in ("mean", "sum", "none"): - raise ValueError(f"{reduction} is not a valid reduction method.") - - # Apply softmax transform to predictions and log transform - pred_log_sm = torch.nn.functional.log_softmax(pred_, dim=1) - # Compute cross-entropy with the label vector - samplewise_loss = -1 * torch.sum(label * pred_log_sm, dim=1) - - if sample_weight is not None: - # weight individual samples using sample_weight - # we squeeze into a single column in-case it had an - # empty singleton dimension - samplewise_loss *= sample_weight.squeeze() - - if class_weight is not None: - class_weight = class_weight.to(label.device) - # weight the losses - # copy the weights across the batch to allow for elementwise - # multiplication with the samplewise losses - class_weight = class_weight.repeat(samplewise_loss.size(0), 1) - # compute an [N,] vector of weights for each samples' loss - weight_vec, _ = torch.max( - class_weight * label, - dim=1, - ) - - samplewise_loss = samplewise_loss * weight_vec - if reduction == "mean": - loss = torch.mean(samplewise_loss) - elif reduction == "sum": - loss = torch.sum(samplewise_loss) - else: - loss = samplewise_loss - return loss - - -class scNymCrossEntropy(nn.Module): - def __init__( - self, - class_weight: torch.FloatTensor = None, - sample_weight: torch.FloatTensor = None, - reduction: str = "mean", - ) -> None: - """Class wrapper for scNym cross-entropy loss to be used - in conjuction with `MultiTaskTrainer` - - Parameters - ---------- - class_weight : torch.FloatTensor - [C,] relative weights for each of the output classes. - useful for increasing attention to underrepresented - classes. - reduction : str - reduction method across the batch. - - See Also - -------- - cross_entropy - .trainer.MultiTaskTrainer - """ - super(scNymCrossEntropy, self).__init__() - - self.class_weight = class_weight - self.sample_weight = sample_weight - self.reduction = reduction - return - - def __call__( - self, - labeled_sample: dict, - unlabeled_sample: dict, - model: nn.Module, - weight: float = None, - ) -> torch.FloatTensor: - """Perform class prediction and compute the supervised loss - - Parameters - ---------- - labeled_sample : dict - input - torch.FloatTensor - [Batch, Features] minibatch of labeled examples. - output - torch.LongTensor - one-hot labels. - unlabeled_sample : dict - input - torch.FloatTensor - [Batch, Features] minibatch of unlabeled samples. - output - torch.LongTensor - zeros. - pass `None` if there are no unlabeled samples. - model : nn.Module - model with parameters accessible via the `.parameters()` - method. - weight : float - default None. no-op, included for API compatibility. - - - Returns - ------- - loss : torch.FloatTensor - """ - data = labeled_sample["input"] - # forward pass - outputs, x_embed = model(data, return_embed=True) - probs = torch.nn.functional.softmax(outputs, dim=-1) - _, predictions = torch.max(probs, dim=-1) - - # compute loss - loss = cross_entropy( - pred_=probs, - label=labeled_sample["output"], - sample_weight=self.sample_weight, - class_weight=self.class_weight, - reduction=self.reduction, - ) - - labeled_sample["embed"] = x_embed - - if unlabeled_sample is not None: - outputs, u_embed = model(unlabeled_sample["input"], return_embed=True) - unlabeled_sample["embed"] = u_embed - - return loss - - -class InterpolationConsistencyLoss(nn.Module): - def __init__( - self, - unsup_criterion: Callable, - sup_criterion: Callable, - decay_coef: float = 0.9997, - mean_teacher: bool = True, - augment: Callable = None, - teacher_eval: bool = True, - teacher_bn_running_stats: bool = None, - **kwargs, - ) -> None: - """Computes an Interpolation Consistency Loss - given a trained model and an unlabeled minibatch. - - Parameters - ---------- - unsup_criterion : Callable - loss criterion for similarity between "mixed-up" - "fake labels" and predicted labels for "mixed-up" - samples. - sup_criterion : Callable - loss for samples with a primarily labeled component. - decay_coef : float - decay coefficient for mean teacher parameter - updates. - mean_teacher : bool - use a mean teacher model for interpolation consistency - loss estimation. - augment : Callable - augments a batch of samples. - teacher_eval : bool - place teacher in evaluation mode, deactivating stochastic - model components. - teacher_bn_running_stats : bool - use running statistics for batch normalization mean and - variance. - if False, uses minibatch statistics. - if None, uses setting of the student model batch norm layer. - - Returns - ------- - None. - - Notes - ----- - Instantiates a `SampleMixUp` class and passes any - `**kwargs` to this class. - - Uses a "mean teacher" method by keeping a running - average of parameter sets used in the `__call__` - method. - - `decay_coef` taken from the Mean Teacher paper experiments - on ImageNet. - https://arxiv.org/abs/1703.01780 - - Formalism: - - .. math:: - - icl(u) = criterion( f(Mixup(u_i, u_j)), - Mixup(f(u_i), f(u_j)) ) - - References - ---------- - 1. Interpolation consistency training for semi-supervised learning - 2019, arXiv:1903.03825v3, stat.ML - Vikas Verma, Alex Lamb, Juho Kannala, Yoshua Bengio - - 2. Mean teachers are better role models: \ - Weight-averaged consistency targets improve \ - semi-supervised deep learning results - 2017, arXiv:1703.01780, cs.NE - Antti Tarvainen, Harri Valpola - """ - super(InterpolationConsistencyLoss, self).__init__() - - self.unsup_criterion = unsup_criterion - self.sup_criterion = sup_criterion - self.decay_coef = decay_coef - self.mean_teacher = mean_teacher - if self.mean_teacher: - print("IC Loss is using a mean teacher.") - self.augment = augment - self.teacher_eval = teacher_eval - self.teacher_bn_running_stats = teacher_bn_running_stats - - # instantiate a callable MixUp operation - self.mixup_op = SampleMixUp(**kwargs) - - self.teacher = None - self.step = 0 - return - - def _update_teacher( - self, - model: nn.Module, - ) -> None: - """Update the teacher model based on settings""" - if self.mean_teacher: - if self.teacher is None: - # instantiate the teacher with a copy - # of the model - self.teacher = copy.deepcopy( - model, - ) - else: - self._update_teacher_params( - model, - ) - else: - self.teacher = copy.deepcopy( - model, - ) - - if self.teacher_eval: - self.teacher = self.teacher.eval() - - if self.teacher_bn_running_stats is not None: - # enforce our preference for teacher model batch - # normalization statistics - for m in self.teacher.modules(): - if isinstance(m, nn.BatchNorm1d): - m.track_running_stats = self.teacher_bn_running_stats - - # check that our parameters are preserved - if self.teacher_bn_running_stats is not None: - # enforce our preference for teacher model batch - # normalization statistics - for m in self.teacher.modules(): - if isinstance(m, nn.BatchNorm1d): - assert m.track_running_stats == self.teacher_bn_running_stats - - return - - def _update_teacher_params( - self, - model: nn.Module, - ) -> None: - """Update parameters in the teacher model using an - exponential averaging method. - - Notes - ----- - Logic derived from the Mean Teacher implementation - https://github.com/CuriousAI/mean-teacher/ - """ - # Per the mean-teacher paper, we use the global average - # of parameter values until the exponential average is more effective - # For a `decay_coef ~= 0.997`, this hand-off happens at ~step 333. - alpha = min(1 - 1 / (self.step + 1), self.decay_coef) - # Perform in-place operations on the teacher parameters to average - # with the new model parameters - # Here, we're computing a simple weighted average where alpha is - # the weight on past parameters, and (1 - alpha) is the weight on - # new parameters - zipped_params = zip(self.teacher.parameters(), model.parameters()) - for teacher_param, model_param in zipped_params: - (teacher_param.data.mul_(alpha).add_(1 - alpha, model_param.data)) - return - - def __call__( - self, - model: nn.Module, - unlabeled_sample: dict, - labeled_sample: dict, - ) -> torch.FloatTensor: - """Takes a model and set of unlabeled samples as input - and computes the Interpolation Consistency Loss. - - Parameters - ---------- - model : nn.Module - model with parameters accessible via the `.parameters()` - method. - unlabeled_sample : dict - input - torch.FloatTensor - [Batch, Features] minibatch of unlabeled samples. - output - torch.LongTensor - zeros. - labeled_sample : dict - input - torch.FloatTensor - [Batch, Features] minibatch of labeled examples. - output - torch.LongTensor - one-hot labels. - - Returns - ------- - supervised_loss : torch.FloatTensor - supervised loss computed using `sup_criterion` between - model predictions on mixed observations and true labels. - unsupervised_loss : torch.FloatTensor - unsupervised loss computed using `criterion` and the - interpolation consistency method. - supervised_outputs : torch.FloatTensor - [Batch, Classes] model outputs for augmented labeled examples. - - - Notes - ----- - Algorithm description: - - (0) Update the mean teacher. - (1) Compute "fake labels" for unlabeled samples by performing - a forward pass through the "mean teacher" and using the output - as a representative label for the sample. - (2) Perform a MixUp operation on unlabeled samples and their - corresponding fake labels. - (3) Compute the loss criterion between the mixed-up fake labels - and the predicted fake labels for the mixed up samples. - """ - ############################### - # (0) Update the mean teacher - ############################### - - self._update_teacher( - model, - ) - - ############################### - # (1) Compute Fake Labels - ############################### - - with torch.no_grad(): - fake_y = F.softmax( - self.teacher(unlabeled_sample["input"]), - dim=1, - ) - - ############################### - # (2) Perform MixUp and Forward - ############################### - - unlabeled_sample["output"] = fake_y - - mixed_sample = self.mixup_op(unlabeled_sample) - # move sample to model device if necessary - mixed_sample["input"] = mixed_sample["input"].to( - device=next(model.parameters()).device, - ) - mixed_output = F.softmax( - model(mixed_sample["input"]), - ) - assert mixed_output.requires_grad - - # set outputs as attributes for later access - self.mixed_output = mixed_output - self.mixed_sample = mixed_sample - self.unlabeled_sample = unlabeled_sample - - ############################### - # (3) Compute unsupervised loss - ############################### - - icl = self.unsup_criterion( - mixed_output, - fake_y, - ) - - ############################### - # (4) Compute supervised loss - ############################### - - if self.augment is not None: - labeled_sample = self.augment(labeled_sample) - # move sample to the model device if necessary - labeled_sample["input"] = labeled_sample["input"].to( - device=next(model.parameters()).device, - ) - labeled_sample["input"].requires_grad = True - - sup_outputs = model(labeled_sample["input"]) - sup_loss = self.sup_criterion( - sup_outputs, - labeled_sample["output"], - ) - - self.step += 1 - return sup_loss, icl, sup_outputs - - -def sharpen_labels( - q: torch.FloatTensor, - T: float = 0.5, -) -> torch.FloatTensor: - """Reduce the entropy of a categorical label using a - temperature adjustment - - Parameters - ---------- - q : torch.FloatTensor - [N, C] pseudolabel. - T : float - temperature parameter. - - Returns - ------- - q_s : torch.FloatTensor - [C,] sharpened pseudolabel. - - Notes - ----- - .. math:: - - S(q, T) = q_i^{1/T} / \sum_j^L q_j^{1/T} - - """ - if T == 0.0: - # equivalent to argmax - _, idx = torch.max(q, dim=1) - oh = torch.nn.functional.one_hot( - idx, - num_classes=q.size(1), - ) - return oh - - if T == 1.0: - # no-op - return q - - q = torch.pow(q, 1.0 / T) - q /= torch.sum( - q, - dim=1, - ).reshape(-1, 1) - return q - - -class MixMatchLoss(InterpolationConsistencyLoss): - """Compute the MixMatch Loss given a batch of labeled - and unlabeled examples. - - Attributes - ---------- - n_augmentations : int - number of augmentated samples to average across when - computing pseudolabels. - default = 2 from MixMatch paper. - T : float - temperature parameter. - augment_pseudolabels : bool - perform augmentations during pseudolabel generation. - pseudolabel_min_confidence : float - minimum confidence to compute a loss for a given pseudolabeled - example. examples below this confidence threshold will be given - `0` loss. see the `FixMatch` paper for discussion. - teacher : nn.Module - teacher model for pseudolabeling. - running_confidence_scores : list - [n_batches_to_store,] (torch.Tensor, torch.Tensor,) of unlabeled - example (Confident_Bool, BestConfidenceScore) tuples. - n_batches_to_store : int - determines how many batches to keep in `running_confidence_scores`. - """ - - def __init__( - self, - n_augmentations: int = 2, - T: float = 0.5, - augment_pseudolabels: bool = True, - pseudolabel_min_confidence: float = 0.0, - **kwargs, - ) -> None: - """Compute the MixMatch Loss given a batch of labeled - and unlabeled examples. - - Parameters - ---------- - n_augmentations : int - number of augmentated samples to average across when - computing pseudolabels. - default = 2 from MixMatch paper. - T : float - temperature parameter. - augment_pseudolabels : bool - perform augmentations during pseudolabel generation. - pseudolabel_min_confidence : float - minimum confidence to compute a loss for a given pseudolabeled - example. examples below this confidence threshold will be given - `0` loss. see the `FixMatch` paper for discussion. - - Returns - ------- - None. - - References - ---------- - MixMatch: A Holistic Approach to Semi-Supervised Learning - http://papers.nips.cc/paper/8749-mixmatch-a-holistic-approach-to-semi-supervised-learning - - FixMatch: https://arxiv.org/abs/2001.07685 - """ - # inherit from IC loss, forcing the SampleMixUp to keep - # the identity of the dominant observation in each mixed sample - super(MixMatchLoss, self).__init__( - **kwargs, - keep_dominant_obs=True, - ) - if not callable(self.augment): - msg = "MixMatch requires a Callable for augment" - raise TypeError(msg) - self.n_augmentations = n_augmentations - self.augment_pseudolabels = augment_pseudolabels - self.T = T - - self.pseudolabel_min_confidence = pseudolabel_min_confidence - # keep a running score of the last 50 batches worth of pseudolabel - # confidence outcomes - self.n_batches_to_store = 50 - self.running_confidence_scores = [] - return - - @torch.no_grad() - def _generate_labels( - self, - unlabeled_sample: dict, - ) -> torch.FloatTensor: - """Generate labels by applying a set of augmentations - to each unlabeled example and keeping the mean. - - Parameters - ---------- - unlabeled_batch : dict - "input" - [Batch, Features] minibatch of unlabeled samples. - """ - # let the teacher model take guesses at the label for augmented - # versions of the unlabeled observations - raw_guesses = [] - for i in range(self.n_augmentations): - to_augment = { - "input": unlabeled_sample["input"].clone(), - "output": torch.zeros(1), - } - if self.augment_pseudolabels: - # augment the batch before pseudolabeling - augmented_batch = self.augment(to_augment) - else: - augmented_batch = to_augment - # convert model guess to probability distribution `q` - # with softmax, prior to considering it a label - guess = F.softmax( - self.teacher(augmented_batch["input"]), - dim=1, - ) - raw_guesses.append(guess) - - # compute pseudolabels as the mean across all label guesses - pseudolabels = torch.mean( - torch.stack( - raw_guesses, - dim=0, - ), - dim=0, - ) - - # before sharpening labels, determine if the labels are - # sufficiently confidence to use - highest_conf, likeliest_class = torch.max( - pseudolabels, - dim=1, - ) - # confident is a bool that we will use to decide if we should - # keep loss from a given example or zero it out - confident = highest_conf >= self.pseudolabel_min_confidence - # store confidence outcomes in a running list so we can monitor - # which fraction of pseudolabels are being used - if len(self.running_confidence_scores) > self.n_batches_to_store: - # remove the oldest batch - self.running_confidence_scores.pop(0) - - # store tuples of (torch.Tensor, torch.Tensor) - # (confident_bool, highest_conf_score) - self.running_confidence_scores.append( - ( - confident.detach().cpu(), - highest_conf.detach().cpu(), - ), - ) - - if self.T is not None: - # sharpen labels - pseudolabels = sharpen_labels( - q=pseudolabels, - T=self.T, - ) - # ensure pseudolabels aren't attached to the - # computation graph - pseudolabels = pseudolabels.detach() - - return pseudolabels, confident - - def __call__( - self, - model: nn.Module, - labeled_sample: dict, - unlabeled_sample: dict, - **kwargs, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - """ - Parameters - ---------- - model : nn.Module - model with parameters accessible via the `.parameters()` - method. - labeled_sample : dict - input - torch.FloatTensor - [Batch, Features] minibatch of labeled examples. - output - torch.LongTensor - one-hot labels. - unlabeled_sample : dict - input - torch.FloatTensor - [Batch, Features] minibatch of unlabeled samples. - output - torch.LongTensor - zeros. - - - Returns - ------- - supervised_loss : torch.FloatTensor - supervised loss computed using `sup_criterion` between - model predictions on mixed observations and true labels. - unsupervised_loss : torch.FloatTensor - unsupervised loss computed using `criterion` between - model predictions on mixed unlabeled observations - and pseudolabels generated as the mean - across `n_augmentations` augmentation runs. - supervised_outputs : torch.FloatTensor - [Batch, Classes] model outputs for augmented labeled examples. - """ - - ######################################## - # (0) Update the mean teacher - ######################################## - - self._update_teacher( - model, - ) - - ######################################## - # (1) Generate labels for unlabeled data - ######################################## - - pseudolabels, pseudolabel_confidence = self._generate_labels( - unlabeled_sample=unlabeled_sample, - ) - # make sure pseudolabels match real label dtype - # so that they can be concatenated - pseudolabels = pseudolabels.to(dtype=labeled_sample["output"].dtype) - - ######################################## - # (2) Augment the labeled data - ######################################## - - labeled_sample = self.augment( - labeled_sample, - ) - - ######################################## - # (3) Perform MixUp across both batches - ######################################## - n_unlabeled_original = unlabeled_sample["input"].size(0) - unlabeled_sample["output"] = pseudolabels - - # separate samples into confident and unconfident sample dicts - # we only allow samples with confident pseudolabels to - # participate in the MixUp operation - conf_unlabeled_sample = {} - ucnf_unlabeled_sample = {} - - for k in unlabeled_sample.keys(): - conf_unlabeled_sample[k] = unlabeled_sample[k][pseudolabel_confidence] - ucnf_unlabeled_sample[k] = unlabeled_sample[k][~pseudolabel_confidence] - - # unlabeled samples come BEFORE labeled samples - # in the concatenated sample - # NOTE: we only allow confident unlabeled samples - # into the concatenated sample used for MixUp - cat_sample = { - k: torch.cat( - [ - conf_unlabeled_sample[k], - labeled_sample[k], - ], - dim=0, - ) - for k in ["input", "output"] - } - - # mixup the concatenated sample - # NOTE: dominant observations are maintained - # by passing `keep_dominant_obs=True` in - # `self.__init__` - mixed_samples = self.mixup_op( - cat_sample, - ) - - ######################################## - # (4) Forward pass for mixed samples - ######################################## - - # split the mixed samples based on the dominant - # observation - n_unlabeled = conf_unlabeled_sample["input"].size(0) - unlabeled_m_ = mixed_samples["input"][:n_unlabeled] - unlabeled_y_ = mixed_samples["output"][:n_unlabeled] - - labeled_m_ = mixed_samples["input"][n_unlabeled:] - labeled_y_ = mixed_samples["output"][n_unlabeled:] - - # append low confidence samples to unlabeled_m_ and unlabeled_y_ - # this ensures that batch norm is still able to update it's - # statistics based on batches from the train AND target domain - unlabeled_m_ = torch.cat( - [ - unlabeled_m_, - ucnf_unlabeled_sample["input"], - ] - ) - unlabeled_y_ = torch.cat( - [ - unlabeled_y_, - ucnf_unlabeled_sample["output"], - ] - ) - - # perform a forward pass on mixed samples - # NOTE: Our unsupervised criterion operates on post-softmax - # probability vectors, so we transform the output here - unlabeled_z_ = F.softmax( - model(unlabeled_m_), - dim=1, - ) - # NOTE: Our supervised criterion operates directly on - # logits and performs a `logsoftmax()` internally - labeled_z_ = model(labeled_m_) - - ######################################## - # (5) Compute losses - ######################################## - - # compare mixed pseudolabels to the model guess - # on the mixed input - # NOTE: this returns an **unreduced** loss of size - # [Batch,] or [Batch, Classes] depending on the loss function - unsupervised_loss = self.unsup_criterion( - unlabeled_z_, - unlabeled_y_, - ) - # sum loss across classes if not reduced in the loss - if unsupervised_loss.dim() > 1: - unsupervised_loss = torch.sum(unsupervised_loss, dim=1) - - # scale the loss to 0 for all observations without confident pseudolabels - # this allows the loss to slowly ramp up as labels become more confident - scale_vec = ( - torch.zeros_like(unsupervised_loss) - .float() - .to(device=unsupervised_loss.device) - ) - scale_vec[:n_unlabeled] += 1.0 - unsupervised_loss = unsupervised_loss * scale_vec - unsupervised_loss = torch.mean(unsupervised_loss) - - # compute model guess on the mixed supervised input - # to the mixed labels - # NOTE: we didn't allow non-confident pseudolabels - # into the MixUp, so this shouldn't propogate any - # poor quality pseudolabel information - supervised_loss = self.sup_criterion( - labeled_z_, - labeled_y_, - ) - - self.step += 1 - - return supervised_loss, unsupervised_loss, labeled_z_ - - -class MultiTaskMixMatchWrapper(nn.Module): - def __init__( - self, - mixmatch_loss: MixMatchLoss, - sup_weight: Union[float, Callable] = 1.0, - unsup_weight: Union[float, Callable] = 1.0, - use_sup_eval: bool = True, - ) -> None: - """Wrapper around the `MixMatchLoss` class for use with `MultiTaskTrainer`. - The wrapper performs weighting of the supervised and unsupervised loss - internally, then returns a single `torch.FloatTensor` to `MultiTaskTrainer` - to maintain a consistent "one criterion, one loss" API. - - Parameters - ---------- - mixmatch_loss : MixMatchLoss - an instance of the `MixMatchLoss` class. - sup_weight : float, Callable - constant weight or callable weight schedule function for the - supervised MixMatch loss. - unsup_weight : float, Callable - constant weight or callable weight schedule function for the - unsupervised MixMatch loss. - use_sup_eval : bool - use only the supervised loss when in eval mode. - - Returns - ------- - None. - - Notes - ----- - Relies upon updating the `.epoch` attribute during the training - loop to properly enforce weight scheduling. - """ - super(MultiTaskMixMatchWrapper, self).__init__() - self.mixmatch_loss = mixmatch_loss - self.sup_weight = sup_weight - self.unsup_weight = unsup_weight - self.use_sup_eval = use_sup_eval - # initialize the epoch attribute so `MultiTaskTrainer` can find it - # `.epoch` will be updated in the training loop - self.epoch = 0 - return - - def __call__( - self, - *, - labeled_sample: dict, - unlabeled_sample: dict, - model: nn.Module, - weight: float = None, - ) -> torch.FloatTensor: - """Compute MixMatch losses, weight them internally, then return - the weighted sum. - - Parameters - ---------- - labeled_sample : dict - input - torch.FloatTensor - [Batch, Features] minibatch of labeled examples. - output - torch.LongTensor - one-hot labels. - unlabeled_sample : dict - input - torch.FloatTensor - [Batch, Features] minibatch of unlabeled samples. - output - torch.LongTensor - zeros. - model : nn.Module - model with parameters accessible via the `.parameters()` - method. - weight : float - unused weight parameter for compatability with the `MultiTaskTrainer` - API. - - Returns - ------- - loss : torch.FloatTensor - weighted sum of MixMatch supervised and unsupervised loss. - """ - sup_loss, unsup_loss, labeled_z_ = self.mixmatch_loss( - labeled_sample=labeled_sample, - unlabeled_sample=unlabeled_sample, - model=model, - ) - # get weights for each loss by either calling the function or keeping - # the constant value provided - sup_weight = ( - self.sup_weight(self.epoch) - if callable(self.sup_weight) - else self.sup_weight - ) - unsup_weight = ( - self.unsup_weight(self.epoch) - if callable(self.unsup_weight) - else self.unsup_weight - ) - - # don't use the unsupervised loss if we're in eval mode - # `use_sup_eval` is set - if self.use_sup_eval and not self.training: - unsup_weight = 0.0 - - loss = (sup_weight * sup_loss) + (unsup_weight * unsup_loss) - return loss - - -"""Domain adaptation losses""" - - -class DANLoss(nn.Module): - """Compute a domain adaptation network (DAN) loss.""" - - def __init__( - self, - dan_criterion: Callable, - model: CellTypeCLF, - use_conf_pseudolabels: bool = False, - scale_loss_pseudoconf: bool = False, - n_domains: int = 2, - **kwargs, - ) -> None: - """Compute a domain adaptation network loss. - - Parameters - ---------- - dan_criterion : Callable - domain classification criterion `Callable(output, target)`. - model : scnym.model.CellTypeCLF - `CellTypeCLF` model to use for embedding. - use_conf_pseudolabels : bool - only use unlabeled observations with confident pseudolabels - for discrimination. expects `pseudolabel_confidence` to be - passed in the `__call__()` if so. - scale_loss_pseudoconf : bool - scale the weight of the gradients passed to both models based - on the proportion of confident pseudolabels. - n_domains : int - number of domains of origin to predict using the adversary. - - Returns - ------- - None. - - Notes - ----- - **kwargs are passed to `scnym.model.DANN` - - See Also - -------- - scnym.model.DANN - scnym.trainer.MultiTaskTrainer - """ - super(DANLoss, self).__init__() - - self.dan_criterion = dan_criterion - - # build the DANN - self.dann = DANN( - model=model, - n_domains=n_domains, - **kwargs, - ) - self.dann.domain_clf = self.dann.domain_clf.to( - device=next(iter(model.parameters())).device, - ) - # instantiate with small tensor to simplify downstream size - # checking logic - self.x_embed = torch.zeros((1, 1)) - - self.use_conf_pseudolabels = use_conf_pseudolabels - self.scale_loss_pseudoconf = scale_loss_pseudoconf - # note that weighting is performed on gradients internally; - # accessed by `trainer.MultiTaskTrainer` - self.no_weight = True - return - - def __call__( - self, - labeled_sample: dict, - unlabeled_sample: dict = None, - weight: float = 1.0, - pseudolabel_confidence: torch.Tensor = None, - **kwargs, - ) -> torch.FloatTensor: - """Compute the domain adaptation loss on a labeled source - and unlabeled target domain batch. - - Parameters - ---------- - labeled_sample : dict - input - torch.FloatTensor - [BatchL, Features] minibatch of labeled examples. - output - torch.LongTensor - one-hot labels. - unlabeled_sample : dict - input - torch.FloatTensor - [BatchU, Features] minibatch of unlabeled samples. - output - torch.LongTensor - zeros. - weight : float - weight for reversed gradients passed up to the embedding - layer. gradients used for the domain classifier are normal - gradients, but we weight and reverse the gradients flowing - upward to the embedding layer by this constant. - pseudolabel_confidence : torch.Tensor - [BatchU,] boolean identifying observations in `unlabeled_sample` - with confident pseudolabels. - if not None and `self.use_conf_pseudolabels`, only performs - domain discrimination on unlabeled samples with confident - pseudolabels. - **kwargs : dict - kwargs are a no-op, included to allow for `model` kwarg per - `MultiTaskTrainer` API. - - Returns - ------- - dan_loss : torch.FloatTensor - domain adversarial loss term. - """ - # if no unlabeled data is provided, we create a dict of empty - # tensors. these tensors lead to no-ops for all the `.cat` ops - # below. - if unlabeled_sample is None: - t = torch.FloatTensor().to(device=labeled_sample["input"].device) - unlabeled_sample = {k: t for k in ["input", "domain"]} - - ######################################## - # (1) Create domain labels - ######################################## - - # check if domain labels are provided, if not assume - # train and target are separate domains - # domain labels of -1 indicate `None` was passed as a domain label - # to `SingleCellDS` - if torch.sum(labeled_sample.get("domain", torch.Tensor([-1])) == -1) > 0: - source_label = torch.zeros(labeled_sample["input"].size(0)).long() - source_label = torch.nn.functional.one_hot( - source_label, - num_classes=2, - ) - logger.debug("DAN source domain labels inferred.") - else: - # domain labels should already by one-hot - source_label = labeled_sample["domain"] - source_label = source_label.to(device=labeled_sample["input"].device) - - if torch.sum(unlabeled_sample.get("domain", torch.Tensor([-1])) == -1) > 0: - target_label = torch.ones(unlabeled_sample["input"].size(0)).long() - target_label = torch.nn.functional.one_hot( - target_label, - num_classes=2, - ) - logger.debug("DAN target domain labels inferred.") - else: - target_label = unlabeled_sample["domain"] - target_label = target_label.to(device=unlabeled_sample["input"].device) - - lx = labeled_sample["input"] - ux = unlabeled_sample["input"] - - ######################################## - # (2) Check confidence of unlabeled obs - ######################################## - - if self.use_conf_pseudolabels and pseudolabel_confidence is not None: - # check confidence of unlabeled observations and remove - # any unconfident observations from the minibatch - ux = ux[pseudolabel_confidence] - target_label = target_label[pseudolabel_confidence] - # store the number of confident unlabeled obs - self.n_conf_pseudolabels = ux.size(0) - self.n_total_unlabeled = unlabeled_sample["input"].size(0) - p_conf_pseudolabels = self.n_conf_pseudolabels / max(self.n_total_unlabeled, 1) - - ######################################## - # (3) Embed points and Classify domains - ######################################## - - x = torch.cat([lx, ux], 0) - dlabel = torch.cat([source_label, target_label], 0) - - self.dann.set_rev_grad_weight(weight=weight) - domain_pred, x_embed = self.dann(x) - - # store embeddings and labels - if x_embed.size(0) >= self.x_embed.size(0): - self.x_embed = copy.copy(x_embed.detach().cpu()) - self.dlabel = copy.copy(dlabel.detach().cpu()) - - ######################################## - # (4) Compute DAN loss - ######################################## - - dan_loss = self.dan_criterion( - domain_pred, - dlabel, - ) - - ######################################## - # (5) Compute DAN accuracy for logs - ######################################## - - _, dan_pred = torch.max(domain_pred, dim=1) - _, dlabel_int = torch.max(dlabel, dim=1) - self.dan_acc = ( - torch.sum( - dan_pred == dlabel_int, - ) - / float(dan_pred.size(0)) - ) - - if self.scale_loss_pseudoconf: - dan_loss *= p_conf_pseudolabels - - return dan_loss - - -"""Reconstruction losses""" - - -def poisson_loss( - input_: torch.FloatTensor, - target: torch.FloatTensor, - dispersion: torch.FloatTensor = None, -) -> torch.FloatTensor: - """Compute a Poisson loss for count data. - - Parameters - ---------- - input_ : torch.FloatTensor - [Batch, Feature] Poisson rate parameters. - target : torch.FloatTensor - [Batch, Features] count based target. - dispersion : torch.FloatTensor - Ignored for Poisson loss. - - Returns - ------- - nll : torch.FloatTensor - Poisson negative log-likelihood. - """ - # input_ are Poisson rates, compute likelihood of target data - # and sum likelihood across genes - nll = -1 * torch.sum( - torch.distributions.Poisson(input_).log_prob(target), - dim=-1, - ) - return nll - - -def negative_binomial_loss( - input_: torch.FloatTensor, - target: torch.FloatTensor, - dispersion: torch.FloatTensor, - eps: float = 1e-8, -) -> torch.FloatTensor: - """Compute a Negative Binomial loss for count data. - - Parameters - ---------- - input_ : torch.FloatTensor - [Batch, Feature] Negative Binomial mean parameters. - target : torch.FloatTensor - [Batch, Features] count based target. - dispersion : torch.FloatTensor - [Features,] Negative Binomial dispersion parameters. - eps : float - small constant to avoid numerical issues. - - Returns - ------- - nll : torch.FloatTensor - Negative Binomial negative log-likelihood. - - References - ---------- - Credit to `scvi-tools`: - https://github.com/YosefLab/scvi-tools/blob/42315756ba879b9421630696ea7afcd74e012a07/scvi/distributions/_negative_binomial.py#L67 - """ - res = -1 * (NegativeBinomial(mu=input_, theta=dispersion).log_prob(target).sum(-1)) - return res - - -def mse_loss( - input_: torch.FloatTensor, - target: torch.FloatTensor, - dispersion: torch.FloatTensor, -) -> torch.FloatTensor: - """MSELoss wrapped for scNym compatibility""" - return torch.nn.functional.mse_loss(input_, target) - - -class ReconstructionLoss(nn.Module): - """Computes a reconstruction of the input data from the - embedding""" - - def __init__( - self, - *, - model: nn.Module, - rec_criterion: Callable, - reduction: str = "mean", - norm_before_loss: float = None, - **kwargs, - ) -> None: - """Computes a reconstruction loss of the input data - from the embedding. - - Parameters - ---------- - model : nn.Module - cell type classification model to use for cellular - embedding. - rec_criterion : Callable - reconstruction loss that takes two arguments `(input_, target)`. - reduction : str - {"none", "mean", "sum"} reduction operation for [Batch,] loss values. - norm_before_loss : float - normalize profiles to the following depth before computing loss. - this helps balance loss contribution from cells with dramatically - different depths (e.g. Drop-seq and Smart-seq2). - if `None`, does not normalize before loss. - **kwargs : dict - passed to recontruction model `.model.AE`. - - Returns - ------- - None. - """ - super(ReconstructionLoss, self).__init__() - - self.rec_criterion = rec_criterion - self.model = model - self.reduction = reduction - if reduction not in (None, "none", "sum", "mean"): - msg = f"reduction argument {self.reduction} is invalid." - raise ValueError(msg) - self.norm_before_loss = norm_before_loss - - # build the reconstruction autoencoder - self.rec_model = AE( - model=model, - **kwargs, - ) - # move rec_model to the appropriate computing device - self.rec_model = self.rec_model.to( - device=list(self.model.parameters())[1].device, - ) - - return - - def __call__( - self, - labeled_sample: dict, - unlabeled_sample: dict = None, - weight: float = 1.0, - **kwargs, - ) -> torch.FloatTensor: - """Compute the domain adaptation loss on a labeled source - and unlabeled target domain batch. - - Parameters - ---------- - labeled_sample : dict - input - torch.FloatTensor - [BatchL, Features] minibatch of labeled examples. - output - torch.LongTensor - [BatchL,] one-hot labels. - embed - torch.FloatTensor, optional - [BatchL, n_hidden] minibatch embedding. - unlabeled_sample : dict, optional. - input - torch.FloatTensor - [BatchU, Features] minibatch of unlabeled samples. - output - torch.LongTensor - [BatchU,] zeros. - embed - torch.FloatTensor, optional - [BatchU, n_hidden] minibatch embedding. - weight : float - reconstruction loss weight. Not used, present for compatability with the - `MultiTaskTrainer` API. - kwargs : dict - currently not used, allows for compatibility with `Trainer` subclasses - that pass `model` to call by default (e.g. as used for the old `MixMatchLoss`). - - Returns - ------- - reconstruction_loss : torch.FloatTensor - reconstruction loss, reduced across the batch. - """ - if unlabeled_sample is None: - # if no unlabeled data is passed, we create empty FloatTensors - # to concat onto the labeled tensors below. - # cat of an empty tensor is a no-op. - t = torch.FloatTensor().to(device=labeled_sample["input"].device) - unlabeled_sample = { - "input": t, - "embed": t, - "domain": t, - } - - # join data into a single batch - x = torch.cat( - [ - labeled_sample["input"], - unlabeled_sample["input"], - ], - dim=0, - ) - - # use pre-computed embeddings if they're available from e.g. - # a previous loss function. - if "embed" in labeled_sample.keys() and "embed" in unlabeled_sample.keys(): - x_embed = torch.cat( - [ - labeled_sample["embed"], - unlabeled_sample["embed"], - ], - dim=0, - ) - else: - x_embed = None - - # pass domain arguments to the reconstruction model if specified - # domains are already [Batch, Domains] one-hot encoded. - if self.rec_model.n_domains > 0: - x_domain = torch.cat( - [ - labeled_sample["domain"], - unlabeled_sample["domain"], - ], - dim=0, - ).to(device=x.device) - else: - x_domain = None - - # perform embedding and reconstruction - # if `x_embed is None`, computes the embedding using the - # trunk of the classification model - x_rec, x_scaled, dispersion, x_embed = self.rec_model( - x, - x_embed=x_embed, - x_domain=x_domain, - ) - - if self.norm_before_loss is not None: - # normalize to a common depth (CP-TenThousand) before computing loss - x_scaled2use = x_scaled / x_scaled.sum(1).view(-1, 1) * 1e6 - x2use = x / x.sum(1).view(-1, 1) * self.norm_before_loss4 - else: - x_scaled2use = x_scaled - x2use = x - - # score reconstruction - reconstruction_loss = self.rec_criterion( - input_=x_scaled2use, - target=x2use, - dispersion=dispersion, - ) - if self.reduction == "mean": - reconstruction_loss = torch.mean(reconstruction_loss) - elif (self.reduction == "none") or (self.reduction is None): - reconstruction_loss = reconstruction_loss - elif self.reduction == "sum": - reconstruction_loss = torch.sum(reconstruction_loss) - else: - msg = f"reduction argument {self.reduction} is invalid." - raise ValueError(msg) - - return reconstruction_loss - - -class LatentL2(nn.Module): - def __init__( - self, - ) -> None: - """Compute an l2-norm penalty on the latent embedding. - This serves as a sufficient regularization in deterministic - regularized autoencoders (RAE), akin to the KL term in VAEs. - - References - ---------- - https://openreview.net/pdf?id=S1g7tpEYDS - """ - super(LatentL2, self).__init__() - - return - - def __call__( - self, - labeled_sample: dict, - unlabeled_sample: dict, - model: nn.Module = None, - weight: float = None, - ) -> torch.FloatTensor: - """Compute an l2 penalty on the latent space of a model""" - # is the embedding pre-computed for both samples? - embed_computed = "embed" in labeled_sample.keys() - if unlabeled_sample is not None: - embed_computed = embed_computed and ("embed" in unlabeled_sample.keys()) - keys = ["input"] - if embed_computed: - keys += ["embed"] - - if unlabeled_sample is not None: - # join tensors across samples - sample = { - k: torch.cat([labeled_sample[k], unlabeled_sample[k]], 0) for k in keys - } - else: - sample = labeled_sample - - if embed_computed: - x_embed = sample["embed"] - else: - data = sample["input"] - logits, x_embed = model(data, return_embed=True) - - l2 = 0.5 * torch.norm(x_embed, p=2) - return l2 - - -# TODO: Consider adding in one of the TC-VAE mutual information -# penalties for latent vars to substitute for the covariance penalty -# inherent in the mean field VAE KL term - - -class UnsupervisedLosses(object): - """Compute multiple unsupervised loss functions""" - - def __init__( - self, - losses: list, - weights: list = None, - ) -> None: - """Compute multiple unsupervised loss functions. - - Parameters - ---------- - losses : List[Callable] - each element in list is a Callable that takes arguments - `labeled_sample, unlabeled_sample` and returns a `torch.FloatTensor` - differentiable loss suitable for backprop. - methods can also take or ignore a `weight` argument. - weights : List[Callable] - matching weight functions for each loss that take an input int epoch - and return a float loss weight. - - Returns - ------- - None. - - Notes - ----- - Computes each loss in serial. - - """ - self.losses = losses - # if no weights are provided, use a uniform schedule with - # weight `1.` for each loss function. - self.weights = weights if weights is not None else [lambda x: 1.0] * len(losses) - return - - def __call__( - self, - labeled_sample: dict, - unlabeled_sample: dict, - ) -> torch.FloatTensor: - loss = torch.zeros( - 1, - ) - for i, fxn in enumerate(self.losses): - fxn_loss = fxn( - labeled_sample=labeled_sample, - unlabeled_sample=unlabeled_sample, - weight=self.weights[i], - ) - loss += fxn_loss - return loss - - -"""Loss weight scheduling""" - - -class ICLWeight(object): - def __init__( - self, - ramp_epochs: int, - burn_in_epochs: int = 0, - max_unsup_weight: float = 10.0, - sigmoid: bool = False, - ) -> None: - """Schedules the interpolation consistency loss - weights across a set of epochs. - - Parameters - ---------- - ramp_epochs : int - number of epochs to increase the unsupervised - loss weight until reaching a maximum value. - burn_in_epochs : int - epochs to wait before increasing the unsupervised loss. - max_unsup_weight : float - maximum weight for the unsupervised loss component. - sigmoid : bool - scale weight using a sigmoid function. - - Returns - ------- - None. - """ - self.ramp_epochs = ramp_epochs - self.burn_in_epochs = burn_in_epochs - self.max_unsup_weight = max_unsup_weight - self.sigmoid = sigmoid - # don't allow division by zero, set step size manually - if self.ramp_epochs == 0.0: - self.step_size = self.max_unsup_weight - else: - self.step_size = self.max_unsup_weight / self.ramp_epochs - print( - "Scaling ICL over %d epochs, %d epochs for burn in." - % (self.ramp_epochs, self.burn_in_epochs) - ) - return - - def _get_weight( - self, - epoch: int, - ) -> float: - """Compute the current weight""" - if epoch >= (self.ramp_epochs + self.burn_in_epochs): - weight = self.max_unsup_weight - elif self.sigmoid: - x = (epoch - self.burn_in_epochs) / self.ramp_epochs - coef = np.exp(-5 * (x - 1) ** 2) - weight = coef * self.max_unsup_weight - else: - weight = self.step_size * (epoch - self.burn_in_epochs) - - return weight - - def __call__( - self, - epoch: int, - ) -> float: - """Compute the weight for an unsupervised IC loss - given the epoch. - - Parameters - ---------- - epoch : int - current training epoch. - - Returns - ------- - weight : float - weight for the unsupervised component of IC loss. - """ - if type(epoch) != int: - raise TypeError(f"epoch must be int, you passed a {type(epoch)}") - if epoch < self.burn_in_epochs: - weight = 0.0 - else: - weight = self._get_weight(epoch) - return weight - - -"""Structured latent variable learning""" - - -class StructuredSparsity(object): - def __init__( - self, - n_genes: int, - n_hidden: int, - gene_sets: dict = None, - gene_names: Iterable = None, - prior_matrix: Union[np.ndarray, torch.Tensor] = None, - n_dense_latent: int = 0, - group_lasso: float = 0.0, - p_norm: int = 1, - nonnegative: bool = False, - ) -> None: - """Add structured sparsity penalties to regularize - weights of an encoding layer. - - Parameters - ---------- - n_genes : int - number of genes in the input layer. - n_hidden : int - number of hidden units in the input layer. - gene_sets : dict, optional. - keys are program names, values are lists of gene names. - must have fewer keys than `n_hidden`. - gene_names : Iterable, optional. - names for genes in `n_genes`. required for use of `gene_sets`. - prior_matrix : np.ndarray, torch.FloatTensor - [n_hidden, n_genes] binary matrix of prior constraints. - if provided with `gene_sets`, this matrix is used instead. - n_dense_latent : int - number of latent variables with no l1 loss applied. - applies to the final `n_dense_latent` variables. - group_lasso : float, optional. - weight for a group LASSO penalty on the second hidden - layer. [Default = 0]. - p_norm : int - p-norm to use for the prior penalty. [Default = 1] for lasso. - nonnegative : bool - apply an L1 penalty to *all* negative values. this implicitly enforces - a roughly non-negative projection matrix. - - Returns - ------- - None. - """ - self.n_genes = n_genes - self.n_hidden = n_hidden - self.gene_sets = gene_sets - self.gene_names = gene_names - self.prior_matrix = None - self.n_dense_latent = n_dense_latent - self.group_lasso = group_lasso - self.p_norm = p_norm - self.nonnegative = nonnegative - - if prior_matrix is None and gene_sets is None: - msg = "Must provide either a prior_matrix or gene_sets to use." - raise ValueError(msg) - - if gene_sets is not None and gene_names is None: - msg = "Must provide `gene_names` to use `gene_sets`." - raise ValueError(msg) - - if gene_sets is not None and gene_names is not None: - - if len(gene_sets.keys()) > self.n_hidden: - # check that we didn't provide too many gene sets - # given the size of our encoder - msg = f"{len(gene_sets.keys())} gene sets provided,\n" - msg += f"but there are only {n_hidden} hidden units.\n" - msg += "Must specify fewer programs than hidden units." - raise ValueError(msg) - - # set `self.prior_matrix` based on the gene sets - # also sets `self.gene_set_names` - self._set_prior_matrix_from_gene_sets() - - if prior_matrix is not None: - # if the prior_matrix was provided, always prefer it. - self.prior_matrix = prior_matrix - - assert self.prior_matrix is not None - return - - def _set_prior_matrix_from_gene_sets( - self, - ) -> None: - """Generate a prior matrix from a set of gene programs - and gene names for the input variables. - """ - self.gene_set_names = sorted(list(self.gene_sets.keys())) - - # [n_programs, n_genes] - P = torch.zeros( - ( - self.n_hidden, - self.n_genes, - ) - ).bool() - - # cast to set for list comprehension speed - gene_names = set(self.gene_names) - for i, k in enumerate(self.gene_set_names): - genes = self.gene_sets[k] - bidx = torch.tensor( - [x in genes for x in gene_names], - dtype=torch.bool, - ) - P[i, :] = bidx - - self.prior_matrix = P - return - - def __call__( - self, - model: nn.Module, - **kwargs, - ) -> torch.FloatTensor: - """Compute the l1 sparsity loss.""" - # get first layer weights - W = dict(model.named_parameters())["embed.0.weight"] - logger.debug(f"Weights {W}, sum: {W.sum()}") - # generate a "penalty" matrix `P` that we'll modify - # before computing the l1 - # this elem-mult zeros out the loss on any annotated - # genes in each gene program - P = W * torch.logical_not(self.prior_matrix).float().to(device=W.device) - logger.debug(f"Penalty {P}, sum {P.sum()}") - # omit the dense latent factors (if any) from the l1 - # computation - n_latent = P.size(0) - self.n_dense_latent - prior_norm = torch.norm(P[:n_latent], p=self.p_norm) - logger.debug(f"l1 {prior_norm}") - - # W1 = dict(model.named_parameters())['embed.4.weight'] - # group_l1 = torch.norm(W1, p=1) - - if self.nonnegative: - # place an optional non-negativity penalty on genes within the gene set - nonneg_inset = W * self.prior_matrix.float().to(device=W.device) - nonneg_norm = torch.norm(nonneg_inset[nonneg_inset < 0], p=self.p_norm) - else: - nonneg_norm = 0.0 - - r = prior_norm + nonneg_norm - return r diff --git a/build/lib/scnym/main.py b/build/lib/scnym/main.py deleted file mode 100644 index 59d84ee..0000000 --- a/build/lib/scnym/main.py +++ /dev/null @@ -1,1678 +0,0 @@ -"""Train scNym models and identify cell type markers""" -import numpy as np -import pandas as pd -from scipy import sparse -import os -import os.path as osp -import scanpy as sc -import logging - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import DataLoader -from sklearn.model_selection import StratifiedKFold -from typing import Union, Tuple -import copy -import itertools -from functools import partial - -from .model import CellTypeCLF - -from .dataprep import SingleCellDS, SampleMixUp, balance_classes -from .dataprep import AUGMENTATION_SCHEMES -from .trainer import Trainer, SemiSupervisedTrainer, MultiTaskTrainer -from .trainer import cross_entropy, get_class_weight -from .trainer import InterpolationConsistencyLoss, ICLWeight, MixMatchLoss, DANLoss -from .losses import scNymCrossEntropy -from .predict import Predicter -from . import utils - -# allow tensorboard outputs even though TF2 is installed -# TF2 broke the tensorboard/pytorch API, so we need to alias -# the old API endpoint below -try: - import tensorflow as tf - tfv = int(tf.__version__.split(".")[0]) -except ImportError: - print("tensorflow is not installed, assuming tensorboard is independent") - tfv = 1 - -if tfv > 1: - import tensorboard as tb - - tf.io.gfile = tb.compat.tensorflow_stub.io.gfile - - -logger = logging.getLogger(__name__) - -# define optimizer map for cli selection -OPTIMIZERS = { - "adadelta": torch.optim.Adadelta, - "adam": torch.optim.Adam, - "adamw": torch.optim.AdamW, - "sgd": torch.optim.SGD, -} - -######################################################### -# Train scNym classification models -######################################################### - - -def repeater(data_loader): - """Use `itertools.repeat` to infinitely loop through - a dataloader. - - Parameters - ---------- - data_loader : torch.utils.data.DataLoader - data loader class. - - Yields - ------ - data : Iterable - batches from `data_loader`. - - Credit - ------ - https://bit.ly/2z0LGm8 - """ - for loader in itertools.repeat(data_loader): - for data in loader: - yield data - - -def fit_model( - X: Union[np.ndarray, sparse.csr.csr_matrix], - y: np.ndarray, - traintest_idx: Union[np.ndarray, tuple], - val_idx: np.ndarray, - batch_size: int, - n_epochs: int, - lr: float, - optimizer_name: str, - weight_decay: float, - ModelClass: nn.Module, - out_path: str, - n_genes: int = None, - mixup_alpha: float = None, - unlabeled_counts: np.ndarray = None, - unsup_max_weight: float = 2.0, - unsup_mean_teacher: bool = False, - ssl_method: str = "mixmatch", - ssl_kwargs: dict = {}, - weighted_classes: bool = False, - balanced_classes: bool = False, - input_domain: np.ndarray = None, - unlabeled_domain: np.ndarray = None, - pretrained: str = None, - patience: int = None, - save_freq: int = None, - tensorboard: bool = True, - **kwargs, -) -> Tuple[float, float]: - """Fit an scNym model given a set of observations and labels. - - Parameters - ---------- - X : np.ndarray - [Cells, Genes] of log1p transformed normalized values. - log1p and normalization performed using scanpy defaults. - y : np.ndarray - [Cells,] integer class labels. - traintest_idx : np.ndarray - [Int,] indices to use for training and early stopping. - a single array will be randomly partitioned, OR a tuple - of `(train_idx, test_idx)` can be passed. - val_idx : np.ndarray - [Int,] indices to hold-out for final model evaluation. - n_epochs : int - number of epochs for training. - lr : float - learning rate. - optimizer_name : str - optimizer to use. {"adadelta", "adam"}. - weight_decay : float - weight decay to apply to model weights. - ModelClass : nn.Module - a model class for construction classification models. - batch_size : int - batch size for training. - fold_indices : list - elements are 2-tuple, with training indices and held-out. - out_path : str - top level path for saving fold outputs. - n_genes : int - number of genes in the input. Not necessarily `X.shape[1]` if - the input matrix has been concatenated with other features. - mixup_alpha : float - alpha parameter for an optional MixUp augmentation during training. - unlabeled_counts : np.ndarray - [Cells', Genes] of log1p transformed normalized values for - unlabeled observations. - unsup_max_weight : float - maximum weight for the unsupervised loss term. - unsup_mean_teacher : bool - use a mean teacher for pseudolabel generation. - ssl_method : str - semi-supervised learning method to use. - ssl_kwargs : dict - arguments passed to the semi-supervised learning loss. - balanced_classes : bool - perform class balancing by undersampling majority classes. - weighted_classes : bool - weight loss for each class based on relative abundance of classes - in the training data. - input_domain : np.ndarray - [Cells,] integer domain labels for training data. - unlabeled_domain : np.ndarray - [Cells',] integer domain labels for unlabeled data. - pretrained : str - path to a pretrained model for initialization. - default: `None`. - patience : int - number of epochs to wait before early stopping. - `None` deactivates early stopping. - save_freq : int - frequency in epochs for saving model checkpoints. - if `None`, saves >=5 checkpoints per model. - tensorboard : bool - save logs to tensorboard. - - Returns - ------- - test_acc : float - classification accuracy on the test set. - test_loss : float - supervised loss on the test set. - """ - # count the number of cell types available - n_cell_types = len(np.unique(y)) - if n_genes is None: - n_genes = X.shape[1] - - if type(traintest_idx) != tuple: - # Set aside 10% of the traintest data for model selection in `test_idx` - train_idx = np.random.choice( - traintest_idx, - size=int(np.floor(0.9 * len(traintest_idx))), - replace=False, - ).astype("int") - test_idx = np.setdiff1d(traintest_idx, train_idx).astype("int") - elif type(traintest_idx) == tuple and len(traintest_idx) == 2: - # use the user provided train/test split - train_idx = traintest_idx[0] - test_idx = traintest_idx[1] - else: - # the user supplied an invalid argument - msg = "`traintest_idx` of type {type(traintest_idx)}\n" - msg += "and length {len(traintest_idx)} is invalid." - raise ValueError(msg) - - # save indices to CSVs for later retrieval - np.savetxt(osp.join(out_path, "train_idx.csv"), train_idx) - np.savetxt(osp.join(out_path, "test_idx.csv"), test_idx) - np.savetxt(osp.join(out_path, "val_idx.csv"), val_idx) - - # balance or weight classes if applicable - if balanced_classes and weighted_classes: - msg = "balancing AND weighting classes is not useful." - msg += "\nPick one mode of accounting for class imbalances." - raise ValueError(msg) - elif balanced_classes and not weighted_classes: - print("Setting up a stratified sampler...") - # we sample classes with weighted likelihood, rather than - # a uniform likelihood of sampling - # we use the inverse of the class count as a weight - # this is normalized in `WeightedRandomSample` - classes, counts = np.unique(y[train_idx], return_counts=True) - sample_weights = 1.0 / counts - - # `WeightedRandomSampler` is kind of funny and takes a weight - # **per example** in the training set, rather than per class. - # here we assign the appropriate class weight to each sample - # in the training set. - weight_per_example = sample_weights[y[train_idx]] - - # we instantiate the sampler with the relevant weight for - # each observation and set the number of total samples to the - # number of samples in our training set - # `WeightedRandomSampler` will sample indices from a multinomial - # with probabilities computed from the normalized vector - # of `weights_per_example`. - sampler = torch.utils.data.sampler.WeightedRandomSampler( - weight_per_example, - len(y[train_idx]), - ) - class_weight = None - elif weighted_classes and not balanced_classes: - # compute class weights - # class weights amplify the loss of some classes and reduce - # the loss of others, inversely proportional to the class - # frequency - print("Weighting classes for training...") - class_weight = get_class_weight(y[train_idx]) - print(class_weight) - print() - sampler = None - else: - print("Not weighting classes and not balancing classes.") - class_weight = None - sampler = None - - # Generate training and model selection Datasets and Dataloaders - X_train = X[train_idx, :] - y_train = y[train_idx] - - X_test = X[test_idx, :] - y_test = y[test_idx] - - # count the number of domains - if ( - (input_domain is None) - and (unlabeled_domain is None) - and (unlabeled_counts is not None) - ): - n_domains = 2 - elif ( - (input_domain is None) - and (unlabeled_domain is None) - and (unlabeled_counts is None) - ): - n_domains = 1 - elif (input_domain is not None) and (unlabeled_domain is None): - input_domain_max = input_domain.max() - n_domains = int(input_domain_max) - elif (input_domain is not None) and (unlabeled_domain is not None): - input_domain_max = input_domain.max() - unlabeled_domain_max = ( - 0 if len(unlabeled_domain) == 0 else unlabeled_domain.max() - ) - n_domains = ( - int( - np.max( - [ - input_domain_max, - unlabeled_domain_max, - ] - ) - ) - + 1 - ) - else: - msg = "domains supplied for only one set of data" - raise ValueError(msg) - print(f"Found {n_domains} unique domains.") - - if input_domain is not None: - d_train = input_domain[train_idx] - d_test = input_domain[test_idx] - else: - d_train = None - d_test = None - - train_ds = SingleCellDS( - X=X_train, - y=y_train, - num_classes=len(np.unique(y)), - domain=d_train, - num_domains=n_domains, - ) - test_ds = SingleCellDS( - X_test, - y_test, - num_classes=len(np.unique(y)), - domain=d_test, - num_domains=n_domains, - ) - logger.debug(f"{len(train_ds)} training samples in DS.") - logger.debug(f"{len(test_ds)} testing samples in DS.") - - train_dl = DataLoader( - train_ds, - batch_size=batch_size, - shuffle=True if sampler is None else False, - sampler=sampler, - drop_last=True, - ) - test_dl = DataLoader( - test_ds, - batch_size=batch_size, - shuffle=True, - ) - logger.debug(f"{len(train_dl)} training samples in DL.") - logger.debug(f"{len(test_dl)} testing samples in DL.") - - dataloaders = { - "train": train_dl, - "val": test_dl, - } - - # Define batch transformers - batch_transformers = {} - if mixup_alpha is not None and ssl_method != "mixmatch": - print("Using MixUp as a batch transformer.") - batch_transformers["train"] = SampleMixUp(alpha=mixup_alpha) - - # Build a cell type classification model and transfer to CUDA - model = ModelClass( - n_genes=n_genes, - n_cell_types=n_cell_types, - **kwargs, - ) - - if pretrained is not None: - # initialize with supplied weights - model.load_state_dict( - torch.load( - pretrained, - map_location="cpu", - ) - ) - - if torch.cuda.is_available(): - model = model.cuda() - - # Set up loss criterion and the model optimizer - # here we use our own cross_entropy loss to handle - # discrete probability distributions rather than - # categorical predictions - if class_weight is None: - criterion = cross_entropy - else: - criterion = partial( - cross_entropy, - class_weight=torch.from_numpy(class_weight).float(), - ) - - opt_callable = OPTIMIZERS[optimizer_name.lower()] - - if opt_callable != torch.optim.SGD: - optimizer = opt_callable( - model.parameters(), - weight_decay=weight_decay, - lr=lr, - ) - scheduler = None - else: - # use SGD as the optimizer with momentum - # and a learning rate scheduler - optimizer = opt_callable( - model.parameters(), - weight_decay=weight_decay, - lr=lr, - momentum=0.9, - ) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer=optimizer, - T_max=n_epochs, - eta_min=lr / 10000, - ) - - # Build the relevant trainer object for either supervised - # or semi-supervised learning with interpolation consistency - trainer_kwargs = { - "model": model, - "criterion": criterion, - "optimizer": optimizer, - "scheduler": scheduler, - "dataloaders": dataloaders, - "out_path": out_path, - "batch_transformers": batch_transformers, - "n_epochs": n_epochs, - "min_epochs": n_epochs // 20, - "save_freq": max(n_epochs // 5, 1) if save_freq is None else save_freq, - "reg_criterion": None, - "exp_name": osp.basename(out_path), - "verbose": False, - "tb_writer": osp.join(out_path, "tblog") if tensorboard else None, - "patience": patience, - } - - if unlabeled_counts is None and (n_domains == 1): - # perform fully supervised training - print("Performing fully supervised training with no domain adaptation.") - T = Trainer(**trainer_kwargs) - elif unlabeled_counts is None and (n_domains > 1): - print("Performing supervised training with a domain adversary.") - # perform supervised training with DA - # use the MultiTaskTrainer - dan_criterion = ssl_kwargs.get("dan_criterion", None) - if dan_criterion is not None: - # initialize the DAN Loss - - dan_criterion = DANLoss( - model=model, - dan_criterion=cross_entropy, - use_conf_pseudolabels=ssl_kwargs.get( - "dan_use_conf_pseudolabels", False - ), - scale_loss_pseudoconf=ssl_kwargs.get( - "dan_scale_loss_pseudoconf", False - ), - n_domains=n_domains, - ) - - # setup the DANN learning rate schedule - dan_weight = ICLWeight( - ramp_epochs=ssl_kwargs.get("dan_ramp_epochs", max(n_epochs // 4, 1)), - max_unsup_weight=ssl_kwargs.get("dan_max_weight", 1.0), - burn_in_epochs=ssl_kwargs.get("dan_burn_in_epochs", 0), - sigmoid=ssl_kwargs.get("sigmoid", True), - ) - # add DANN parameters to the optimizer - optimizer.add_param_group( - { - "params": dan_criterion.dann.domain_clf.parameters(), - "name": "domain_classifier", - } - ) - ce = scNymCrossEntropy() - criteria = [ - {"name": "dan", "function": dan_criterion, "weight": dan_weight}, - {"name": "ce", "function": ce, "weight": 1.0}, - ] - del trainer_kwargs["criterion"] - trainer_kwargs["criteria"] = criteria - T = MultiTaskTrainer(**trainer_kwargs) - else: - # perform semi-supervised training - unsup_dataset = SingleCellDS( - X=unlabeled_counts, - y=np.zeros(unlabeled_counts.shape[0]), - num_classes=len(np.unique(y)), - domain=unlabeled_domain, - num_domains=n_domains, - ) - - # Build a semi-supervised data loader that infinitely samples - # unsupervised data for interpolation consistency. - # This allows us to loop through the labeled data iterator - # without running out of unlabeled batches. - unsup_dataloader = DataLoader( - unsup_dataset, - batch_size=batch_size, - shuffle=True, - drop_last=True, - ) - unsup_dataloader = repeater(unsup_dataloader) - - # Set up the unsupervised loss - if ssl_method.lower() == "ict": - print("Using ICT for semi-supervised learning") - USL = InterpolationConsistencyLoss( - alpha=mixup_alpha if mixup_alpha is not None else 0.3, - unsup_criterion=nn.MSELoss(), - sup_criterion=criterion, - decay_coef=ssl_kwargs.get("decay_coef", 0.997), - mean_teacher=unsup_mean_teacher, - ) - - elif ssl_method.lower() == "mixmatch": - print("Using MixMatch for semi-supervised learning") - # we want the raw MSE per sample here, rather than the average - # so we set `reduction='none'`. - # this allows us to scale the weight of individual examples - # based on pseudolabel confidence. - unsup_criterion_name = ssl_kwargs.get("unsup_criterion", "mse") - if unsup_criterion_name.lower() == "mse": - unsup_criterion = nn.MSELoss(reduction="none") - elif unsup_criterion_name.lower() in ("crossentropy", "ce"): - unsup_criterion = partial( - cross_entropy, - reduction="none", - ) - USL = MixMatchLoss( - alpha=mixup_alpha if mixup_alpha is not None else 0.3, - unsup_criterion=unsup_criterion, - sup_criterion=criterion, - decay_coef=ssl_kwargs.get("decay_coef", 0.997), - mean_teacher=unsup_mean_teacher, - augment=AUGMENTATION_SCHEMES[ssl_kwargs.get("augment", "log1p_drop")], - n_augmentations=ssl_kwargs.get("n_augmentations", 1), - T=ssl_kwargs.get("T", 0.5), - augment_pseudolabels=ssl_kwargs.get("augment_pseudolabels", True), - pseudolabel_min_confidence=ssl_kwargs.get( - "pseudolabel_min_confidence", 0.0 - ), - ) - else: - msg = f"{ssl_method} is not a valid semi-supervised learning method.\n" - msg += 'must be one of {"ict", "mixmatch"}' - raise ValueError(msg) - - # set up the weight schedule - # we define a number of epochs for ramping, a number to wait - # ("burn_in_epochs") before we start the ramp up, and a maximum - # coefficient value - weight_schedule = ICLWeight( - ramp_epochs=ssl_kwargs.get("ramp_epochs", max(n_epochs // 4, 1)), - max_unsup_weight=unsup_max_weight, - burn_in_epochs=ssl_kwargs.get("burn_in_epochs", 20), - sigmoid=ssl_kwargs.get("sigmoid", False), - ) - # don't let early stopping save checkpoints from before the SSL - # ramp up has started - trainer_kwargs["min_epochs"] = max( - trainer_kwargs["min_epochs"], - weight_schedule.burn_in_epochs + weight_schedule.ramp_epochs // 5, - ) - - # if min_epochs are manually specified, use that number instead - if ssl_kwargs.get("min_epochs", None) is not None: - trainer_kwargs["min_epochs"] = ssl_kwargs["min_epochs"] - - # let the model save weights even if the ramp is - # longer than the total epochs we'll train for - trainer_kwargs["min_epochs"] = min( - trainer_kwargs["min_epochs"], - trainer_kwargs["n_epochs"] - 1, - ) - - dan_criterion = ssl_kwargs.get("dan_criterion", None) - if dan_criterion is not None: - # initialize the DAN Loss - - dan_criterion = DANLoss( - model=model, - dan_criterion=cross_entropy, - use_conf_pseudolabels=ssl_kwargs.get( - "dan_use_conf_pseudolabels", False - ), - scale_loss_pseudoconf=ssl_kwargs.get( - "dan_scale_loss_pseudoconf", False - ), - n_domains=n_domains, - ) - - # setup the DANN learning rate schedule - dan_weight = ICLWeight( - ramp_epochs=ssl_kwargs.get("dan_ramp_epochs", max(n_epochs // 4, 1)), - max_unsup_weight=ssl_kwargs.get("dan_max_weight", 1.0), - burn_in_epochs=ssl_kwargs.get("dan_burn_in_epochs", 0), - sigmoid=ssl_kwargs.get("sigmoid", True), - ) - # add DANN parameters to the optimizer - optimizer.add_param_group( - { - "params": dan_criterion.dann.domain_clf.parameters(), - "name": "domain_classifier", - } - ) - else: - dan_weight = None - - # initialize the trainer - T = SemiSupervisedTrainer( - unsup_dataloader=unsup_dataloader, - unsup_criterion=USL, - unsup_weight=weight_schedule, - dan_criterion=dan_criterion, - dan_weight=dan_weight, - **trainer_kwargs, - ) - - print("Training...") - T.train() - print("Training complete.") - print() - - # Perform model evaluation using the best set of weights on the - # totally unseen, held out data. - print("Evaluating model.") - model = ModelClass( - n_genes=n_genes, - n_cell_types=n_cell_types, - **kwargs, - ) - model.load_state_dict( - torch.load( - osp.join(out_path, "00_best_model_weights.pkl"), - ) - ) - model.eval() - - if torch.cuda.is_available(): - model = model.cuda() - - # Build a DataLoader for validation - X_val = X[val_idx, :] - y_val = y[val_idx] - val_ds = SingleCellDS( - X_val, - y_val, - num_classes=len(np.unique(y)), - ) - val_dl = DataLoader( - val_ds, - batch_size=batch_size, - shuffle=False, - ) - - # Without recording any gradients to speed things up, - # predict classes for all held out data and evaluate metrics. - with torch.no_grad(): - loss = 0.0 - running_corrects = 0.0 - running_total = 0.0 - all_predictions = [] - all_labels = [] - for data in val_dl: - input_ = data["input"] - - label_ = data["output"] # one-hot - - if torch.cuda.is_available(): - input_ = input_.cuda() - label_ = label_.cuda() - - # make an integer version of labels for convenience - int_label_ = torch.argmax(label_, 1) - - # Perform forward pass and compute predictions as the - # most likely class - output = model(input_) - _, predictions = torch.max(output, 1) - - corrects = torch.sum( - predictions.detach() == int_label_.detach(), - ) - - l = criterion(output, label_) - loss += float(l.detach().cpu().numpy()) - - running_corrects += float(corrects.item()) - running_total += float(label_.size(0)) - - all_labels.append(int_label_.detach().cpu().numpy()) - - all_predictions.append(predictions.detach().cpu().numpy()) - - norm_loss = loss / len(val_dl) - acc = running_corrects / running_total - print("EVAL LOSS: ", norm_loss) - print("EVAL ACC : ", acc) - - all_predictions = np.concatenate(all_predictions) - all_labels = np.concatenate(all_labels) - np.savetxt(osp.join(out_path, "predictions.csv"), all_predictions) - np.savetxt(osp.join(out_path, "labels.csv"), all_labels) - - PL = np.stack([all_predictions, all_labels], 0) - print("Predictions | Labels") - print(PL.T[:15, :]) - return acc, norm_loss - - -def train_cv( - X: Union[np.ndarray, sparse.csr.csr_matrix], - y: np.ndarray, - batch_size: int, - n_epochs: int, - lr: float, - optimizer_name: str, - weight_decay: float, - ModelClass: nn.Module, - fold_indices: list, - out_path: str, - n_genes: int = None, - mixup_alpha: float = None, - unlabeled_counts: np.ndarray = None, - unsup_max_weight: float = 2.0, - unsup_mean_teacher: bool = False, - ssl_method: str = "mixmatch", - ssl_kwargs: dict = {}, - weighted_classes: bool = False, - balanced_classes: bool = False, - **kwargs, -) -> Tuple[np.ndarray, np.ndarray]: - """Perform training using a provided set of training/hold-out - sample indices. - - Parameters - ---------- - X : np.ndarray - [Cells, Genes] of log1p transformed normalized values. - log1p and normalization performed using scanpy defaults. - y : np.ndarray - [Cells,] integer class labels. - n_epochs : int - number of epochs for training. - weight_decay : float - weight decay to apply to model weights. - lr : float - learning rate. - optimizer_name : str - optimizer to use. {"adadelta", "adam"}. - ModelClass : nn.Module - a model class for construction classification models. - batch_size : int - batch size for training. - fold_indices : list - elements are 2-tuple, with training indices and held-out. - out_path : str - top level path for saving fold outputs. - n_genes : int - number of genes in the input. Not necessarily `X.shape[1]` if - the input matrix has been concatenated with other features. - mixup_alpha : float - alpha parameter for an optional MixUp augmentation during training. - unsup_max_weight : float - maximum weight for the unsupervised loss term. - unsup_mean_teacher : bool - use a mean teacher for pseudolabel generation. - ssl_method : str - semi-supervised learning method to use. - ssl_kwargs : dict - arguments passed to the semi-supervised learning loss. - balanced_classes : bool - perform class balancing by undersampling majority classes. - weighted_classes : bool - weight loss for each class based on relative abundance of classes - in the training data. - - Returns - ------- - fold_eval_acc : np.ndarray - evaluation accuracies for each fold. - fold_eval_losses : np.ndarray - loss values for each fold. - """ - fold_eval_losses = np.zeros(len(fold_indices)) - fold_eval_acc = np.zeros(len(fold_indices)) - - # Perform training on each fold specified in `fold_indices` - for f in range(len(fold_indices)): - print("Training tissue independent, fold %d." % f) - fold_out_path = osp.join(out_path, "fold" + str(f).zfill(2)) - - os.makedirs(fold_out_path, exist_ok=True) - - traintest_idx = fold_indices[f][0].astype("int") - val_idx = fold_indices[f][1].astype("int") - - acc, loss = fit_model( - X=X, - y=y, - traintest_idx=traintest_idx, - val_idx=val_idx, - out_path=fold_out_path, - batch_size=batch_size, - n_epochs=n_epochs, - ModelClass=ModelClass, - n_genes=n_genes, - lr=lr, - optimizer_name=optimizer_name, - weight_decay=weight_decay, - mixup_alpha=mixup_alpha, - unlabeled_counts=unlabeled_counts, - unsup_max_weight=unsup_max_weight, - unsup_mean_teacher=unsup_mean_teacher, - ssl_method=ssl_method, - ssl_kwargs=ssl_kwargs, - weighted_classes=weighted_classes, - balanced_classes=balanced_classes, - **kwargs, - ) - - fold_eval_losses[f] = loss - fold_eval_acc[f] = acc - return fold_eval_acc, fold_eval_losses - - -def train_all( - X: Union[np.ndarray, sparse.csr.csr_matrix], - y: np.ndarray, - batch_size: int, - n_epochs: int, - ModelClass: nn.Module, - out_path: str, - n_genes: int = None, - lr: float = 1.0, - optimizer_name: str = "adadelta", - weight_decay: float = None, - mixup_alpha: float = None, - unlabeled_counts: np.ndarray = None, - unsup_max_weight: float = 2.0, - unsup_mean_teacher: bool = False, - ssl_method: str = "mixmatch", - ssl_kwargs: dict = {}, - weighted_classes: bool = False, - balanced_classes: bool = False, - **kwargs, -) -> Tuple[float, float]: - """Perform training using all provided samples. - - Parameters - ---------- - X : np.ndarray - [Cells, Genes] of log1p transformed normalized values. - log1p and normalization performed using scanpy defaults. - y : np.ndarray - [Cells,] integer class labels. - n_epochs : int - number of epochs for training. - ModelClass : nn.Module - a model class for construction classification models. - batch_size : int - batch size for training. - out_path : str - top level path for saving fold outputs. - n_genes : int - number of genes in the input. Not necessarily `X.shape[1]` if - the input matrix has been concatenated with other features. - lr : float - learning rate. - optimizer_name : str - optimizer to use. {"adadelta", "adam"}. - weight_decay : float - weight decay to apply to model weights. - balanced_classes : bool - perform class balancing by undersampling majority classes. - weighted_classes : bool - weight loss for each class based on relative abundance of classes - in the training data. - - Returns - ------- - loss : float - best loss on the testing set used for model selection. - acc : float - best accuracy on the testing set used for model selection. - """ - # Prepare a unique output directory - all_out_path = osp.join(out_path, "all_data") - if not osp.exists(all_out_path): - os.mkdir(all_out_path) - - # Generate training and model selection indices - traintest_idx = np.random.choice( - np.arange(X.shape[0]), - size=int(np.floor(0.9 * X.shape[0])), - replace=False, - ).astype("int") - val_idx = np.setdiff1d( - np.arange(X.shape[0]), - traintest_idx, - ).astype("int") - - acc, loss = fit_model( - X=X, - y=y, - traintest_idx=traintest_idx, - val_idx=val_idx, - batch_size=batch_size, - n_epochs=n_epochs, - ModelClass=ModelClass, - out_path=all_out_path, - n_genes=n_genes, - lr=lr, - optimizer_name=optimizer_name, - weight_decay=weight_decay, - mixup_alpha=mixup_alpha, - unlabeled_counts=unlabeled_counts, - unsup_max_weight=unsup_max_weight, - unsup_mean_teacher=unsup_mean_teacher, - ssl_method=ssl_method, - ssl_kwargs=ssl_kwargs, - weighted_classes=weighted_classes, - balanced_classes=balanced_classes, - **kwargs, - ) - - np.savetxt( - osp.join(all_out_path, "test_loss_acc.csv"), - np.array([loss, acc]).reshape(2, 1), - delimiter=",", - ) - - return loss, acc - - -def train_tissue_independent_cv( - X: Union[np.ndarray, sparse.csr.csr_matrix], - metadata: pd.DataFrame, - out_path: str, - balanced_classes: bool = False, - weighted_classes: bool = False, - batch_size: int = 256, - n_epochs: int = 200, - lower_group: str = "cell_ontology_class", - **kwargs, -) -> None: - """ - Trains a cell type classifier that is independent of tissue origin - - Parameters - ---------- - X : np.ndarray - [Cells, Genes] of log1p transformed, normalized values. - log1p and normalization performed using scanpy defaults. - metadata : pd.DataFrame - [Cells, Features] data with `upper_group` and `lower_group` columns. - out_path : str - path for saving trained model weights and evaluation performance. - balanced_classes : bool - perform class balancing by undersampling majority classes. - weighted_classes : bool - weight loss for each class based on relative abundance of classes - in the training data. - batch_size : int - batch size for training. - n_epochs : int - number of epochs for training. - lower_group : str - column in `metadata` corresponding to output classes. i.e. cell types. - - Returns - ------- - None. - - Notes - ----- - Passes `kwargs` to `CellTypeCLF`. - """ - - print("TRAINING TISSUE INDEPENDENT CLASSIFIER") - print("-" * 20) - print() - - if not os.path.exists(out_path): - os.mkdir(out_path) - - # identify all the `lower_group` levels and create - # an integer class vector corresponding to unique levels - celltypes = sorted(list(set(metadata[lower_group]))) - print("There are %d %s in the experiment.\n" % (len(celltypes), lower_group)) - - for t in celltypes: - print(t) - - # identify all the `lower_group` levels and create - # an integer class vector corresponding to unique levels - y = pd.Categorical(metadata[lower_group]).codes - y = y.astype("int32") - labels = pd.Categorical(metadata[lower_group]).categories - # save mapping of levels : integer values as a CSV - out_df = pd.DataFrame({"label": labels, "code": np.arange(len(labels))}) - out_df.to_csv(osp.join(out_path, "celltype_label.csv")) - - # generate k-fold cross-validation split indices - # & vectors for metrics evaluated at each fold. - kf = StratifiedKFold(n_splits=5, shuffle=True) - kf_indices = list(kf.split(X, y)) - - # Perform training on each fold specified in `kf_indices` - fold_eval_acc, fold_eval_losses = train_cv( - X=X, - y=y, - batch_size=batch_size, - n_epochs=n_epochs, - ModelClass=CellTypeCLF, - fold_indices=kf_indices, - out_path=out_path, - balanced_classes=balanced_classes, - weighted_classes=weighted_classes, - **kwargs, - ) - - # Save the per-fold results to CSVs - - print("Fold eval losses") - print(fold_eval_losses) - print("Fold eval accuracy") - print(fold_eval_acc) - print("Mean %f Std %f" % (fold_eval_losses.mean(), fold_eval_losses.std())) - np.savetxt( - osp.join( - out_path, - "fold_eval_losses.csv", - ), - fold_eval_losses, - ) - np.savetxt( - osp.join( - out_path, - "fold_eval_acc.csv", - ), - fold_eval_acc, - ) - - # Train a model using all available data (after class balancing) - val_loss, val_acc = train_all( - X=X, - y=y, - batch_size=batch_size, - n_epochs=n_epochs, - ModelClass=CellTypeCLF, - out_path=out_path, - balanced_classes=balanced_classes, - weighted_classes=weighted_classes, - **kwargs, - ) - - return - - -def train_one_tissue_cv( - X: Union[np.ndarray, sparse.csr.csr_matrix], - metadata: pd.DataFrame, - out_path: str, - balanced_classes: bool = False, - weighted_classes: bool = False, - batch_size: int = 256, - n_epochs: int = 200, - upper_group: str = "tissue", - lower_group: str = "cell_ontology_class", - **kwargs, -) -> None: - """ - Trains a cell type classifier for a single tissue - - Parameters - ---------- - X : np.ndarray - [Cells, Genes] of log1p transformed, normalized values. - log1p and normalization performed using scanpy defaults. - metadata : pd.DataFrame - [Cells, Features] data with `upper_group` and `lower_group` columns. - out_path : str - path for saving trained model weights and evaluation performance. - balanced_classes : bool, optional - perform class balancing by undersampling majority classes. - weighted_classes : bool - weight loss for each class based on relative abundance of classes - in the training data. - upper_group : str - column in `metadata` with subsets for training `lower_group` - classifiers independently. i.e. tissues. - lower_group : str - column in `metadata` corresponding to output classes. i.e. cell types. - - Returns - ------- - None. - """ - - tissue_str = str(list(metadata[upper_group])[0]).lower() - print( - "TRAINING %s DEPENDENT CLASSIFIER FOR: " % upper_group.upper(), - tissue_str.upper(), - ) - print("-" * 20) - print() - - celltypes = sorted(list(set(metadata[lower_group]))) - print("There are %d %s in the experiment.\n" % (len(celltypes), lower_group)) - for t in celltypes: - print(t) - print("") - y = pd.Categorical(metadata[lower_group]).codes - y = y.astype("int32") - labels = pd.Categorical(metadata[lower_group]).categories - out_df = pd.DataFrame({"label": labels, "code": np.arange(len(labels))}) - out_df.to_csv(osp.join(out_path, "celltype_label.csv")) - - kf = StratifiedKFold(n_splits=5, shuffle=True) - kf_indices = list(kf.split(X, y)) - - # Perform training on each fold specified in `kf_indices` - fold_eval_acc, fold_eval_losses = train_cv( - X=X, - y=y, - batch_size=batch_size, - n_epochs=n_epochs, - ModelClass=CellTypeCLF, - fold_indices=kf_indices, - out_path=out_path, - weighted_classes=weighted_classes, - balanced_classes=balanced_classes, - **kwargs, - ) - - print("Fold eval losses") - print(fold_eval_losses) - print("Fold eval accuracy") - print(fold_eval_acc) - print("Mean %f Std %f" % (fold_eval_losses.mean(), fold_eval_losses.std())) - np.savetxt( - osp.join( - out_path, - "fold_eval_losses.csv", - ), - fold_eval_losses, - ) - np.savetxt( - osp.join( - out_path, - "fold_eval_acc.csv", - ), - fold_eval_acc, - ) - - # Train a model using all available data (after class balancing) - val_loss, val_acc = train_all( - X=X, - y=y, - batch_size=batch_size, - n_epochs=n_epochs, - ModelClass=CellTypeCLF, - out_path=out_path, - weighted_classes=weighted_classes, - balanced_classes=balanced_classes, - **kwargs, - ) - return - - -######################################################### -# Predict cell types with a trained model -######################################################### - - -def predict_cell_types( - X: Union[np.ndarray, sparse.csr.csr_matrix], - model_path: str, - out_path: str, - upper_groups: Union[list, np.ndarray] = None, - lower_group_labels: list = None, - **kwargs, -) -> None: - """Predict cell types using a pretrained model - - Parameters - ---------- - X : np.ndarray, sparse.csr.csr_matrix - [Cells, Genes] of log1p transformed, normalized values. - log1p and normalization performed using scanpy defaults. - model_path : str - path to a set of pretrained model weights. - out_path : str - path for prediction outputs. - upper_groups : list, np.ndarray - [Cells,] iterable of str specifying the `upper_group` for each cell. - if provided, assumes an `upper_group` conditional model. - if `None`, assumes an `upper_group` independent model. - lower_group_labels : list - str labels corresponding to output nodes of the model. - - Returns - ------- - None. - - Notes - ----- - `**kwargs` passed to `scnym.predict.Predicter`. - """ - if upper_groups is not None: - print("Assuming conditional model.") - - X, categories = utils.append_categorical_to_data(X, upper_groups) - np.savetxt( - osp.join(out_path, "category_names.csv"), - categories, - fmt="%s", - delimiter=",", - ) - else: - print("Assuming independent model") - - # Intantiate a prediction object, which handles batch processing - P = Predicter( - model_weights=model_path, - n_genes=X.shape[1], - n_cell_types=None, # infer cell type # from weights - labels=lower_group_labels, - **kwargs, - ) - - predictions, names, scores = P.predict(X, output="score") - - probabilities = F.softmax(torch.from_numpy(scores), dim=1) - probabilities = probabilities.cpu().numpy() - - np.savetxt(osp.join(out_path, "predictions_idx.csv"), predictions, delimiter=",") - np.savetxt(osp.join(out_path, "probabilities.csv"), probabilities, delimiter=",") - np.savetxt(osp.join(out_path, "raw_scores.csv"), scores, delimiter=",") - if names is not None: - np.savetxt( - osp.join(out_path, "predictions_names.csv"), names, delimiter=",", fmt="%s" - ) - return - - -######################################################### -# utilities -######################################################### - - -def load_data( - path: str, -) -> Union[np.ndarray, sparse.csr.csr_matrix]: - """Load a counts matrix from a file path. - - Parameters - ---------- - path : str - path to [npy, csv, h5ad, loom] file. - - Returns - ------- - X : np.ndarray - [Cells, Genes] matrix. - """ - if osp.splitext(path)[-1] == ".npy": - print("Assuming sparse matrix...") - X_raw = np.load(path, allow_pickle=True) - X_raw = X_raw.item() - elif osp.splitext(path)[-1] == ".csv": - X_raw = np.loadtxt(path, delimiter=",") - elif osp.splitext(path)[-1] == ".h5ad": - adata = sc.read_h5ad(path) - X_raw = utils.get_adata_asarray(adata=adata) - elif osp.splitext(path)[-1] == ".loom": - adata = sc.read_loom(path) - X_raw = utils.get_adata_asarray(adata=adata) - else: - raise ValueError( - "unrecognized file type %s for counts" % osp.splitext(path)[-1] - ) - - return X_raw - - -######################################################### -# main() -######################################################### - - -def main(): - import configargparse - import yaml - - parser = configargparse.ArgParser( - description="Train cell type classifiers", - default_config_files=["./configs/default_config.txt"], - ) - parser.add_argument( - "command", - type=str, - help='action to perform. \ - ["train_tissue_independent", \ - "train_tissue_dependent", \ - "train_tissue_specific", \ - "find_cell_type_markers", \ - "predict_cell_types"]', - ) - parser.add_argument( - "-c", is_config_file=True, required=False, help="path to a configuration file." - ) - parser.add_argument( - "--input_counts", - type=str, - required=True, - help="path to input data [Cells, Genes] counts. \ - [npy, csv, h5ad, loom]", - ) - parser.add_argument( - "--input_gene_names", - type=str, - required=True, - help="path to gene names for the input data.", - ) - parser.add_argument( - "--training_gene_names", - type=str, - required=False, - help="path to training data gene names. \ - required for prediction.", - ) - parser.add_argument( - "--training_metadata", - type=str, - required=True, - help="CSV metadata for training. Requires `upper_group` and `lower_group` columns. \ - necessary for prediction to provide cell type names.", - ) - parser.add_argument( - "--lower_group", - type=str, - required=True, - default="cell_ontology_class", - help="column in `metadata` with to output labels. \ - i.e. cell types.", - ) - parser.add_argument( - "--upper_group", - type=str, - required=True, - default="tissue", - help="column in `metadata` with to subsets for independent training. \ - i.e. tissues.", - ) - parser.add_argument( - "--out_path", type=str, required=True, help="path for output files" - ) - parser.add_argument( - "--genes_to_use", - type=str, - default=None, - help="path to a text file of genes to use for training. \ - must be a subset of genes in `training_gene_names`", - ) - parser.add_argument( - "--input_domain_group", - type=str, - help="column in `training_metadata` that specifies domain of origin for each training observation.", - required=False, - default=None, - ) - parser.add_argument( - "--batch_size", type=int, default=256, help="batch size for training" - ) - parser.add_argument( - "--n_epochs", type=int, default=256, help="number of epochs for training" - ) - parser.add_argument( - "--init_dropout", - type=float, - default=0.3, - help="initial dropout to perform on gene inputs", - ) - parser.add_argument( - "--n_hidden", - type=int, - default=128, - help="number of hidden units in the classifier", - ) - parser.add_argument( - "--n_layers", type=int, default=2, help="number of hidden layers in the model" - ) - parser.add_argument( - "--residual", action="store_true", help="use residual layers in the model" - ) - parser.add_argument( - "--track_running_stats", - type=bool, - default=True, - help="track running statistics in batch normalization layers", - ) - parser.add_argument( - "--model_path", - type=str, - default=None, - help="path to pretrained model weights \ - for class marker identification.", - ) - parser.add_argument( - "--weight_decay", - type=float, - default=1e-5, - help="weight decay applied by the optimizer", - ) - parser.add_argument( - "--lr", type=float, default=1.0, help="learning rate for the optimizer." - ) - parser.add_argument( - "--optimizer", - type=str, - default="adadelta", - help="optimizer to use. {adadelta, adam}.", - ) - parser.add_argument( - "--l1_reg", - type=float, - default=1e-4, - help="l1 regularization strength \ - for class marker identification", - ) - parser.add_argument( - "--weight_classes", - type=bool, - default=False, - help="weight loss based on relative class abundance.", - ) - parser.add_argument( - "--balance_classes", type=bool, default=False, help="perform class balancing." - ) - parser.add_argument( - "--mixup_alpha", - type=float, - default=None, - help="alpha parameter for MixUp training. \ - if set performs MixUp, otherwise does not.", - ) - parser.add_argument( - "--unlabeled_counts", - type=str, - default=None, - help="path to unlabeled data [Cells, Genes]. \ - [npy, csv, h5ad, loom]. \ - if provided, uses interpolation consistency training.", - ) - parser.add_argument( - "--unlabeled_genes", - type=str, - default=None, - help="path to gene names for the unlabeled data.\ - if not provided, assumes same as `input_counts`.", - ) - parser.add_argument( - "--unlabeled_domain", - type=str, - help="path to a CSV of integer domain labels for each data point in `unlabeled_counts`.", - required=False, - default=None, - ) - parser.add_argument( - "--unsup_max_weight", - type=float, - default=2.0, - help="maximum weight for the unsupervised component of IC training.", - ) - parser.add_argument( - "--unsup_mean_teacher", - action="store_true", - help="use a mean teacher for IC training.", - ) - parser.add_argument( - "--ssl_method", - type=str, - default="mixmatch", - help='semi-supervised learning method to use. {"mixmatch", "ict"}.', - ) - parser.add_argument( - "--ssl_config", - type=str, - default=None, - help="path to a YAML configuration file of kwargs for the SSL method.", - ) - args = parser.parse_args() - - print(args) - print(parser.format_values()) - - COMMANDS = [ - "train_tissue_independent", - "train_tissue_dependent", - "train_tissue_specific", - "predict_cell_types", - ] - - if args.command not in COMMANDS: - raise ValueError("%s is not a valid command." % args.command) - - ##################################### - # LOAD DATA - ##################################### - - X_raw = load_data(args.input_counts) - - print("Loaded data.") - print("%d cells and %d genes in raw data." % X_raw.shape) - gene_names = np.loadtxt(args.input_gene_names, dtype="str") - print("Loaded gene names for the raw data. %d genes." % len(gene_names)) - - if args.genes_to_use is not None: - genes_to_use = np.loadtxt(args.genes_to_use, dtype="str") - print( - "Using a subset of %d genes as specified in \n %s." - % (len(genes_to_use), args.genes_to_use) - ) - else: - genes_to_use = gene_names - - if args.genes_to_use is not None: - # Filter the input matrix to use only the specified genes - print("Using %d genes for classification." % len(genes_to_use)) - gnl = gene_names.tolist() - keep_idx = np.array([gnl.index(x) for x in genes_to_use]) - X = X_raw[:, keep_idx] - else: - # leave all genes in the matrix - X = X_raw - - # Load metadata and identify output classes - metadata = pd.read_csv( - args.training_metadata, - ) - lower_groups = np.unique(metadata[args.lower_group]).tolist() - - # load domain labels if applicable - if args.input_domain_group is not None: - if args.input_domain_group not in metadata.columns: - msg = f"{args.input_domain_group} is not a column in `training_metadata`" - raise ValueError(msg) - else: - input_domain = np.array(metadata[args.input_domain_group]) - else: - input_domain = None - - # Load any provided unlabeled data for semi-supervised learning - if args.unlabeled_counts is not None: - unlabeled_counts = load_data(args.unlabeled_counts) - print("%d cells, %d genes in unlabeled data." % unlabeled_counts.shape) - - # parse any semi-supervised learning specific parameters - if args.ssl_config is not None: - print(f"Loading Semi-Supervised Learning parameters for {args.ssl_method}") - with open(args.ssl_config, "r") as f: - ssl_kwargs = yaml.load(f, Loader=yaml.Loader) - print("SSL kwargs:") - for k, v in ssl_kwargs.items(): - print(f"{k}\t\t:\t\t{v}") - print() - else: - ssl_kwargs = {} - - else: - unlabeled_counts = None - ssl_kwargs = {} - - if args.unlabeled_genes is not None and unlabeled_counts is not None: - # Contruct a matrix using the unlabeled counts where columns - # correspond to the same gene in `input_counts`. - print("Subsetting unlabeled counts to genes used for training...") - unlabeled_genes = np.loadtxt( - args.unlabeled_genes, - delimiter=",", - dtype="str", - ) - unlabeled_counts = utils.build_classification_matrix( - X=unlabeled_counts, - model_genes=genes_to_use, - sample_genes=unlabeled_genes, - ) - if args.unlabeled_domain is not None: - unlabeled_domain = np.loadtxt( - args.unlabeled_domain, - ).astype(np.int) - else: - unlabeled_domain = None - else: - unlabeled_domain = None - - # prepare output paths - if not os.path.exists(args.out_path): - os.mkdir(args.out_path) - - sub_dirs = [ - "tissues", - "tissue_independent_no_dropout", - "tissue_dependent", - "tissue_ind_class_optimums", - ] - for sd in sub_dirs: - if not os.path.exists(osp.join(args.out_path, sd)): - os.mkdir(osp.join(args.out_path, sd)) - - ##################################### - # TISSUE INDEPENDENT CLASSIFIERS - ##################################### - - if args.command == "train_tissue_independent": - train_tissue_independent_cv( - X, - metadata, - osp.join(args.out_path, "tissue_independent"), - balanced_classes=args.balance_classes, - weighted_classes=args.weight_classes, - batch_size=args.batch_size, - n_epochs=args.n_epochs, - init_dropout=args.init_dropout, - lower_group=args.lower_group, - n_hidden=args.n_hidden, - n_layers=args.n_layers, - lr=args.lr, - optimizer_name=args.optimizer, - weight_decay=args.weight_decay, - residual=args.residual, - track_running_stats=args.track_running_stats, - mixup_alpha=args.mixup_alpha, - unlabeled_counts=unlabeled_counts, - unsup_max_weight=args.unsup_max_weight, - unsup_mean_teacher=args.unsup_mean_teacher, - ssl_method=args.ssl_method, - ssl_kwargs=ssl_kwargs, - input_domain=input_domain, - unlabeled_domain=unlabeled_domain, - ) - - ##################################### - # PRETRAINED MODEL PREDICTION - ##################################### - - if args.command == "predict_cell_types": - if args.model_path is None: - raise ValueError("`model_path` required.") - if args.training_gene_names is None: - raise ValueError("must supply `training_gene_names`.") - training_genes = np.loadtxt( - args.training_gene_names, delimiter=",", dtype="str" - ).tolist() - - X = utils.build_classification_matrix( - X=X, - model_genes=training_genes, - sample_genes=gene_names, - ) - - predict_cell_types( - X, - model_path=args.model_path, - out_path=args.out_path, - lower_group_labels=lower_groups, - n_hidden=args.n_hidden, - n_layers=args.n_layers, - residual=args.residual, - ) - - -######################################################### -# __main__ -######################################################### - - -if __name__ == "__main__": - - main() diff --git a/build/lib/scnym/model.py b/build/lib/scnym/model.py deleted file mode 100644 index e94dde1..0000000 --- a/build/lib/scnym/model.py +++ /dev/null @@ -1,603 +0,0 @@ -import torch -import torch.nn as nn -from typing import Callable, Iterable, Union, Tuple -import logging - -logger = logging.getLogger(__name__) - - -class ResBlock(nn.Module): - """Residual block. - - References - ---------- - Deep Residual Learning for Image Recognition - Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun - arXiv:1512.03385 - """ - - def __init__( - self, - n_inputs: int, - n_hidden: int, - ) -> None: - """Residual block with fully-connected neural network - layers. - - Parameters - ---------- - n_inputs : int - number of input dimensions. - n_hidden : int - number of hidden dimensions in the Residual Block. - - Returns - ------- - None. - """ - super(ResBlock, self).__init__() - - self.n_inputs = n_inputs - self.n_hidden = n_hidden - - # Build the initial projection layer - self.linear00 = nn.Linear(self.n_inputs, self.n_hidden) - self.norm00 = nn.BatchNorm1d(num_features=self.n_hidden) - self.relu00 = nn.ReLU(inplace=True) - - # Map from the latent space to output space - self.linear01 = nn.Linear(self.n_hidden, self.n_hidden) - self.norm01 = nn.BatchNorm1d(num_features=self.n_hidden) - self.relu01 = nn.ReLU(inplace=True) - return - - def forward( - self, - x: torch.FloatTensor, - ) -> torch.FloatTensor: - """Residual block forward pass. - - Parameters - ---------- - x : torch.FloatTensor - [Batch, self.n_inputs] - - Returns - ------- - o : torch.FloatTensor - [Batch, self.n_hidden] - """ - identity = x - - # Project input to the latent space - o = self.norm00(self.linear00(x)) - o = self.relu00(o) - - # Project from the latent space to output space - o = self.norm01(self.linear01(o)) - - # Make this a residual connection - # by additive identity operation - o += identity - return self.relu01(o) - - -class CellTypeCLF(nn.Module): - """Cell type classifier from expression data. - - Attributes - ---------- - n_genes : int - number of input genes in the model. - n_cell_types : int - number of output classes in the model. - n_hidden : int - number of hidden units in the model. - n_layers : int - number of hidden layers in the model. - init_dropout : float - dropout proportion prior to the first layer. - residual : bool - use residual connections. - """ - - def __init__( - self, - n_genes: int, - n_cell_types: int, - n_hidden: int = 256, - n_hidden_init: int = 256, - n_layers: int = 2, - init_dropout: float = 0.0, - residual: bool = False, - batch_norm: bool = True, - track_running_stats: bool = True, - n_decoder_layers: int = 0, - use_raw_counts: bool = False, - ) -> None: - """ - Cell type classifier from expression data. - Linear layers with batch norm and dropout. - - Parameters - ---------- - n_genes : int - number of genes in the input - n_cell_types : int - number of cell types for the output - n_hidden : int - number of hidden unit - n_hidden_init : - number of hidden units for the initial encoding layer. - n_layers : int - number of hidden layers. - init_dropout : float - dropout proportion prior to the first layer. - residual : bool - use residual connections. - batch_norm : bool - use batch normalization in hidden layers. - track_running_stats : bool - track running statistics in batch norm layers. - n_decoder_layers : int - number of layers in the decoder. - use_raw_counts : bool - provide raw counts as input. - - Returns - ------- - None. - """ - super(CellTypeCLF, self).__init__() - - self.n_genes = n_genes - self.n_cell_types = n_cell_types - self.n_hidden = n_hidden - self.n_hidden_init = n_hidden_init - self.n_decoder_layers = n_decoder_layers - self.n_layers = n_layers - self.init_dropout = init_dropout - self.residual = residual - self.batch_norm = batch_norm - self.track_running_stats = track_running_stats - self.use_raw_counts = use_raw_counts - - # simulate technical dropout of scRNAseq - self.init_dropout = nn.Dropout(p=self.init_dropout) - - # Define a vanilla NN layer with batch norm, dropout, ReLU - vanilla_layer = [ - nn.Linear(self.n_hidden, self.n_hidden), - ] - if self.batch_norm: - vanilla_layer += [ - nn.BatchNorm1d( - num_features=self.n_hidden, - track_running_stats=self.track_running_stats, - ), - ] - vanilla_layer += [ - nn.Dropout(), - nn.ReLU(inplace=True), - ] - - # Define a residual NN layer with batch norm, dropout, ReLU - residual_layer = [ - ResBlock(self.n_hidden, self.n_hidden), - ] - if self.batch_norm: - residual_layer += [ - nn.BatchNorm1d( - num_features=self.n_hidden, - track_running_stats=self.track_running_stats, - ), - ] - - residual_layer += [ - nn.Dropout(), - nn.ReLU(inplace=True), - ] - - # Build the intermediary layers of the model - if self.residual: - hidden_layer = residual_layer - else: - hidden_layer = vanilla_layer - - hidden_layers = hidden_layer * (self.n_layers - 1) - - # Build the classifier `nn.Module`. - self.embed = nn.Sequential( - nn.Linear(self.n_genes, self.n_hidden_init), - nn.BatchNorm1d( - num_features=self.n_hidden_init, - track_running_stats=self.track_running_stats, - ), - nn.Dropout(), - nn.ReLU(inplace=True), - nn.Linear(self.n_hidden_init, self.n_hidden), - nn.BatchNorm1d( - num_features=self.n_hidden, - track_running_stats=self.track_running_stats, - ), - nn.Dropout(), - nn.ReLU(inplace=True), - *hidden_layers, - ) - - dec_hidden = hidden_layer * (self.n_decoder_layers - 1) - final_clf = nn.Linear(self.n_hidden, self.n_cell_types) - self.classif = nn.Sequential( - *dec_hidden, - final_clf, - ) - return - - def forward( - self, - x: torch.FloatTensor, - return_embed: bool = False, - ) -> torch.FloatTensor: - """Perform a forward pass through the model - - Parameters - ---------- - x : torch.FloatTensor - [Batch, self.n_genes] - return_embed : bool - return the embedding and the class predictions. - - Returns - ------- - pred : torch.FloatTensor - [Batch, self.n_cell_types] - embed : torch.FloatTensor, optional - [Batch, n_hidden], only returned if `return_embed`. - """ - # add initial dropout noise - if self.init_dropout.p > 0 and not self.use_raw_counts: - # counts are log1p(CPM) - # expm1 to normed counts - x = torch.expm1(x) - x = self.init_dropout(x) - # renorm to log1p CPM - size = torch.sum(x, dim=1).reshape(-1, 1) - prop_input_ = x / size - norm_input_ = prop_input_ * 1e6 - x = torch.log1p(norm_input_) - elif self.init_dropout.p > 0 and self.use_raw_counts: - x = self.init_dropout(x) - else: - # we don't need to do initial dropout - pass - x_embed = self.embed(x) - pred = self.classif(x_embed) - - if return_embed: - r = ( - pred, - x_embed, - ) - else: - r = pred - return r - - -class GradReverse(torch.autograd.Function): - """Layer that reverses and scales gradients before - passing them up to earlier ops in the computation graph - during backpropogation. - """ - - @staticmethod - def forward(ctx, x, weight): - """ - Perform a no-op forward pass that stores a weight for later - gradient scaling during backprop. - - Parameters - ---------- - x : torch.FloatTensor - [Batch, Features] - weight : float - weight for scaling gradients during backpropogation. - stored in the "context" ctx variable. - - Notes - ----- - We subclass `Function` and use only @staticmethod as specified - in the newstyle pytorch autograd functions. - https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function - - We define a "context" ctx of the class that will hold any values - passed during forward for use in the backward pass. - - `x.view_as(x)` and `*1` are necessary so that `GradReverse` - is actually called - `torch.autograd` tries to optimize backprop and - excludes no-ops, so we have to trick it :) - """ - # store the weight we'll use in backward in the context - ctx.weight = weight - return x.view_as(x) * 1.0 - - @staticmethod - def backward(ctx, grad_output): - """Return gradients - - Returns - ------- - rev_grad : torch.FloatTensor - reversed gradients scaled by `weight` passed in `.forward()` - None : None - a dummy "gradient" required since we passed a weight float - in `.forward()`. - """ - # here scale the gradient and multiply by -1 - # to reverse the gradients - return (grad_output * -1 * ctx.weight), None - - -class DANN(nn.Module): - """Build a domain adaptation neural network""" - - def __init__( - self, - model: CellTypeCLF, - n_domains: int = 2, - weight: float = 1.0, - n_layers: int = 1, - ) -> None: - """Build a domain adaptation neural network using - the embedding of a provided model. - - Parameters - ---------- - model : CellTypeCLF - cell type classification model. - n_domains : int - number of domains to adapt. - weight : float - weight for reversed gradients. - n_layers : int - number of hidden layers in the network. - - Returns - ------- - None. - """ - super(DANN, self).__init__() - - self.model = model - self.n_domains = n_domains - - self.embed = model.embed - - hidden_layers = [ - nn.Linear(self.model.n_hidden, self.model.n_hidden), - nn.ReLU(), - ] * n_layers - - self.domain_clf = nn.Sequential( - *hidden_layers, - nn.Linear(self.model.n_hidden, self.n_domains), - ) - return - - def set_rev_grad_weight( - self, - weight: float, - ) -> None: - """Set the weight term used after reversing gradients""" - self.weight = weight - return - - def forward( - self, - x: torch.FloatTensor, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - """Perform a forward pass. - - Parameters - ---------- - x : torch.FloatTensor - [Batch, Features] input. - - Returns - ------- - domain_pred : torch.FloatTensor - [Batch, n_domains] logits. - x_embed : torch.FloatTensor - [Batch, n_hidden] - """ - # get the model embedding - x_embed = self.embed(x) - # reverse gradients and scale by a weight - # domain_pred -> x_rev -> GradReverse -> x_embed - # d+ -> d+ -> d- -> d- - x_rev = GradReverse.apply( - x_embed, - self.weight, - ) - # classify the domains - domain_pred = self.domain_clf(x_rev) - return domain_pred, x_embed - - -class AE(nn.Module): - """Build an autoencoder that shares the classifier embedding. - - Attributes - ---------- - model : CellTypeCLF - cell type classification model. - n_layers : int - number of hidden layers in the network. - n_hidden : int - number of hidden units in each hidden layer. - defaults to the hidden layer size of the model. - dispersion : torch.nn.Parameter - [model.n_genes,] dispersion parameters for each gene. - `None` unless `model.use_raw_counts`. - latent_libsize : bool - use a latent variable to store library size. if `False`, - uses the observed library size to scale abundance profiles. - """ - - noise_scale = 1.0 - - def __init__( - self, - model: CellTypeCLF, - n_layers: int = 2, - n_hidden: int = None, - n_domains: int = None, - latent_libsize: bool = False, - ) -> None: - """Build an autoencoder using the embedding of a provided model. - - Parameters - ---------- - model : CellTypeCLF - cell type classification model. - n_layers : int - number of hidden layers in the network. - n_hidden : int - number of hidden units in each hidden layer. - defaults to the hidden layer size of the model. - n_domains : int - number of domain covariates to include. - latent_libsize : bool - use a latent variable to store library size. if `False`, - uses the observed library size to scale abundance profiles. - - Returns - ------- - None. - - Notes - ----- - Maps gene expression vectors to an embedding using the same - trunk as the classification model. If `model.use_raw_counts`, - reconstructs library depth using the latent library size and - also learns a set of dispersion parameters for each gene. - Reconstructs profiles using a decoder model that mirrors the - classification embedding trunk. - """ - super(AE, self).__init__() - - self.model = model - self.n_hidden = self.model.n_hidden if n_hidden is None else n_hidden - self.latent_libsize = latent_libsize - self.n_domains = n_domains if n_domains is not None else 0 - - # extract the embedder from the classification model - self.embed = self.model.embed - - # append decoder layers - dec_input = [ - nn.Linear(self.model.n_hidden + self.n_domains, self.n_hidden), - nn.ReLU(), - ] - - hidden_layers = [ - nn.Linear(self.model.n_hidden, self.n_hidden), - nn.ReLU(), - ] * (n_layers - 1) - - self.decoder = nn.Sequential( - *dec_input, - *hidden_layers, - nn.Linear(self.n_hidden, self.model.n_genes), - ) - - if self.model.use_raw_counts: - # initialize dispersion parameters from a unit Gaussian - self.dispersion = nn.Parameter(torch.randn(self.model.n_genes)) - else: - self.dispersion = torch.ones((1,)) - - # encode log(library_size) as a latent variable - self.libenc = nn.Sequential( - nn.Linear(self.model.n_genes + self.n_domains, 1), - nn.ReLU(), - ) - - return - - def noise( - self, - x_embed: torch.FloatTensor, - ) -> torch.FloatTensor: - """Add white noise to the latent embedding""" - eps = torch.randn_like(x_embed) * self.noise_scale - return torch.nn.functional.relu(x_embed + eps) - - def forward( - self, - x: torch.FloatTensor, - x_embed: torch.FloatTensor = None, - x_domain: torch.FloatTensor = None, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - """Perform a forward pass. - - Parameters - ---------- - x : torch.FloatTensor - [Batch, Features] input. - x_embed : torch.FloatTensor, optional. - [Batch, n_hidden] embedding. - x_domain : torch.FloatTensor, optional. - [Batch, Domains] one-hot labels. - used for conditional decoding. - - Returns - ------- - reconstructed_profiles : torch.FloatTensor - [Batch, Features] abundance profiles [0, 1]. - scaled_profiles : torch.FloatTensor - [Batch, Features] profiles scaled by latent depths. - dispersion : torch.FloatTensor - [Features,] dispersion parameters for each gene. - x_embed : torch.FloatTensor - [Batch, n_hidden] - """ - # get the model embedding, avoid recomputing if a precomputed - # embedding is passed in - x_embed = self.embed(x) if x_embed is None else x_embed - - if self.training: - x_embed = self.noise(x_embed) - - # check the dimensions are sane - if x_embed.size(-1) > 2048: - logger.warn( - f"AE `x_embed` dimension is larger than expected: {x_embed.size(1)}" - ) - - # add domain covariates if provided and initialized to use covars - if x_domain is None and self.n_domains > 0: - msg = "Must provide domain covariates for a conditional model. Received `None`." - raise TypeError(msg) - if x_domain is not None and self.n_domains > 0: - logger.debug(f"Domain covariates added. Size {x_domain.size()}.") - x_embed = torch.cat([x_embed, x_domain], dim=-1) - x2libsz = torch.cat([x, x_domain], dim=-1) - else: - x2libsz = x - - # reconstruct gene expression abundance profiles, first with raw - # activations - x_rec = self.decoder(x_embed) - # use softmax to go from logits to relative abundance profiles - x_rec = nn.functional.softmax(x_rec, dim=1) - - if self.latent_libsize: - # `libenc` returns the log of the library size - lib_size = self.libenc(x2libsz) - lib_size = torch.clamp(lib_size, max=12) # numerical stability - else: - lib_size = torch.log(x.sum(1)).view(-1, 1) # [Cells, 1] - x_scaled = x_rec * torch.exp(lib_size) - - return x_rec, x_scaled, torch.exp(self.dispersion), x_embed diff --git a/build/lib/scnym/predict.py b/build/lib/scnym/predict.py deleted file mode 100644 index 8b84b42..0000000 --- a/build/lib/scnym/predict.py +++ /dev/null @@ -1,216 +0,0 @@ -import numpy as np -from scipy import sparse -import os -import torch -import torch.nn.functional as F -from typing import Union -from .model import CellTypeCLF -from .dataprep import SingleCellDS -import tqdm - - -class Predicter(object): - """Predict cell types from expression data using `CellTypeCLF`. - - Attributes - ---------- - model_weights : list - paths to model weights for classification. - labels : list - str labels for output classes. - n_cell_types : int - number of output classes. - n_genes : int - number of input genes. - models : list - `nn.Module` for each set of weights in `.model_weights`. - """ - - def __init__( - self, - model_weights: Union[str, list, tuple], - n_genes: int = None, - n_cell_types: int = None, - labels: list = None, - **kwargs, - ) -> None: - """ - Predict cell types using pretrained weights for `CellTypeCLF`. - - Parameters - ---------- - model_weights : str, list, tuple - paths to pre-trained model weights. if more than one - path to weights is provided, predicts using an ensemble - of models. - n_genes : int - number of genes in the input frame. - n_cell_types : int - number of cell types in the output. - labels : list - string labels corresponding to each cell type output - **kwargs passed to `model.CellTypeCLF` - """ - if type(model_weights) == str: - self.model_weights = [model_weights] - else: - self.model_weights = model_weights - self.labels = labels - - if n_cell_types is None: - # get the number of output nodes from the pretrained model - print( - "Assuming `n_cell_types` is the same as in the \ - pretrained model weights." - ) - params = torch.load(self.model_weights[0], map_location="cpu") - fkey = list(params.keys())[-1] - self.n_cell_types = len(params[fkey]) - else: - self.n_cell_types = n_cell_types - - # check that all the specified weights exist - for weights in self.model_weights: - if not os.path.exists(weights): - raise FileNotFoundError() - - if n_genes is None: - # get the number of input genes from the model weights - print( - "Assuming `n_genes` is the same as in the \ - pretrained model weights." - ) - params = torch.load(model_weights, map_location="cpu") - fkey = list(params.keys())[0] - self.n_genes = params[fkey].shape[1] - else: - self.n_genes = n_genes - - # Load each set of weights in `model_weights` into a model - # to use in an ensemble prediction. - self.models = [] - for weights in self.model_weights: - model = CellTypeCLF( - n_genes=self.n_genes, - n_cell_types=self.n_cell_types, - **kwargs, - ) - model.load_state_dict(torch.load(weights, map_location="cpu")) - - if torch.cuda.is_available(): - model = model.cuda() - - self.models.append(model.eval()) - - return - - def predict( - self, - X: Union[np.ndarray, sparse.csr.csr_matrix, torch.FloatTensor], - output: str = None, - batch_size: int = 1024, - **kwargs, - ) -> (np.ndarray, list): - """ - Predict cell types given a matrix `X`. - - Parameters - ---------- - X : np.ndarray, sparse.csr.csr_matrix, torch.FloatTensor - [Cells, Genes] - output : str - additional output to include as an optional third tuple. - ('prob', 'score'). - batch_size : int - batch size to use for predictions. - - Returns - ------- - predictions : np.ndarray - [Cells,] ints of predicted class - names : list - [Cells,] str of predicted class names - probabilities : np.ndarray - [Cells, Types] probabilities (softmax outputs). - - Notes - ----- - acceptable **kwarg for legacy compatibility -- - return_prob : bool - return probabilities as an optional third output. - """ - if not X.shape[1] == self.n_genes: - gs = (X.shape[1], self.n_genes) - raise ValueError("%d genes in X, %d genes in model." % gs) - - if "return_prob" in kwargs: - return_prob = kwargs["return_prob"] - else: - return_prob = None - - if output not in ["prob", "score"] and output is not None: - msg = f"{output} is not a valid additional output." - raise ValueError(msg) - - # build a SingleCellDS so we can load cells onto the - # GPU in batches - ds = SingleCellDS(X=X, y=np.zeros(X.shape[0])) - dl = torch.utils.data.DataLoader( - ds, - batch_size=batch_size, - ) - - # For each cell vector, compute a prediction - # and a class probability vector. - predictions = [] - scores = [] - probabilities = [] - - # For each cell, compute predictions - for data in tqdm.tqdm(dl, desc="Finding cell types"): - - X_batch = data["input"] - - if torch.cuda.is_available(): - X_batch = X_batch.cuda() - - # take an average prediction across all models provided - outs = [] - for model in self.models: - out = model(X_batch) - outs.append(out) - outs = torch.stack(outs, dim=0) - out = torch.mean(outs, dim=0) - - # save most likely prediction and output probabilities - scores.append(out.detach().cpu().numpy()) - - _, pred = torch.max(out, 1) - predictions.append(pred.detach().cpu().numpy()) - - probs = F.softmax(out, dim=1) - probabilities.append(probs.detach().cpu().numpy()) - - predictions = np.concatenate(predictions, axis=0) # [Cells,] - scores = np.concatenate(scores, axis=0) # [Cells, Types] - probabilities = np.concatenate(probabilities, axis=0) # [Cells, Types] - - if self.labels is not None: - names = [] - for i in range(len(predictions)): - names += [self.labels[predictions[i]]] - else: - names = None - - # Parse the arguments to determine what to return - # N.B. that `return_prob` here is to support legacy code - # and may be removed in the future. - if return_prob is True: - return predictions, names, probabilities - elif output is not None: - if output == "prob": - return predictions, names, probabilities - elif output == "score": - return predictions, names, scores - else: - return predictions, names diff --git a/build/lib/scnym/scnym_ad.py b/build/lib/scnym/scnym_ad.py deleted file mode 100644 index 04a4f42..0000000 --- a/build/lib/scnym/scnym_ad.py +++ /dev/null @@ -1,217 +0,0 @@ -"""scNym model training from standard anndata objects""" -import anndata -import os -import os.path as osp -import uuid -import configargparse -import numpy as np -import pandas as pd - -from .main import train_cv, train_all -from .model import CellTypeCLF -from .utils import build_classification_matrix -from sklearn.model_selection import StratifiedKFold - - -def make_parser(): - parser = configargparse.ArgParser( - description="train an scNym cell type classification model." - ) - parser.add_argument( - "--config", - type=str, - is_config_file=True, - required=False, - help="path to a configuration file.", - ) - parser.add_argument( - "--data", type=str, help="path to an h5ad [Cells, Features] object." - ) - parser.add_argument( - "--groupby", - type=str, - help="categorical feature in `adata.obs` to use for classifier training.", - ) - parser.add_argument("--out_path", type=str, help="path for outputs.") - parser.add_argument( - "--batch_size", - type=int, - default=256, - help="batch size for training", - ) - parser.add_argument( - "--n_epochs", - type=int, - default=200, - help="number of epochs for training", - ) - parser.add_argument( - "--init_dropout", - type=float, - default=0.0, - help="initial dropout to perform on gene inputs", - ) - parser.add_argument( - "--n_hidden", - type=int, - default=128, - help="number of hidden units in the classifier", - ) - parser.add_argument( - "--n_layers", - type=int, - default=2, - help="number of hidden layers in the model", - ) - parser.add_argument( - "--residual", - action="store_true", - help="use residual layers in the model", - ) - parser.add_argument( - "--weight_decay", - type=float, - default=1e-5, - help="weight decay applied by the optimizer", - ) - parser.add_argument( - "--weight_classes", - type=bool, - default=True, - help="weight loss based on relative class abundance.", - ) - parser.add_argument( - "--mixup_alpha", - type=float, - default=None, - help="alpha parameter for MixUp training. if set performs MixUp, otherwise does not.", - ) - parser.add_argument( - "--unlabeled_counts", - type=str, - default=None, - help="path to h5ad [Cells, Features] object of unlabeled data.", - ) - parser.add_argument( - "--unsup_max_weight", - type=float, - default=2.0, - help="maximum weight for the unsupervised component of IC training.", - ) - parser.add_argument( - "--unsup_mean_teacher", - type=bool, - default=True, - help="use a mean teacher for IC training.", - ) - parser.add_argument( - "--cross_val_train", - action="store_true", - ) - return parser - - -def main(): - parser = make_parser() - args = parser.parse_args() - - adata = anndata.read_h5ad(args.data) - print(f"{adata.shape[0]} cells, {adata.shape[1]} genes in the training data.") - - if args.groupby not in adata.obs: - msg = f"{args.groupby} not in `adata.obs`" - raise ValueError(msg) - - os.makedirs(args.out_path, exist_ok=True) - - if args.unlabeled_counts is None: - unlabeled_counts = None - else: - # load unlabeled counts and build a matrix that follows - # gene dimension ordering of the training data - unlabeled_adata = anndata.read_h5ad(args.unlabeled_counts) - unlabeled_counts = build_classification_matrix( - X=unlabeled_adata.X - if type(unlabeled_adata.X) == np.ndarray - else unlabeled_adata.X.toarray(), - model_genes=np.array(adata.var_names), - sample_genes=np.array(unlabeled_adata.var_names), - ) - - X = adata.X if type(adata.X) == np.ndarray else adata.X.toarray() - y = pd.Categorical(adata.obs[args.groupby]).codes - - model_params = { - "n_hidden": args.n_hidden, - "residual": args.residual, - "n_layers": args.n_layers, - "init_dropout": args.init_dropout, - } - - if args.cross_val_train: - kf = StratifiedKFold(n_splits=5, shuffle=True) - fold_indices = list(kf.split(X, y)) - - fold_eval_acc, fold_eval_losses = train_cv( - X=X, - y=y, - batch_size=args.batch_size, - n_epochs=args.n_epochs, - weight_decay=args.weight_decay, - ModelClass=CellTypeCLF, - fold_indices=fold_indices, - out_path=args.out_path, - n_genes=adata.shape[1], - mixup_alpha=args.mixup_alpha, - unlabeled_counts=unlabeled_counts, - unsup_max_weight=args.unsup_max_weight, - unsup_mean_teacher=args.unsup_mean_teacher, - weighted_classes=args.weight_classes, - **model_params, - ) - np.savetxt( - osp.join( - args.out_path, - "fold_eval_losses.csv", - ), - fold_eval_losses, - ) - np.savetxt( - osp.join( - args.out_path, - "fold_eval_acc.csv", - ), - fold_eval_acc, - ) - - val_loss, val_acc = train_all( - X=X, - y=y, - batch_size=args.batch_size, - n_epochs=args.n_epochs, - weight_decay=args.weight_decay, - ModelClass=CellTypeCLF, - out_path=args.out_path, - n_genes=adata.shape[1], - mixup_alpha=args.mixup_alpha, - unlabeled_counts=unlabeled_counts, - unsup_max_weight=args.unsup_max_weight, - unsup_mean_teacher=args.unsup_mean_teacher, - weighted_classes=args.weight_classes, - **model_params, - ) - print(f"Final validation loss: {val_loss:08}") - print(f"Final validation acc : {val_acc:08}") - - # get exp id - exp_id = uuid.uuid4() - res = pd.DataFrame( - {"val_acc": val_acc, "val_loss": val_loss}, - index=[exp_id], - ).to_csv( - osp.join( - args.out_path, - "all_data_val_results.csv", - ) - ) - return diff --git a/build/lib/scnym/trainer.py b/build/lib/scnym/trainer.py deleted file mode 100644 index c3c8522..0000000 --- a/build/lib/scnym/trainer.py +++ /dev/null @@ -1,1412 +0,0 @@ -import numpy as np -import os -import os.path as osp -import torch -import torch.nn as nn -import torch.nn.functional as F -import json -import logging -from typing import Callable, Iterable, Union, List -from .dataprep import SampleMixUp -from .utils import compute_entropy_of_mixing -from .model import CellTypeCLF, DANN -import copy -from torch.utils.tensorboard import SummaryWriter - -from .dataprep import SampleMixUp -from .utils import compute_entropy_of_mixing -from .model import CellTypeCLF, DANN, AE -from .losses import * - - -logger = logging.getLogger(__name__) - - -class Trainer(object): - """ - Trains a PyTorch model. - - Attributes - ---------- - model : nn.Module - model with required `.forward(...)` method. - criterion : Callable - loss criterion to optimize. - optimizer : torch.optim.Optimizer - optimizer for the model parameters. - dataloaders : dict - keyed by ['train', 'val'] with values corresponding - to `torch.utils.data.DataLoader` for training - and validation sets. - out_path : str - output path for best model. - n_epochs : int - number of epochs for training. - min_epochs : int - minimum number of epochs before saving weights. - patience : int - maximum number of epochs to wait before early stopping. - if `None`, infinite patience is used (up to `n_epochs`). - waiting_time : int - number of epochs since the last best val loss. - reg_criterion : Callable - criterion to penalize layer weights. - use_gpu : bool - use CUDA acceleration. - verbose : bool - write all batch losses to stdout. - save_freq : int - Number of epochs between model checkpoints. Default = 10. - scheduler : learning rate scheduler. - """ - - def __init__( - self, - model: nn.Module, - criterion: Callable, - optimizer: torch.optim.Optimizer, - dataloaders: dict, - out_path: str, - batch_transformers: dict = {}, - n_epochs: int = 50, - min_epochs: int = 0, - patience: int = None, - exp_name: str = "", - reg_criterion: Callable = None, - use_gpu: bool = torch.cuda.is_available(), - verbose: bool = False, - save_freq: int = 10, - scheduler: torch.optim.lr_scheduler = None, - tb_writer: str = None, - ) -> None: - """ - Trains a PyTorch `nn.Module` object provided in `model` - on training and testing sets provided in `dataloaders` - using `criterion` and `optimizer`. - - Saves model weight snapshots every `save_freq` epochs and saves the - weights with the best testing loss at the end of training. - - Parameters - ---------- - model : nn.Module - model with required `.forward(...)` method. - criterion : Callable - loss criterion to optimize. - optimizer : torch.optim.Optimizer - optimizer for the model parameters. - dataloaders : dict - keyed by ['train', 'val'] with values corresponding - to `torch.utils.data.DataLoader` for training - and validation sets. - out_path : str - output path for best model. - batch_transformers : dict - apply transforms to minibatch inputs and targets. - keys are ['train', 'val'], values are Callable. - n_epochs : int - number of epochs for training. - min_epochs : int - minimum number of epochs before saving weights. - patience : int - maximum number of epochs to wait before early stopping. - if `None`, infinite patience is used (up to `n_epochs`). - reg_criterion : callable - criterion to penalize layer weights. - use_gpu : bool - use CUDA acceleration. - verbose : bool - write all batch losses to stdout. - save_freq : int - Number of epochs between model checkpoints. Default = 10. - scheduler : torch.optim.lr_scheduler - learning rate schedule. - - Returns - ------- - None. - """ - self.model = model - self.optimizer = optimizer - self.criterion = criterion - self.n_epochs = n_epochs - self.min_epochs = min_epochs - self.patience = patience if patience is not None else n_epochs - self.waiting_time = 0 - self.dataloaders = dataloaders - self.batch_transformers = batch_transformers - self.out_path = out_path - self.use_gpu = use_gpu - self.verbose = verbose - self.save_freq = save_freq - self.best_acc = 0.0 - self.best_loss = 1.0e10 - self.scheduler = scheduler - self.reg_criterion = reg_criterion - if tb_writer is not None: - self.tb_writer = SummaryWriter(log_dir=tb_writer) - os.makedirs(tb_writer, exist_ok=True) - else: - self.tb_writer = None - - if not os.path.exists(self.out_path): - os.mkdir(self.out_path) - # initialize log - - self.log_path = os.path.join(self.out_path, "_".join([exp_name, "log.csv"])) - with open(self.log_path, "w") as f: - header = "Epoch,Running_Loss,Mode\n" - f.write(header) - - self.parameters = { - "out_path": out_path, - "exp_name": exp_name, - "n_epochs": n_epochs, - "use_cuda": self.use_gpu, - "train_batch_size": self.dataloaders["train"].batch_size, - "val_batch_size": self.dataloaders["val"].batch_size, - "train_batch_sampler": str(type(self.dataloaders["train"].sampler)), - "val_batch_sampler": str(type(self.dataloaders["val"].sampler)), - "optimizer_type": str(type(self.optimizer)), - "learning_rate": self.optimizer.param_groups[0]["lr"], - "model_hidden": self.model.n_hidden, - "model_ngenes": self.model.n_genes, - "model_ncelltypes": self.model.n_cell_types, - } - - # write the log file header - with open(self.log_path, "w") as f: - header = "Epoch,Iter,Running_Loss,Mode\n" - f.write(header) - - def train_epoch(self): - """Perform training across one full iteration through - the data. - """ - self.model.train(True) - i = 0 - running_loss = 0.0 - running_corrects = 0.0 - running_total = 0.0 - - btrans = self.batch_transformers.get("train", None) - for data in self.dataloaders["train"]: - # if a batch transformer is present, - # transform the data before use - if btrans is not None: - data = btrans(data) - - inputs = data["input"] - labels = data["output"] # one-hot - - if self.use_gpu: - inputs = inputs.cuda() - labels = labels.cuda() - else: - pass - inputs.requires_grad_() - labels.requires_grad = False - - # zero gradients - self.optimizer.zero_grad() - - # forward pass - outputs = self.model(inputs) - # predictions are the output nodes with - # the highest values - _, predictions = torch.max(outputs, 1) - - # remake an integer version of the labels for quick checking - int_labels = torch.argmax(labels, 1) - - correct = torch.sum(predictions.detach() == int_labels.detach()) - - # compute loss - if self.reg_criterion is not None: - reg_loss = self.reg_criterion(self.model) - loss = self.criterion(outputs, labels) + reg_loss - else: - loss = self.criterion(outputs, labels) - - if self.verbose: - print("batch loss: ", loss.item()) - if np.isnan(loss.data.cpu().numpy()): - raise RuntimeError("NaN loss encountered in training") - - # compute gradients in a backward pass, update parameters - loss.backward() - self.optimizer.step() - - # statistics update - running_loss += loss.item() / inputs.size(0) - running_corrects += float(correct.item()) - running_total += float(labels.size(0)) - - if i % 100 == 0 and self.verbose: - print("Iter : ", i) - print("running_loss : ", running_loss / (i + 1)) - print("running_acc : ", running_corrects / running_total) - print("corrects: %f | total: %f" % (running_corrects, running_total)) - # append to log - with open(self.log_path, "a") as f: - f.write( - str(self.epoch) - + "," - + str(i) - + "," - + str(running_loss / (i + 1)) - + ",train\n" - ) - i += 1 - - epoch_loss = running_loss / len(self.dataloaders["train"]) - epoch_acc = running_corrects / running_total - - # append to log - with open(self.log_path, "a") as f: - f.write( - str(self.epoch) - + "," - + str(i) - + "," - + str(running_loss / (i + 1)) - + ",train_epoch\n" - ) - - if self.tb_writer is not None: - self.tb_writer.add_scalar("Loss/train", epoch_loss, self.epoch) - self.tb_writer.add_scalar("Acc/train", epoch_acc, self.epoch) - for i, p in enumerate(self.model.parameters()): - self.tb_writer.add_histogram( - f"Grad/param{i:04}", - p.grad, - self.epoch, - ) - - self.tb_writer.add_scalar( - "lr/lr", - self.optimizer.state_dict()["param_groups"][0]["lr"], - self.epoch, - ) - - if self.verbose: - print("{} Loss : {:.4f}".format("train", epoch_loss)) - print("{} Acc : {:.4f}".format("train", epoch_acc)) - print( - "TRAIN EPOCH corrects: %f | total: %f" - % (running_corrects, running_total) - ) - - @torch.no_grad() - def val_epoch(self): - """Perform a pass through the validation data. - Do not record gradients to speed things up. - """ - self.model.train(False) - i = 0 - running_loss = 0.0 - running_corrects = 0 - running_total = 0 - - btrans = self.batch_transformers.get("val", None) - for data in self.dataloaders["val"]: - # if a batch transformer is present, - # transform the data before use - if btrans is not None: - data = btrans(data) - - inputs = data["input"] - labels = data["output"] # one-hot - if self.use_gpu: - inputs = inputs.cuda() - labels = labels.cuda() - else: - pass - - # zero gradients - self.optimizer.zero_grad() - # forward pass - outputs = self.model(inputs) - _, predictions = torch.max(outputs, 1) - - # remake an integer version of the labels for quick checking - int_labels = torch.argmax(labels, 1) - correct = torch.sum(predictions.detach() == int_labels.detach()) - if self.verbose > 1: - print("PRED\n", predictions[:10, ...]) - print("LABEL\n", int_labels[:10, ...]) - print("CORRECT: ", correct) - - if self.reg_criterion is not None: - reg_loss = self.reg_criterion(self.model) - loss = self.criterion(outputs, labels) + reg_loss - else: - loss = self.criterion(outputs, labels) - - # statistics update - running_loss += loss.item() / inputs.size(0) - running_corrects += int(correct.item()) - running_total += int(labels.size(0)) - - if i % 1 == 10 and self.verbose > 1: - print("Iter : ", i) - print("running_loss : ", running_loss / (i + 1)) - print("running_acc : ", running_corrects / running_total) - print("corrects: %f | total: %f" % (running_corrects, running_total)) - # append to log - with open(self.log_path, "a") as f: - f.write( - str(self.epoch) - + "," - + str(i) - + "," - + str(running_loss / (i + 1)) - + ",val\n" - ) - i += 1 - - epoch_loss = running_loss / len(self.dataloaders["val"]) - epoch_acc = running_corrects / running_total - # append to log - with open(self.log_path, "a") as f: - f.write( - str(self.epoch) - + "," - + str(i) - + "," - + str(running_loss / (i + 1)) - + ",val_epoch\n" - ) - - # add one epoch to the waiting time for best loss - # if we had a new best loss, the counter is reset below - self.waiting_time += 1 - if (epoch_loss < self.best_loss) and (self.epoch >= self.min_epochs): - self.best_loss = epoch_loss - self.best_model_wts = self.model.state_dict() - self.waiting_time = 0 - torch.save( - self.model.state_dict(), - os.path.join( - self.out_path, - ("model_weights_" + str(self.epoch).zfill(3) + ".pkl"), - ), - ) - print("Saving best model weights...") - torch.save( - self.model.state_dict(), - os.path.join(self.out_path, "00_best_model_weights.pkl"), - ) - print("Saved best weights.") - - if hasattr(self, "dan_criterion"): - print("Trainer has a `dan_criterion`.") - if self.dan_criterion is not None: - print("Saving DAN weights...") - torch.save( - self.dan_criterion.dann.state_dict(), - os.path.join( - self.out_path, - "02_best_dan_weights.pkl", - ), - ) - - with open(self.log_path, "a") as f: - f.write( - str(self.epoch) - + "," - + str(i) - + "," - + str(running_loss / (i + 1)) - + ",best_model_weights\n", - ) - - if self.tb_writer is not None: - self.tb_writer.add_text( - "BestWeights", - f"Saved best weights at {self.epoch}, loss {epoch_loss}", - self.epoch, - ) - self.tb_writer.flush() - - elif self.epoch % self.save_freq == 0: - torch.save( - self.model.state_dict(), - os.path.join( - self.out_path, - "model_weights_" + str(self.epoch).zfill(3) + ".pkl", - ), - ) - - elif self.epoch == (self.n_epochs - 1): - torch.save( - self.model.state_dict(), - os.path.join(self.out_path, "01_final_model_weights.pkl"), - ) - if self.verbose: - print(f"{self.waiting_time} epochs since last best weights.\n") - - if self.tb_writer is not None: - self.tb_writer.add_scalar("Loss/val", epoch_loss, self.epoch) - self.tb_writer.add_scalar("Acc/val", epoch_acc, self.epoch) - self.tb_writer.flush() - - if self.verbose: - print("{} Loss : {:.4f}".format("val", epoch_loss)) - print("{} Acc : {:.4f}".format("val", epoch_acc)) - print( - "VAL EPOCH corrects: %f | total: %f" % (running_corrects, running_total) - ) - - def train(self): - for epoch in range(self.n_epochs): - self.epoch = epoch - msg = f"Epoch {epoch}/{self.n_epochs-1}" - p_complete = epoch / self.n_epochs - n_bars = int(np.floor(30 * p_complete)) - msg += "|" + "-" * n_bars + "_" * (30 - n_bars) + "|" - # print a new line so the progress bar isn't overwritten - # on the final stdout - end_char = "\n" if epoch == (self.n_epochs - 1) else "\r" - print(msg, end=end_char) - - # training epoch - self.train_epoch() - # evaluate model - self.val_epoch() - - # update learning rate - # NOTE: change in `torch>=1.1.0`, `scheduler.step()` - # is now called AFTER `optimizer.step()` - if self.scheduler is not None: - self.scheduler.step() - - if self.waiting_time > self.patience: - # we have waited a sufficient number of epochs - # to perform early stopping - logger.info(">" * 5) - logger.info(f"Early stopping at epoch {self.epoch}") - logger.info(">" * 5) - break - - self.model.load_state_dict( - torch.load( - os.path.join( - self.out_path, - "00_best_model_weights.pkl", - ) - ) - ) - - if self.tb_writer is not None: - # close tensorboard writer - self.tb_writer.flush() - self.tb_writer.close() - - return self.model - - -class SemiSupervisedTrainer(Trainer): - def __init__( - self, - unsup_criterion: Callable, - unsup_dataloader: torch.utils.data.DataLoader, - unsup_weight: Callable, - dan_criterion: Callable = None, - dan_weight: Callable = None, - **kwargs, - ) -> None: - """Train a PyTorch model using both a supervised and - unsupervised loss as described for Interpolation - Consistency Training. - - Parameters - ---------- - unsup_criterion : Callable - loss function for unlabeled samples. - takes both the current `nn.Module` model and a `torch.FloatTensor` - of unlabeled samples as input. - unsup_dataloader : torch.utils.data.DataLoader - data loader supplying unlabeled samples. - unsup_weight : Callable - takes an int epoch as input and returns a weight coefficient - to scale the importance of the unsupervised loss. - dan_criterion : Callable, optional - domain adaptation loss. takes in a model, labeled batch, and - unlabeled batch, and returns a `torch.Tensor` loss value. - dan_weight : Callable, optional - domain adaptation loss weight schedule. - takes an int epoch as input and returns a weight coefficient. - - Returns - ------- - None. - """ - super(SemiSupervisedTrainer, self).__init__(**kwargs) - self.unsup_criterion = unsup_criterion - self.unsup_dataloader = unsup_dataloader - self.unsup_weight = unsup_weight - self.dan_criterion = dan_criterion - if self.dan_criterion is not None: - print("Using a Domain Adaptation Loss.") - self.dan_weight = dan_weight - return - - def train_epoch( - self, - ) -> None: - """ - Perform training using both a supervised and semi-supervised loss. - - Notes - ----- - (1) Sample labeled examples, compute the standard supervised loss. - (2) Sample unlabeled examples, compute unsupervised loss. - (3) Perform backward pass and update parameters. - """ - self.model.train(True) - i = 0 - running_loss = 0.0 - running_sup_loss = 0.0 # supervised loss - running_uns_loss = 0.0 # unsupervised loss - running_dom_loss = 0.0 # domain adaptation loss - running_corrects = 0.0 - running_total = 0.0 - - btrans = self.batch_transformers.get("train", None) - - iter_unsup_dl = iter(self.unsup_dataloader) - for data in self.dataloaders["train"]: - - #################################### - # (1) Prepare data and graph - #################################### - - # get unlabeled batch - unsup_data = next(iter_unsup_dl) - - if btrans is not None: - data = btrans(data) - - if self.use_gpu: - # push all the data to the CUDA device - data["input"] = data["input"].cuda() - data["output"] = data["output"].cuda() - - unsup_data["input"] = unsup_data["input"].cuda() - - # capture gradients on labeled and unlabeled inputs - # do not store gradients on labels - data["input"].requires_grad = True - data["output"].requires_grad = False - - unsup_data["input"].requires_grad = True - - # zero gradients across the graph - self.optimizer.zero_grad() - - #################################### - # (2) Compute loss terms - #################################### - - sup_loss, unsup_loss, sup_outputs = self.unsup_criterion( - model=self.model, - labeled_sample=data, - unlabeled_sample=unsup_data, - ) - - # check supervised classification accuracy - _, predictions = torch.max(sup_outputs, 1) - int_labels = torch.argmax(data["output"], 1) - - correct = torch.sum(predictions.detach() == int_labels.detach()) - - # compute regularization loss - if self.reg_criterion is not None: - reg_loss = self.reg_criterion(self.model) - else: - reg_loss = 0.0 - - # compute the domain adaptation loss if desired - if self.dan_criterion is not None: - dan_weight = self.dan_weight(self.epoch) - # NOTE: pseudolabel confidence is only used if `use_conf_pseudolabels` - # was passed to the initiatilization of `DANLoss` - pseudolabel_confidence = self.unsup_criterion.running_confidence_scores[ - -1 - ][0] - dan_loss = self.dan_criterion( - labeled_sample=data, - unlabeled_sample=unsup_data, - weight=dan_weight, - pseudolabel_confidence=pseudolabel_confidence, - ) - else: - dan_loss = torch.zeros( - 1, - ).float() - dan_loss = dan_loss.to(device=sup_loss.device) - dan_weight = 0.0 - - #################################### - # (3) Perform backward pass - #################################### - - loss = ( - sup_loss - + reg_loss - + (self.unsup_weight(self.epoch) * unsup_loss) - + dan_loss - ) - - if self.verbose > 1: - print("sup. loss: ", sup_loss.item()) - print("usup. loss: ", unsup_loss.item()) - print("usup. weight: ", self.unsup_weight(self.epoch)) - if self.dan_criterion is not None: - print("Dom. loss: ", dan_loss.item()) - print("Dom. weight: ", dan_weight) - print("total loss: ", loss.item()) - if np.isnan(loss.data.cpu().numpy()): - raise RuntimeError("NaN loss encountered in training") - - # compute gradients in a backward pass, update parameters - loss.backward() - self.optimizer.step() - - # statistics update - labeled_n = data["input"].size(0) - unlabel_n = unsup_data["input"].size(0) - - running_loss += loss.item() - running_sup_loss += sup_loss.item() - running_uns_loss += unsup_loss.item() - running_dom_loss += dan_loss.item() - running_corrects += float(correct.item()) - running_total += float(data["input"].size(0)) - - if i % 100 == 0 and self.verbose: - print("Iter : ", i) - print("running_sup_loss : ", running_sup_loss / (i + 1)) - print("running_uns_loss : ", running_uns_loss / (i + 1)) - print("running_dom_loss : ", running_dom_loss / (i + 1)) - print("running_loss : ", running_loss / (i + 1)) - print("running_acc : ", running_corrects / running_total) - print("corrects: %f | total: %f" % (running_corrects, running_total)) - # append to log - with open(self.log_path, "a") as f: - f.write( - str(self.epoch) - + "," - + str(i) - + "," - + str(running_loss / (i + 1)) - + ",train\n" - ) - i += 1 - - epoch_sup_loss = running_sup_loss / len(self.dataloaders["train"]) - epoch_uns_loss = running_uns_loss / len(self.dataloaders["train"]) - epoch_dom_loss = running_dom_loss / len(self.dataloaders["train"]) - epoch_loss = running_loss / len(self.dataloaders["train"]) - epoch_acc = running_corrects / running_total - - if self.tb_writer is not None: - self.tb_writer.add_scalar( - "Loss/train", - epoch_loss, - self.epoch, - ) - self.tb_writer.add_scalar( - "Acc/train", - epoch_acc, - self.epoch, - ) - self.tb_writer.add_scalar( - "Loss/super", - epoch_sup_loss, - self.epoch, - ) - self.tb_writer.add_scalar( - "Loss/unsup", - epoch_uns_loss, - self.epoch, - ) - self.tb_writer.add_scalar( - "SSL/UnsWeight", - self.unsup_weight(self.epoch), - self.epoch, - ) - if self.dan_criterion is not None: - self.tb_writer.add_scalar( - "Loss/domain", - epoch_dom_loss, - self.epoch, - ) - self.tb_writer.add_scalar( - "SSL/DomWeight", - self.dan_weight(self.epoch), - self.epoch, - ) - - # add embedding - dlabel = self.dan_criterion.dlabel.numpy() - self.tb_writer.add_embedding( - self.dan_criterion.x_embed, - metadata=dlabel.tolist(), - global_step=self.epoch, - tag="Embed/DAN", - ) - - # compute the entropy of mixing - dan_embedding = self.dan_criterion.x_embed.numpy() - - eom = compute_entropy_of_mixing( - X=dan_embedding, - y=dlabel[:, 0], - n_neighbors=100, - n_iters=512, - n_jobs=-1, - ) - self.tb_writer.add_scalar( - "SSL/entropy_of_mixing", - np.mean(eom), - self.epoch, - ) - self.tb_writer.add_histogram( - "SSL/dist_entropy_of_mixing", - eom, - self.epoch, - ) - self.tb_writer.add_scalar( - "SSL/domain_acc", - self.dan_criterion.dan_acc, - self.epoch, - ) - - for i, param in enumerate( - self.dan_criterion.dann.domain_clf.parameters() - ): - self.tb_writer.add_histogram( - f"Grad/domain_clf_{i:04}", - param.grad, - self.epoch, - ) - self.tb_writer.add_scalar( - "SSL/dan_n_conf_pseudolabels", - self.dan_criterion.n_conf_pseudolabels, - self.epoch, - ) - self.tb_writer.add_scalar( - "SSL/dan_p_conf_pseudolabels", - self.dan_criterion.n_conf_pseudolabels - / self.dan_criterion.n_total_unlabeled, - self.epoch, - ) - - self.tb_writer.flush() - - for i, named_mod in enumerate(self.model.classif.named_modules()): - module_name = named_mod[0] - module = named_mod[1] - for j, param in enumerate(module.parameters()): - self.tb_writer.add_histogram( - f"Grad/{module_name}/{j:04}", - param.grad, - self.epoch, - ) - - # add the running confidence scores of unlabeled examples - # if we're using MixMatch - if hasattr(self.unsup_criterion, "running_confidence_scores"): - # get the number of confident pseudolabels - # and the total number of pseudolabels per batch - n_conf = torch.Tensor( - [ - torch.sum(s[0]).item() - for s in self.unsup_criterion.running_confidence_scores - ] - ) - n_total = torch.Tensor( - [ - s[0].size(0) - for s in self.unsup_criterion.running_confidence_scores - ] - ) - conf_dist = torch.cat( - [s[1] for s in self.unsup_criterion.running_confidence_scores], - dim=0, - ) - self.tb_writer.add_scalar( - "SSL/p_conf_pseudolabels", - torch.sum(n_conf) / torch.sum(n_total), - self.epoch, - ) - self.tb_writer.add_scalar( - "SSL/avg_pseudolabel_conf", - torch.mean(conf_dist), - self.epoch, - ) - self.tb_writer.add_histogram( - "SSL/dist_p_conf_pseudolabels", - n_conf / n_total, - self.epoch, - ) - self.tb_writer.add_histogram( - "SSL/pseudolabel_conf", - conf_dist, - self.epoch, - ) - - # append to log - with open(self.log_path, "a") as f: - f.write( - str(self.epoch) - + "," - + str(i) - + "," - + str(epoch_loss) - + ",train_epoch\n" - ) - # write out the supervised and unsupervised components - # of loss separately - f.write( - str(self.epoch) - + "," - + str(i) - + "," - + str(epoch_sup_loss) - + ",train_epoch_sup\n" - ) - f.write( - str(self.epoch) - + "," - + str(i) - + "," - + str(epoch_uns_loss) - + ",train_epoch_uns\n" - ) - f.write( - str(self.epoch) - + "," - + str(i) - + "," - + str(self.unsup_weight(self.epoch)) - + ",train_epoch_uns_weight\n" - ) - if self.verbose: - print("{} Sup. Loss : {:.6f}".format("train", epoch_sup_loss)) - print("{} Unsup. Loss : {:.6f}".format("train", epoch_uns_loss)) - print( - "{} Unsup. Weight : {:.6f}".format( - "train", self.unsup_weight(self.epoch) - ) - ) - if self.dan_criterion is not None: - print("{} Dom. Loss : {:.6f}".format("train", epoch_dom_loss)) - print(f"train Dom. Weight : {self.dan_weight(self.epoch)}") - print("{} Loss : {:.4f}".format("train", epoch_loss)) - print("{} Acc : {:.4f}".format("train", epoch_acc)) - print( - "TRAIN EPOCH corrects: %f | total: %f" - % (running_corrects, running_total) - ) - return - - -class MultiTaskTrainer(Trainer): - def __init__( - self, - criteria: List[dict], - unsup_dataloader: torch.utils.data.DataLoader = None, - **kwargs, - ) -> None: - """Train a multitask model with multiple criteria using - labeled and unlabeled dataloaders. - - Parameters - ---------- - criteria : List[dict] - dictionary describing a single task criterion, containing keys. - function - callable with `dict` kwargs `labeled_sample` - and `unlabeled_sample`, `nn.Module` kwarg `model`, - a `float` kwarg `weight`, and returns `torch.FloatTensor`. - weight - Callable, maps `int` epoch to `float` weight. - can also pass float value for constant weight. - validation - bool, use criterion for validation loss. - unsup_dataloader : torch.utils.data.DataLoader - data loader supplying unlabeled samples. - **kwargs : dict - passed to `Trainer` parent. Include: - model - nn.Module - criterion - Callable - optimizer - torch.optim.Optimizer - dataloaders - dict - out_path - str - n_epochs - int - min_epochs - int - patience - int - use_gpu - bool - scheduler - torch.optim.lr_scheduler - - Returns - ------- - None. - - Notes - ----- - criteria are applied sequentially, such that values extracted in one - criterion can be added to the dictionary and used in another. - if a criterion has a `no_weight=True` attribute, loss weights are not - applied in the train loop (useful for DAN, weights applied to rev'd grads). - all criteria should implement a `.train(bool)` method, even if they do not - contain trainable parameters. - """ - kwargs.update({"criterion": None}) - super(MultiTaskTrainer, self).__init__(**kwargs) - - self.criteria = criteria - # check that criteria provided are actually callable - for c in self.criteria: - fxn = c.get("function", None) - weight = c.get("weight", None) - if not callable(fxn): - msg = "One of the criteria provided is not callable.\n" - msg += f"\t{fxn}" - raise ValueError(fxn) - - if not callable(weight) and type(weight) != float: - msg = 'One of the criteria did not include a `"weight"` property.\n' - msg += f"\t{fxn}\n" - msg += f"\tweight : {weight}" - raise ValueError(msg) - - self.unsup_dataloader = unsup_dataloader - self.best_weights = None - return - - def train_epoch( - self, - ) -> float: - """Perform a training loop by evaluating all the criteria - in `self.criteria` sequentially, then computing the weighted - loss and backproping.""" - - self.model.train(True) - - i = 0 - # setup running values for all losses - running_losses = np.zeros(len(self.criteria)) - - btrans = self.batch_transformers.get("train", None) - - if self.unsup_dataloader is not None: - iter_unsup_dl = iter(self.unsup_dataloader) - - for data in self.dataloaders["train"]: - - #################################### - # (1) Prepare data and graph - #################################### - - if btrans is not None: - data = btrans(data) - - if self.use_gpu: - # push all the data to the CUDA device - data["input"] = data["input"].cuda() - data["output"] = data["output"].cuda() - - # get unlabeled batch - if self.unsup_dataloader is not None: - unsup_data = next(iter_unsup_dl) - unsup_data["input"] = unsup_data["input"].to( - device=data["input"].device, - ) - # unsup_data["input"].requires_grad = True - else: - unsup_data = None - - # capture gradients on labeled and unlabeled inputs - # do not store gradients on labels - # data["input"].requires_grad = True - # data["output"].requires_grad = False - - # zero gradients across the graph - self.optimizer.zero_grad() - - #################################### - # (2) Compute loss terms - #################################### - - loss = torch.zeros( - 1, - ).to(device=data["input"].device) - for crit_idx, crit_dict in enumerate(self.criteria): - - crit_fxn = crit_dict["function"] - weight_fxn = crit_dict["weight"] - - crit_name = crit_fxn.__class__.__name__ - crit_name = crit_dict.get("name", crit_name) - logger.debug(f"Computing criterion: {crit_name}") - - # get the current weight from the weight function, - # or use the constant weight value - weight = weight_fxn(self.epoch) if callable(weight_fxn) else weight_fxn - # prepare crit_fxn for loss computation - crit_fxn.train(True) - if hasattr(crit_fxn, "epoch"): - # update the epoch attribute for use by any internal functions - crit_fxn.epoch = self.epoch - - crit_loss = crit_fxn( - labeled_sample=data, - unlabeled_sample=unsup_data, - model=self.model, - weight=weight, - ) - - if hasattr(crit_fxn, "no_weight"): - # don't reweight the loss, already performed - # internally in the criterion - weight = 1.0 - - logger.debug(f"crit_loss: {crit_loss}") - logger.debug(f"weight: {weight}") - - # weight losses and accumulate - weighted_crit_loss = crit_loss * weight - logger.debug(f"weighted_crit_loss: {weighted_crit_loss}") - logger.debug(f"loss: {loss}, type {type(loss)}") - - loss += weighted_crit_loss - - running_losses[crit_idx] += crit_loss.item() - if self.verbose: - logger.debug(f"weight {crit_name} : {weight}") - logger.debug(f"batch {crit_name} : {weighted_crit_loss}") - - # backprop - loss.backward() - # update parameters - self.optimizer.step() - - # perform logging - n_batches = len(self.dataloaders["train"]) - - epoch_losses = running_losses / n_batches - - if self.verbose: - for crit_idx, crit_dict in enumerate(self.criteria): - crit_name = crit_dict["function"].__class__.__name__ - # get a stored name if it exists - crit_name = crit_dict.get("name", crit_name) - logger.info(f"{crit_name}: {epoch_losses[crit_idx]}") - - if self.tb_writer is not None: - for crit_idx in range(len(self.criteria)): - crit_dict = self.criteria[crit_idx] - crit_name = crit_dict["function"].__class__.__name__ - crit_name = crit_dict.get("name", crit_name) - self.tb_writer.add_scalar( - "loss/" + crit_name, - float(epoch_losses[crit_idx]), - self.epoch, - ) - weight_fxn = crit_dict["weight"] - weight = weight_fxn(self.epoch) if callable(weight_fxn) else weight_fxn - self.tb_writer.add_scalar( - "weight/" + crit_name, - float(weight), - self.epoch, - ) - - return np.sum(epoch_losses) - - @torch.no_grad() - def val_epoch(self): - """Perform a pass through the validation data.""" - self.model.train(False) - i = 0 - running_losses = np.zeros(len(self.criteria)) - running_corrects = 0 - running_total = 0 - - if self.unsup_dataloader is not None: - iter_unsup_dl = iter(self.unsup_dataloader) - - btrans = self.batch_transformers.get("val", None) - for data in self.dataloaders["val"]: - - # if a batch transformer is present, - # transform the data before use - if btrans is not None: - data = btrans(data) - - if self.use_gpu: - data["input"] = data["input"].cuda() - data["output"] = data["output"].cuda() - - if self.unsup_dataloader is not None: - unsup_data = next(iter_unsup_dl) - unsup_data["input"] = unsup_data["input"].to( - device=data["input"].device - ) - else: - unsup_data = None - - inputs = data["input"] - labels = data["output"] # one-hot - - # zero gradients - self.optimizer.zero_grad() - - # perform a forward pass to get prediction accuracies, regardless - # of what other tasks our model is performing - outputs = self.model(inputs) - _, predictions = torch.max(outputs, 1) - - # remake an integer version of the labels for quick checking - int_labels = torch.argmax(labels, 1) - correct = torch.sum(predictions.detach() == int_labels.detach()).item() - - running_corrects += float(correct) - running_total += int(int_labels.size(0)) - - logger.debug(f"PRED\n{predictions[:10, ...]}") - logger.debug(f"LABEL\n{int_labels[:10, ...]}") - logger.debug(f"CORRECT: {correct}") - - # compute losses - losses = [] - for crit_idx, crit_dict in enumerate(self.criteria): - - if not crit_dict.get("validation", False): - continue - - crit_fxn = crit_dict["function"] - weight_fxn = crit_dict["weight"] - # get the current weight from the weight function, - # or use the constant weight value - weight = weight_fxn(self.epoch) if callable(weight_fxn) else weight_fxn - - crit_fxn.train(False) - crit_loss = crit_fxn( - labeled_sample=data, - unlabeled_sample=unsup_data, - model=self.model, - weight=weight, - ) - - crit_name = crit_fxn.__class__.__name__ - - if hasattr(crit_fxn, "no_weight"): - # don't reweight the loss, already performed - # internally in the criterion - weight = 1.0 - # weight losses and accumulate - weighted_crit_loss = crit_loss * weight - losses.append(weighted_crit_loss) - running_losses[crit_idx] += weighted_crit_loss.item() - - logger.debug(f"{crit_name}: {crit_loss}") - logger.debug(f"\tweight : {weight}") - logger.debug(f"weighted {crit_name}: {weighted_crit_loss}") - - epoch_losses = running_losses / len(self.dataloaders["val"]) - epoch_acc = running_corrects / running_total - - epoch_loss = np.sum(epoch_losses) - - # append to log - with open(self.log_path, "a") as f: - f.write( - str(self.epoch) - + "," - + str(i) - + "," - + str(epoch_loss / (i + 1)) - + ",val_epoch\n" - ) - - # add one epoch to the waiting time for best loss - # if we had a new best loss, the counter is reset below - self.waiting_time += 1 - if (epoch_loss < self.best_loss) and (self.epoch >= self.min_epochs): - self.best_loss = epoch_loss - self.waiting_time = 0 - torch.save( - self.model.state_dict(), - os.path.join(self.out_path, f"model_weights_{self.epoch:03d}.pkl"), - ) - logger.info(f"Saving best model weights, epoch {self.epoch}...") - torch.save( - self.model.state_dict(), - os.path.join(self.out_path, "00_best_model_weights.pkl"), - ) - self.best_weights = copy.deepcopy(self.model.state_dict()) - logger.info("Saved best weights.") - - # also save the best weights of additional model components - for crit_fxn in self.criteria: - if crit_fxn["function"].__class__.__name__ == "DANLoss": - # save DAN weights - logger.info("Saving DAN weights...") - weights = crit_fxn["function"].dann.state_dict() - torch.save( - weights, - os.path.join( - self.out_path, - f"02_best_dan_weights.pkl", - ), - ) - elif crit_fxn["function"].__class__.__name__ == "ReconstructionLoss": - # save AE weights - logger.info("Saving Reconstruction weights...") - weights = crit_fxn["function"].rec_model.state_dict() - torch.save( - weights, - os.path.join( - self.out_path, - f"03_best_reconstruction_weights.pkl", - ), - ) - else: - pass - - with open(self.log_path, "a") as f: - f.write( - str(self.epoch) - + "," - + str(i) - + "," - + str(epoch_loss) - + ",best_model_weights\n", - ) - - if self.tb_writer is not None: - self.tb_writer.add_text( - "BestWeights", - f"Saved best weights at {self.epoch}, loss {epoch_loss}", - self.epoch, - ) - self.tb_writer.flush() - - elif self.epoch % self.save_freq == 0: - torch.save( - self.model.state_dict(), - os.path.join( - self.out_path, - "model_weights_" + str(self.epoch).zfill(3) + ".pkl", - ), - ) - - elif self.epoch == (self.n_epochs - 1): - torch.save( - self.model.state_dict(), - os.path.join(self.out_path, "01_final_model_weights.pkl"), - ) - if self.verbose: - logger.info(f"{self.waiting_time} epochs since last best weights.\n") - - if self.tb_writer is not None: - self.tb_writer.add_scalar("Loss/val", epoch_loss, self.epoch) - self.tb_writer.add_scalar("Acc/val", epoch_acc, self.epoch) - self.tb_writer.flush() - - if self.verbose: - logger.info("{} Loss : {:.4f}".format("val", epoch_loss)) - logger.info("{} Acc : {:.4f}".format("val", epoch_acc)) - logger.info( - "VAL EPOCH corrects: %f | total: %f" % (running_corrects, running_total) - ) - - return epoch_loss - - -"""Loss weight scheduling""" - - -class ICLWeight(object): - def __init__( - self, - ramp_epochs: int, - burn_in_epochs: int = 0, - max_unsup_weight: float = 10.0, - sigmoid: bool = False, - ) -> None: - """Schedules the interpolation consistency loss - weights across a set of epochs. - - Parameters - ---------- - ramp_epochs : int - number of epochs to increase the unsupervised - loss weight until reaching a maximum value. - burn_in_epochs : int - epochs to wait before increasing the unsupervised loss. - max_unsup_weight : float - maximum weight for the unsupervised loss component. - sigmoid : bool - scale weight using a sigmoid function. - - Returns - ------- - None. - """ - self.ramp_epochs = ramp_epochs - self.burn_in_epochs = burn_in_epochs - self.max_unsup_weight = max_unsup_weight - self.sigmoid = sigmoid - # don't allow division by zero, set step size manually - if self.ramp_epochs == 0.0: - self.step_size = self.max_unsup_weight - else: - self.step_size = self.max_unsup_weight / self.ramp_epochs - print( - "Scaling ICL over %d epochs, %d epochs for burn in." - % (self.ramp_epochs, self.burn_in_epochs) - ) - return - - def _get_weight( - self, - epoch: int, - ) -> float: - """Compute the current weight""" - if epoch >= (self.ramp_epochs + self.burn_in_epochs): - weight = self.max_unsup_weight - elif self.sigmoid: - x = (epoch - self.burn_in_epochs) / self.ramp_epochs - coef = np.exp(-5 * (x - 1) ** 2) - weight = coef * self.max_unsup_weight - else: - weight = self.step_size * (epoch - self.burn_in_epochs) - - return weight - - def __call__( - self, - epoch: int, - ) -> float: - """Compute the weight for an unsupervised IC loss - given the epoch. - - Parameters - ---------- - epoch : int - current training epoch. - - Returns - ------- - weight : float - weight for the unsupervised component of IC loss. - """ - if type(epoch) != int: - raise TypeError(f"epoch must be int, you passed a {type(epoch)}") - if epoch < self.burn_in_epochs: - weight = 0.0 - else: - weight = self._get_weight(epoch) - return weight diff --git a/build/lib/scnym/utils.py b/build/lib/scnym/utils.py deleted file mode 100644 index 1e4cab1..0000000 --- a/build/lib/scnym/utils.py +++ /dev/null @@ -1,743 +0,0 @@ -""" -Utility functions -""" -import torch -import numpy as np -import anndata -from scipy import sparse -import pandas as pd -import tqdm -from scipy import stats -import scanpy as sc -from sklearn.neighbors import NearestNeighbors, KNeighborsRegressor -from sklearn.metrics.pairwise import euclidean_distances -from typing import Union, Callable - - -def make_one_hot( - labels: torch.LongTensor, - C=2, -) -> torch.FloatTensor: - """ - Converts an integer label torch.autograd.Variable to a one-hot Variable. - - Parameters - ---------- - labels : torch.LongTensor or torch.cuda.LongTensor - [N, 1], where N is batch size. - Each value is an integer representing correct classification. - C : int - number of classes in labels. - - Returns - ------- - target : torch.FloatTensor or torch.cuda.FloatTensor - [N, C,], where C is class number. One-hot encoded. - """ - if labels.ndimension() < 2: - labels = labels.unsqueeze(1) - one_hot = torch.zeros( - [ - labels.size(0), - C, - ], - dtype=torch.float32, - device=labels.device, - ) - target = one_hot.scatter_(1, labels, 1) - - return target - - -def l1_layer0( - model: torch.nn.Module, -) -> torch.FloatTensor: - """Compute l1 norm for the first input layer of - a `CellTypeCLF` model. - - Parameters - ---------- - model : torch.nn.Module - CellTypeCLF model with `.classif` module. - - Returns - ------- - l1_reg : torch.FloatTensor - [1,] l1 norm for the first layer parameters. - """ - # get the parameters of the first classification layer - layer0 = list(model.classif.modules())[1] - params = layer0.parameters() - l1_reg = None - - # compute the l1_norm - for W in params: - if l1_reg is None: - l1_reg = W.norm(1) - else: - l1_reg = l1_reg + W.norm(1) - return l1_reg - - -def append_categorical_to_data( - X: Union[np.ndarray, sparse.csr.csr_matrix], - categorical: np.ndarray, -) -> (Union[np.ndarray, sparse.csr.csr_matrix], np.ndarray): - """Convert `categorical` to a one-hot vector and append - this vector to each sample in `X`. - - Parameters - ---------- - X : np.ndarray, sparse.csr.csr_matrix - [Cells, Features] - categorical : np.ndarray - [Cells,] - - Returns - ------- - Xa : np.ndarray - [Cells, Features + N_Categories] - categories : np.ndarray - [N_Categories,] str category descriptors. - """ - # `pd.Categorical(xyz).codes` are int values for each unique - # level in the vector `xyz` - labels = pd.Categorical(categorical) - idx = np.array(labels.codes) - idx = torch.from_numpy(idx.astype("int32")).long() - categories = np.array(labels.categories) - - one_hot_mat = make_one_hot( - idx, - C=len(categories), - ) - one_hot_mat = one_hot_mat.numpy() - assert X.shape[0] == one_hot_mat.shape[0], "dims unequal at %d, %d" % ( - X.shape[0], - one_hot_mat.shape[0], - ) - # append one hot vector to the [Cells, Features] matrix - if sparse.issparse(X): - X = sparse.hstack([X, one_hot_mat]) - else: - X = np.concatenate([X, one_hot_mat], axis=1) - return X, categories - - -def get_adata_asarray( - adata: anndata.AnnData, -) -> Union[np.ndarray, sparse.csr.csr_matrix]: - """Get the gene expression matrix `.X` of an - AnnData object as an array rather than a view. - - Parameters - ---------- - adata : anndata.AnnData - [Cells, Genes] AnnData experiment. - - Returns - ------- - X : np.ndarray, sparse.csr.csr_matrix - [Cells, Genes] `.X` attribute as an array - in memory. - - Notes - ----- - Returned `X` will match the type of `adata.X` view. - """ - if sparse.issparse(adata.X): - X = sparse.csr.csr_matrix(adata.X) - else: - X = np.array(adata.X) - return X - - -def build_classification_matrix( - X: Union[np.ndarray, sparse.csr.csr_matrix], - model_genes: np.ndarray, - sample_genes: np.ndarray, - gene_batch_size: int = 512, -) -> Union[np.ndarray, sparse.csr.csr_matrix]: - """ - Build a matrix for classification using only genes that overlap - between the current sample and the pre-trained model. - - Parameters - ---------- - X : np.ndarray, sparse.csr_matrix - [Cells, Genes] count matrix. - model_genes : np.ndarray - gene identifiers in the order expected by the model. - sample_genes : np.ndarray - gene identifiers for the current sample. - gene_batch_size : int - number of genes to copy between arrays per batch. - controls a speed vs. memory trade-off. - - Returns - ------- - N : np.ndarray, sparse.csr_matrix - [Cells, len(model_genes)] count matrix. - Values where a model gene was not present in the sample are left - as zeros. `type(N)` will match `type(X)`. - """ - # check types - if type(X) not in (np.ndarray, sparse.csr.csr_matrix): - msg = f"X is type {type(X)}, must `np.ndarray` or `sparse.csr_matrix`" - raise TypeError(msg) - n_cells = X.shape[0] - # check if gene names already match exactly - if len(model_genes) == len(sample_genes): - if np.all(model_genes == sample_genes): - print("Gene names match exactly, returning input.") - return X - - # instantiate a new [Cells, model_genes] matrix where columns - # retain the order used during training - if type(X) == np.ndarray: - N = np.zeros((n_cells, len(model_genes))) - else: - # use sparse matrices if the input is sparse - N = sparse.lil_matrix( - ( - n_cells, - len(model_genes), - ) - ) - - # map gene indices from the model to the sample genes - model_genes_indices = [] - sample_genes_indices = [] - common_genes = 0 - for i, g in tqdm.tqdm(enumerate(sample_genes), desc="mapping genes"): - if np.sum(g == model_genes) > 0: - model_genes_indices.append(int(np.where(g == model_genes)[0])) - sample_genes_indices.append( - i, - ) - common_genes += 1 - - # copy the data in batches to the new array to avoid memory overflows - gene_idx = 0 - n_batches = int(np.ceil(N.shape[1] / gene_batch_size)) - for b in tqdm.tqdm(range(n_batches), desc="copying gene batches"): - model_batch_idx = model_genes_indices[gene_idx : gene_idx + gene_batch_size] - sample_batch_idx = sample_genes_indices[gene_idx : gene_idx + gene_batch_size] - N[:, model_batch_idx] = X[:, sample_batch_idx] - gene_idx += gene_batch_size - - if sparse.issparse(N): - # convert to `csr` from `csc` - N = sparse.csr_matrix(N) - print("Found %d common genes." % common_genes) - return N - - -def knn_smooth_pred_class( - X: np.ndarray, - pred_class: np.ndarray, - grouping: np.ndarray = None, - k: int = 15, -) -> np.ndarray: - """ - Smooths class predictions by taking the modal class from each cell's - nearest neighbors. - - Parameters - ---------- - X : np.ndarray - [N, Features] embedding space for calculation of nearest neighbors. - pred_class : np.ndarray - [N,] array of unique class labels. - groupings : np.ndarray - [N,] unique grouping labels for i.e. clusters. - if provided, only considers nearest neighbors *within the cluster*. - k : int - number of nearest neighbors to use for smoothing. - - Returns - ------- - smooth_pred_class : np.ndarray - [N,] unique class labels, smoothed by kNN. - - Examples - -------- - >>> smooth_pred_class = knn_smooth_pred_class( - ... X = X, - ... pred_class = raw_predicted_classes, - ... grouping = louvain_cluster_groups, - ... k = 15,) - - Notes - ----- - scNym classifiers do not incorporate neighborhood information. - By using a simple kNN smoothing heuristic, we can leverage neighborhood - information to improve classification performance, smoothing out cells - that have an outlier prediction relative to their local neighborhood. - """ - if grouping is None: - # do not use a grouping to restrict local neighborhood - # associations, create a universal pseudogroup `0`. - grouping = np.zeros(X.shape[0]) - - smooth_pred_class = np.zeros_like(pred_class) - for group in np.unique(grouping): - # identify only cells in the relevant group - group_idx = np.where(grouping == group)[0].astype("int") - X_group = X[grouping == group, :] - # if there are < k cells in the group, change `k` to the - # group size - if X_group.shape[0] < k: - k_use = X_group.shape[0] - else: - k_use = k - # compute a nearest neighbor graph and identify kNN - nns = NearestNeighbors( - n_neighbors=k_use, - ).fit(X_group) - dist, idx = nns.kneighbors(X_group) - - # for each cell in the group, assign a class as - # the majority class of the kNN - for i in range(X_group.shape[0]): - classes = pred_class[group_idx[idx[i, :]]] - uniq_classes, counts = np.unique(classes, return_counts=True) - maj_class = uniq_classes[int(np.argmax(counts))] - smooth_pred_class[group_idx[i]] = maj_class - return smooth_pred_class - - -class RBFWeight(object): - def __init__( - self, - alpha: float = None, - ) -> None: - """Generate a set of weights based on distances to a point - with a radial basis function kernel. - - Parameters - ---------- - alpha : float - radial basis function parameter. inverse of sigma - for a standard Gaussian pdf. - - Returns - ------- - None. - """ - self.alpha = alpha - return - - def set_alpha( - self, - X: np.ndarray, - n_max: int = None, - dm: np.ndarray = None, - ) -> None: - """Set the alpha parameter of a Gaussian RBF kernel - as the median distance between points in an array of - observations. - - Parameters - ---------- - X : np.ndarray - [N, P] matrix of observations and features. - n_max : int - maximum number of observations to use for median - distance computation. - dm : np.ndarray, optional - [N, N] distance matrix for setting the RBF kernel parameter. - speeds computation if pre-computed. - - Returns - ------- - None. Sets `self.alpha`. - - References - ---------- - A Kernel Two-Sample Test - Arthur Gretton, Karsten M. Borgwardt, Malte J. Rasch, - Bernhard Schölkopf, Alexander Smola. - JMLR, 13(Mar):723−773, 2012. - http://jmlr.csail.mit.edu/papers/v13/gretton12a.html - """ - if n_max is None: - n_max = X.shape[0] - - if dm is None: - # compute a distance matrix from observations - if X.shape[0] > n_max: - ridx = np.random.choice( - X.shape[0], - size=n_max, - replace=False, - ) - X_p = X[ridx, :] - else: - X_p = X - - dm = euclidean_distances( - X_p, - ) - - upper = dm[np.triu_indices_from(dm, k=1)] - - # overwrite_input = True saves memory by overwriting - # the upper indices in the distance matrix array during - # median computation - sigma = np.median( - upper, - overwrite_input=True, - ) - self.alpha = 1.0 / (2 * (sigma ** 2)) - return - - def __call__( - self, - distances: np.ndarray, - ) -> np.ndarray: - """Generate a set of weights based on distances to a point - with a radial basis function kernel. - - Parameters - ---------- - distances : np.ndarray - [N,] distances used to generate weights. - - Returns - ------- - weights : np.ndarray - [N,] weights from the radial basis function kernel. - - Notes - ----- - We weight distances with a Gaussian RBF. - - .. math:: - - f(r) = \exp -(\alpha r)^2 - - """ - # check that alpha parameter is set - if self.alpha is None: - msg = "must set `alpha` attribute before computing weights.\n" - msg += "use `.set_alpha() method to estimate from data." - raise ValueError(msg) - - # generate weights with an RBF kernel - weights = np.exp(-((self.alpha * distances) ** 2)) - return weights - - -def knn_smooth_pred_class_prob( - X: np.ndarray, - pred_probs: np.ndarray, - names: np.ndarray, - grouping: np.ndarray = None, - k: Union[Callable, int] = 15, - dm: np.ndarray = None, - **kwargs, -) -> np.ndarray: - """ - Smooths class predictions by taking the modal class from each cell's - nearest neighbors. - - Parameters - ---------- - X : np.ndarray - [N, Features] embedding space for calculation of nearest neighbors. - pred_probs : np.ndarray - [N, C] array of class prediction probabilities. - names : np.ndarray, - [C,] names of predicted classes in `pred_probs`. - groupings : np.ndarray - [N,] unique grouping labels for i.e. clusters. - if provided, only considers nearest neighbors *within the cluster*. - k : int - number of nearest neighbors to use for smoothing. - dm : np.ndarray, optional - [N, N] distance matrix for setting the RBF kernel parameter. - speeds computation if pre-computed. - - Returns - ------- - smooth_pred_class : np.ndarray - [N,] unique class labels, smoothed by kNN. - - Examples - -------- - >>> smooth_pred_class = knn_smooth_pred_class_prob( - ... X = X, - ... pred_probs = predicted_class_probs, - ... grouping = louvain_cluster_groups, - ... k = 15,) - - Notes - ----- - scNym classifiers do not incorporate neighborhood information. - By using a simple kNN smoothing heuristic, we can leverage neighborhood - information to improve classification performance, smoothing out cells - that have an outlier prediction relative to their local neighborhood. - """ - if grouping is None: - # do not use a grouping to restrict local neighborhood - # associations, create a universal pseudogroup `0`. - grouping = np.zeros(X.shape[0]) - - smooth_pred_probs = np.zeros_like(pred_probs) - smooth_pred_class = np.zeros(pred_probs.shape[0], dtype="object") - for group in np.unique(grouping): - # identify only cells in the relevant group - group_idx = np.where(grouping == group)[0].astype("int") - X_group = X[grouping == group, :] - y_group = pred_probs[grouping == group, :] - # if k is a Callable, use it to define k for this group - if callable(k): - k_use = k(X_group.shape[0]) - else: - k_use = k - - # if there are < k cells in the group, change `k` to the - # group size - if X_group.shape[0] < k_use: - k_use = X_group.shape[0] - - # set up weights using a radial basis function kernel - rbf = RBFWeight() - rbf.set_alpha( - X=X_group, - n_max=None, - dm=dm, - ) - - if "dm" in kwargs: - del kwargs["dm"] - # fit a nearest neighbor regressor - nns = KNeighborsRegressor( - n_neighbors=k_use, - weights=rbf, - **kwargs, - ).fit(X_group, y_group) - smoothed_probs = nns.predict(X_group) - - smooth_pred_probs[group_idx, :] = smoothed_probs - g_classes = names[np.argmax(smoothed_probs, axis=1)] - smooth_pred_class[group_idx] = g_classes - - return smooth_pred_class - - -def argmax_pred_class( - grouping: np.ndarray, - prediction: np.ndarray, -): - """Assign class to elements in groups based on the - most common predicted class for that group. - - Parameters - ---------- - grouping : np.ndarray - [N,] partition values defining groups to be classified. - prediction : np.ndarray - [N,] predicted values for each element in `grouping`. - - Returns - ------- - assigned_classes : np.ndarray - [N,] class labels based on the most common class assigned - to elements in the group partition. - - Examples - -------- - >>> grouping = np.array([0,0,0,1,1,1,2,2,2,2]) - >>> prediction = np.array(['A','A','A','B','A','B','C','A','B','C']) - >>> argmax_pred_class(grouping, prediction) - np.ndarray(['A','A','A','B','B','B','C','C','C','C',]) - - Notes - ----- - scNym classifiers do not incorporate neighborhood information. - This simple heuristic leverages cluster information obtained by - an orthogonal method and assigns all cells in a given cluster - the majority class label within that cluster. - """ - assert ( - grouping.shape[0] == prediction.shape[0] - ), "`grouping` and `prediction` must be the same length" - groups = sorted(list(set(grouping.tolist()))) - - assigned_classes = np.zeros(grouping.shape[0], dtype="object") - - for i, group in enumerate(groups): - classes, counts = np.unique(prediction[grouping == group], return_counts=True) - majority_class = classes[np.argmax(counts)] - assigned_classes[grouping == group] = majority_class - return assigned_classes - - -def compute_entropy_of_mixing( - X: np.ndarray, - y: np.ndarray, - n_neighbors: int, - n_iters: int = None, - **kwargs, -) -> np.ndarray: - """Compute the entropy of mixing among groups given - a distance matrix. - - Parameters - ---------- - X : np.ndarray - [N, P] feature matrix. - y : np.ndarray - [N,] group labels. - n_neighbors : int - number of nearest neighbors to draw for each iteration - of the entropy computation. - n_iters : int - number of iterations to perform. - if `n_iters is None`, uses every point. - - Returns - ------- - entropy_of_mixing : np.ndarray - [n_iters,] entropy values for each iteration. - - Notes - ----- - The entropy of batch mixing is computed by sampling `n_per_sample` - cells from a local neighborhood in the nearest neighbor graph - and contructing a probability vector based on their group membership. - The entropy of this probability vector is computed as a metric of - intermixing between groups. - - If groups are more mixed, the probability vector will have higher - entropy, and vice-versa. - """ - # build nearest neighbor graph - n_neighbors = min(n_neighbors, X.shape[0]) - nn = NearestNeighbors( - n_neighbors=n_neighbors, - metric="euclidean", - **kwargs, - ) - nn.fit(X) - nn_idx = nn.kneighbors(return_distance=False) - - # define query points - if n_iters is not None: - # don't duplicate points when sampling - n_iters = min(n_iters, X.shape[0]) - - if (n_iters is None) or (n_iters == X.shape[0]): - # sample all points - query_points = np.arange(X.shape[0]) - else: - # subset random query points for entropy - # computation - assert n_iters < X.shape[0] - query_points = np.random.choice( - X.shape[0], - size=n_iters, - replace=False, - ) - - entropy_of_mixing = np.zeros(len(query_points)) - for i, ridx in enumerate(query_points): - # get the nearest neighbors of a point - nn_y = y[nn_idx[ridx, :]] - - nn_y_p = np.zeros(len(np.unique(y))) - for j, v in enumerate(np.unique(y)): - nn_y_p[j] = sum(nn_y == v) - nn_y_p = nn_y_p / nn_y_p.sum() - - # use base 2 to return values in bits rather - # than the default nats - H = stats.entropy(nn_y_p) - entropy_of_mixing[i] = H - return entropy_of_mixing - - -"""Find new cell state based on scNym confidence scores""" - -from sklearn.metrics import calinski_harabasz_score - - -def _optimize_clustering(adata, resolution: list = [0.1, 0.2, 0.3, 0.5, 1.0]): - scores = [] - for r in resolution: - sc.tl.leiden(adata, resolution=r) - s = calinski_harabasz_score(adata.obsm["X_scnym"], adata.obs["leiden"]) - scores.append(s) - cl_opt_df = pd.DataFrame({"resolution": resolution, "score": scores}) - best_idx = np.argmax(cl_opt_df["score"]) - res = cl_opt_df.iloc[best_idx, 0] - sc.tl.leiden(adata, resolution=res) - print("Best resolution: ", res) - return cl_opt_df - - -def find_low_confidence_cells( - adata: anndata.AnnData, - confidence_threshold: float = 0.5, - confidence_key: str = "Confidence", - use_rep: str = "X_scnym", - n_neighbors: int = 15, -) -> pd.DataFrame: - """Find cells with low confidence predictions and suggest a potential - number of cell states within the low confidence cell population. - - Parameters - ---------- - adata : anndata.AnnData - [Cells, Genes] experiment containing an scNym embedding and scNym - confidence scores. - confidence_threshold : float - threshold for low confidence cells. - confidence_key : str - key in `adata.obs` containing confidence scores. - use_rep : str - tensor in `adata.obsm` containing the scNym embedding. - n_neighbors : int - number of nearest neighbors to use for NN graph construction - prior to community detection. - - Returns - ------- - None. - Adds `adata.uns["scNym_low_confidence_cells"]`, a `dict` containing - keys `"cluster_optimization", "n_clusters", "embedding"`. - Adds key to `adata.obs["scNym_low_confidence_cluster"]`. - - Notes - ----- - """ - # identify low confidence cells - adata.obs["scNym Discovery"] = ( - adata.obs[confidence_key] < confidence_threshold - ).astype(bool) - low_conf_bidx = adata.obs["scNym Discovery"] - - # embed low confidence cells - lc_ad = adata[adata.obs["scNym Discovery"], :].copy() - sc.pp.neighbors(lc_ad, use_rep=use_rep, n_neighbors=n_neighbors) - sc.tl.umap(lc_ad, min_dist=0.3) - - cl_opt_df = _optimize_clustering(lc_ad) - - lc_embed = lc_ad.obs.copy() - for k in range(1, 3): - lc_embed[f"UMAP{k}"] = lc_ad.obsm["X_umap"][:, k - 1] - - # set the outputs - adata.uns["scNym_low_confidence_cells"] = { - "cluster_optimization": cl_opt_df, - "n_clusters": len(np.unique(lc_ad.obs["leiden"])), - "embedding": lc_embed, - } - adata.obs["scNym_low_confidence_cluster"] = "High Confidence" - adata.obs.loc[low_conf_bidx, "scNym_low_confidence_cluster",] = lc_ad.obs[ - "leiden" - ].apply(lambda x: f"Low Confidence {x}") - return diff --git a/scnym.egg-info/PKG-INFO b/scnym.egg-info/PKG-INFO deleted file mode 100644 index 7fb0b57..0000000 --- a/scnym.egg-info/PKG-INFO +++ /dev/null @@ -1,46 +0,0 @@ -Metadata-Version: 2.1 -Name: scnym -Version: 0.3.3 -Summary: Semi supervised adversarial network networks for single cell classification -Home-page: http://github.com/calico/scnym -Author: Jacob C. Kimmel, David R. Kelley -Author-email: jacobkimmel+scnym@gmail.com, drk@calicolabs.com -License: Apache -Classifier: Environment :: Console -Classifier: Intended Audience :: Science/Research -Classifier: Topic :: Scientific/Engineering :: Bio-Informatics -Requires-Python: >=3.6 -License-File: LICENSE -Requires-Dist: anndata==0.8.0 -Requires-Dist: ConfigArgParse==1.1 -Requires-Dist: h5py==3.10.0 -Requires-Dist: leidenalg==0.8.0 -Requires-Dist: louvain==0.7.0 -Requires-Dist: numba==0.49.1 -Requires-Dist: numpy==1.21.0 -Requires-Dist: numpy-groupies==0.9.13 -Requires-Dist: pandas==1.5.3 -Requires-Dist: pytest==5.4.1 -Requires-Dist: python-dateutil==2.8.2 -Requires-Dist: PyYAML==5.3.1 -Requires-Dist: requests==2.26.0 -Requires-Dist: requests-cache==0.5.2 -Requires-Dist: requests-oauthlib==1.3.0 -Requires-Dist: requests-toolbelt==0.9.1 -Requires-Dist: matplotlib==3.6.3 -Requires-Dist: scanpy==1.6.0 -Requires-Dist: scikit-learn==0.22.2.post1 -Requires-Dist: scikit-misc==0.1.3 -Requires-Dist: scipy==1.4.1 -Requires-Dist: six==1.14.0 -Requires-Dist: tensorboard==2.2.1 -Requires-Dist: tensorboard-plugin-wit==1.6.0.post2 -Requires-Dist: tensorboardX==2.1 -Requires-Dist: torch==1.4.0 -Requires-Dist: torchvision==0.5.0 -Requires-Dist: tqdm==4.44.1 -Requires-Dist: umap-learn==0.3.10 -Requires-Dist: urllib3==1.26.6 -Requires-Dist: protobuf==3.20.* - -scNym uses the semi-supervised MixMatch framework and domain adversarial training to take advantage of information in both the labeled and unlabeled datasets. diff --git a/scnym.egg-info/SOURCES.txt b/scnym.egg-info/SOURCES.txt deleted file mode 100644 index 1e4a6c6..0000000 --- a/scnym.egg-info/SOURCES.txt +++ /dev/null @@ -1,60 +0,0 @@ -LICENSE -README.md -VERSION -demo_script.sh -requirements.txt -setup.py -.github/workflows/python-package.yml -assets/processed_data.md -assets/scnym_icon.png -assets/scnym_mixmatch_diagram.png -assets/scnym_mmdan_diagram.png -baseline/README.md -baseline/baseline.R -baseline/baseline.py -configs/default_config.txt -notebooks/scnym_classif_tutorial.ipynb -scnym/__init__.py -scnym/__main__.py -scnym/api.py -scnym/attributionpriors.py -scnym/dataprep.py -scnym/distributions.py -scnym/interpret.py -scnym/losses.py -scnym/main.py -scnym/model.py -scnym/predict.py -scnym/scnym_ad.py -scnym/trainer.py -scnym/utils.py -scnym.egg-info/PKG-INFO -scnym.egg-info/SOURCES.txt -scnym.egg-info/dependency_links.txt -scnym.egg-info/entry_points.txt -scnym.egg-info/requires.txt -scnym.egg-info/top_level.txt -scnym/__pycache__/__init__.cpython-38.pyc -scnym/__pycache__/api.cpython-38.pyc -scnym/__pycache__/attributionpriors.cpython-38.pyc -scnym/__pycache__/dataprep.cpython-38.pyc -scnym/__pycache__/distributions.cpython-38.pyc -scnym/__pycache__/interpret.cpython-38.pyc -scnym/__pycache__/losses.cpython-38.pyc -scnym/__pycache__/main.cpython-38.pyc -scnym/__pycache__/model.cpython-38.pyc -scnym/__pycache__/predict.cpython-38.pyc -scnym/__pycache__/trainer.cpython-38.pyc -scnym/__pycache__/utils.cpython-38.pyc -tests/test_api.py -tests/test_da.py -tests/test_dataprep.py -tests/test_guide.py -tests/test_interpret.py -tests/test_main.py -tests/test_mixmatch.py -tests/test_model.py -tests/test_multitask.py -tests/test_reconstruction.py -tests/test_trainer.py -tests/test_utils.py \ No newline at end of file diff --git a/scnym.egg-info/dependency_links.txt b/scnym.egg-info/dependency_links.txt deleted file mode 100644 index 8b13789..0000000 --- a/scnym.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/scnym.egg-info/entry_points.txt b/scnym.egg-info/entry_points.txt deleted file mode 100644 index 5d7f4c0..0000000 --- a/scnym.egg-info/entry_points.txt +++ /dev/null @@ -1,3 +0,0 @@ -[console_scripts] -scnym = scnym.main:main -scnym_ad = scnym.scnym_ad:main diff --git a/scnym.egg-info/requires.txt b/scnym.egg-info/requires.txt deleted file mode 100644 index e9ae8ff..0000000 --- a/scnym.egg-info/requires.txt +++ /dev/null @@ -1,31 +0,0 @@ -anndata==0.8.0 -ConfigArgParse==1.1 -h5py==3.10.0 -leidenalg==0.8.0 -louvain==0.7.0 -numba==0.49.1 -numpy==1.21.0 -numpy-groupies==0.9.13 -pandas==1.5.3 -pytest==5.4.1 -python-dateutil==2.8.2 -PyYAML==5.3.1 -requests==2.26.0 -requests-cache==0.5.2 -requests-oauthlib==1.3.0 -requests-toolbelt==0.9.1 -matplotlib==3.6.3 -scanpy==1.6.0 -scikit-learn==0.22.2.post1 -scikit-misc==0.1.3 -scipy==1.4.1 -six==1.14.0 -tensorboard==2.2.1 -tensorboard-plugin-wit==1.6.0.post2 -tensorboardX==2.1 -torch==1.4.0 -torchvision==0.5.0 -tqdm==4.44.1 -umap-learn==0.3.10 -urllib3==1.26.6 -protobuf==3.20.* diff --git a/scnym.egg-info/top_level.txt b/scnym.egg-info/top_level.txt deleted file mode 100644 index 0431c2e..0000000 --- a/scnym.egg-info/top_level.txt +++ /dev/null @@ -1 +0,0 @@ -scnym diff --git a/scnym/__pycache__/__init__.cpython-38.pyc b/scnym/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 5d3d7e7..0000000 Binary files a/scnym/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/scnym/__pycache__/api.cpython-38.pyc b/scnym/__pycache__/api.cpython-38.pyc deleted file mode 100644 index 0b2a6b3..0000000 Binary files a/scnym/__pycache__/api.cpython-38.pyc and /dev/null differ diff --git a/scnym/__pycache__/attributionpriors.cpython-38.pyc b/scnym/__pycache__/attributionpriors.cpython-38.pyc deleted file mode 100644 index a1817da..0000000 Binary files a/scnym/__pycache__/attributionpriors.cpython-38.pyc and /dev/null differ diff --git a/scnym/__pycache__/dataprep.cpython-38.pyc b/scnym/__pycache__/dataprep.cpython-38.pyc deleted file mode 100644 index dacf78b..0000000 Binary files a/scnym/__pycache__/dataprep.cpython-38.pyc and /dev/null differ diff --git a/scnym/__pycache__/distributions.cpython-38.pyc b/scnym/__pycache__/distributions.cpython-38.pyc deleted file mode 100644 index ad0c70d..0000000 Binary files a/scnym/__pycache__/distributions.cpython-38.pyc and /dev/null differ diff --git a/scnym/__pycache__/interpret.cpython-38.pyc b/scnym/__pycache__/interpret.cpython-38.pyc deleted file mode 100644 index 31c90e1..0000000 Binary files a/scnym/__pycache__/interpret.cpython-38.pyc and /dev/null differ diff --git a/scnym/__pycache__/losses.cpython-38.pyc b/scnym/__pycache__/losses.cpython-38.pyc deleted file mode 100644 index 9857556..0000000 Binary files a/scnym/__pycache__/losses.cpython-38.pyc and /dev/null differ diff --git a/scnym/__pycache__/main.cpython-38.pyc b/scnym/__pycache__/main.cpython-38.pyc deleted file mode 100644 index 38d96bc..0000000 Binary files a/scnym/__pycache__/main.cpython-38.pyc and /dev/null differ diff --git a/scnym/__pycache__/model.cpython-38.pyc b/scnym/__pycache__/model.cpython-38.pyc deleted file mode 100644 index 1d5407c..0000000 Binary files a/scnym/__pycache__/model.cpython-38.pyc and /dev/null differ diff --git a/scnym/__pycache__/predict.cpython-38.pyc b/scnym/__pycache__/predict.cpython-38.pyc deleted file mode 100644 index 2717676..0000000 Binary files a/scnym/__pycache__/predict.cpython-38.pyc and /dev/null differ diff --git a/scnym/__pycache__/trainer.cpython-38.pyc b/scnym/__pycache__/trainer.cpython-38.pyc deleted file mode 100644 index 2ef7490..0000000 Binary files a/scnym/__pycache__/trainer.cpython-38.pyc and /dev/null differ diff --git a/scnym/__pycache__/utils.cpython-38.pyc b/scnym/__pycache__/utils.cpython-38.pyc deleted file mode 100644 index b85fbc4..0000000 Binary files a/scnym/__pycache__/utils.cpython-38.pyc and /dev/null differ