diff --git a/dynamo/__init__.py b/dynamo/__init__.py index 870f8c30e..1a4a5fcf0 100755 --- a/dynamo/__init__.py +++ b/dynamo/__init__.py @@ -22,6 +22,7 @@ from . import sample_data from . import configuration from . import ext +from . import multi from .data_io import * from .dynamo_logger import ( diff --git a/dynamo/multi.py b/dynamo/multi.py new file mode 100644 index 000000000..9b876f67e --- /dev/null +++ b/dynamo/multi.py @@ -0,0 +1 @@ +from .multivelo import * \ No newline at end of file diff --git a/dynamo/multivelo/ATACseqTools.py b/dynamo/multivelo/ATACseqTools.py new file mode 100644 index 000000000..cf30145a7 --- /dev/null +++ b/dynamo/multivelo/ATACseqTools.py @@ -0,0 +1,406 @@ +import anndata as ad +from anndata import AnnData + +from concurrent.futures import as_completed, ThreadPoolExecutor + +from mudata import MuData +import numpy as np +from os import PathLike +import pandas as pd + +import scanpy as sc +from scipy.sparse import coo_matrix, csr_matrix, diags, hstack +from tqdm import tqdm +from typing import ( + Literal, + Union +) + +# Imports from dynamo +from ..dynamo_logger import ( + LoggerManager, + main_info, +) + + +# Imports from MultiDynamo +from .MultiConfiguration import MDKM + + +def extend_gene_coordinates( + bedtool, + upstream: int = 2000, + downstream: int = 0 +): + from pybedtools import BedTool + extended_genes = [] + for feature in bedtool: + if feature[2] == 'gene': + start = max(0, int(feature.start) - upstream) + end = int(feature.end) + downstream + extended_genes.append((feature.chrom, start, end, feature.name)) + return BedTool(extended_genes) + + +def annotate_integrated_mdata(mdata: MuData, + celltypist_model: str = 'Immune_All_Low.pkl' + ) -> MuData: + import celltypist + from celltypist import models + # Extract the RNA data + rna_adata = mdata.mod['rna'].copy() + + # ... revert to counts + rna_adata.X = rna_adata.layers['counts'].copy() + + # ... normalize counts so total number per cell is 10,000 (required by celltypist) + sc.pp.normalize_total(rna_adata, + target_sum=1e4) + + # ... pseudo-log transform (x -> log(1 + x)) for better dynamical range (and required by celltypist) + sc.pp.log1p(rna_adata) + + # ... rerun PCA - CellTypist can need larger number than we already computed + sc.pp.pca(rna_adata, n_comps=50) + + # ... recompute the neighborhood graph for majority voting + sc.pp.neighbors(rna_adata, + n_neighbors=50, + n_pcs=50) + + # Download celltypist models for annotation + models.download_models(force_update=True) + + # Select the low resolution immun cell model + model = models.Model.load(model=celltypist_model) + + # Compute cell type labels + predictions = celltypist.annotate(rna_adata, + model=celltypist_model, + majority_voting=True) + + # Transfer the predictions back to the RNA AnnData object + rna_adata = predictions.to_adata() + + # Create dictionary from cell indices to cell types + cellindex_to_celltype_dict = rna_adata.obs['majority_voting'].to_dict() + + # Apply the index map to both RNA and ATAC AnnData objects + atac_adata, rna_adata = mdata.mod['atac'].copy(), mdata.mod['rna'].copy() + atac_adata.obs['cell_type'] = atac_adata.obs.index.map( + lambda cell_idx: cellindex_to_celltype_dict.get(cell_idx, 'Undefined')) + rna_adata.obs['cell_type'] = rna_adata.obs.index.map( + lambda cell_idx: cellindex_to_celltype_dict.get(cell_idx, 'Undefined')) + + return MuData({'atac': atac_adata.copy(), 'rna': rna_adata.copy()}) + + +def gene_activity( + atac_adata: AnnData, + gtf_path: PathLike, + upstream: int = 2000, + downstream: int = 0 +) -> Union[AnnData, None]: + from pybedtools import BedTool + # Drop UCSC convention for naming of chromosomes - This assumes we are using ENSEMBL-format of GTF + atac_adata.var.index = [c.lstrip('chr') for c in atac_adata.var.index] + + # Read GTF to annotate genes + main_info('reading GTF', indent_level=3) + gene_annotations = BedTool(gtf_path) + + # Extend gene coordinates + main_info('extending genes to estimate regulatory regions', indent_level=3) + peak_extension_logger = LoggerManager.gen_logger('extend_genes') + peak_extension_logger.log_time() + + extended_genes = extend_gene_coordinates(gene_annotations, + upstream=upstream, + downstream=downstream) + + peak_extension_logger.finish_progress(progress_name='extend_genes', indent_level=3) + + # Extract ATAC-seq peak coordinates + chrom_list = atac_adata.var_names.str.split(':').str[0] # .astype(int) + start_list = atac_adata.var_names.str.split(':').str[1].str.split('-').str[0] # .astype(int).astype(int) + end_list = atac_adata.var_names.str.split(':').str[1].str.split('-').str[1] # .astype(int).astype(int) + + # Convert ATAC-seq peak data to BedTool format + atac_peaks = BedTool.from_dataframe(pd.DataFrame({ + 'chrom': chrom_list, + 'start': start_list, + 'end': end_list + })) + + # Find overlaps between peaks and extended genes + main_info('overlapping peaks and extended genes', indent_level=3) + linked_peaks = atac_peaks.intersect(extended_genes, wa=True, wb=True) + + # Create a DataFrame from the linked peaks + linked_peaks_df = linked_peaks.to_dataframe( + names=['chrom', 'peak_start', 'peak_end', 'chrom_gene', 'gene_start', 'gene_end', 'gene_name']) + + # Create a dictionary to map peak indices to gene names + main_info('building dictionaries', indent_level=3) + peak_to_gene = linked_peaks_df.set_index(['chrom', 'peak_start', 'peak_end'])['gene_name'].to_dict() + peak_to_gene = {f'{chrom}:{start}-{end}': gene_name for (chrom, start, end), gene_name in peak_to_gene.items()} + + # Get the list of peaks from the ATAC-seq data + peaks = atac_adata.var.index + gene_names = np.array([peak_to_gene.get(peak, '') for peak in peaks]) + + # Get the unique genes + unique_genes = np.unique(gene_names) + + # Initialize a sparse matrix for gene activity scores + n_cells, n_genes = atac_adata.n_obs, len(unique_genes) + + # Create a mapping from gene names to column indices in the sparse matrix + gene_to_idx = {gene: idx for idx, gene in enumerate(unique_genes)} + + def process_peak(i): + gene = gene_names[i] + return gene_to_idx.get(gene, -1), atac_adata[:, i].X + + # Fill the sparse matrix with aggregated counts in parallel + results = [] + with ThreadPoolExecutor() as executor: + futures = [executor.submit(process_peak, i) for i in range(len(peaks))] + with tqdm(total=len(peaks), desc="Processing peaks") as pbar: + for future in as_completed(futures): + result = future.result() + if result is not None: + results.append(result) + pbar.update(1) + + # Aggregate results in batches to minimize overhead + main_info('aggregating results', indent_level=3) + aggregation_logger = LoggerManager.gen_logger('aggregating_results') + aggregation_logger.log_time() + + data = [] + rows = [] + cols = [] + + # Loop through the results to gather data for COO matrix + for col_idx, sparse_col_vector in results: + # Extract row indices and data from the sparse column vector + coo = sparse_col_vector.tocoo() + data.extend(coo.data) + rows.extend(coo.row) + cols.extend([col_idx] * len(coo.row)) + + # Create a COO matrix from collected data + coo_matrix_all = coo_matrix((data, (rows, cols)), shape=(n_cells, n_genes)) + + # Convert COO matrix to CSR format + gene_activity_matrix = coo_matrix_all.tocsr() + + aggregation_logger.finish_progress(progress_name='aggregating_results', indent_level=3) + + # Add the sparse gene activity matrix as a new .obsm + atac_adata.obsm[MDKM.ATAC_GENE_ACTIVITY_KEY] = gene_activity_matrix + atac_adata.uns[MDKM.ATAC_GENE_ACTIVITY_GENES_KEY] = pd.Index(unique_genes) + + return atac_adata + + +def integrate(mdata: MuData, + integration_method: Literal['moscot', 'multivi'] = 'multivi', + alpha: float = 0.5, + entropic_regularization: float = 0.01, + gtf_path: Union[PathLike, str] = None, + max_epochs: int = 500, + lr: float = 0.0001, + ) -> MuData: + # Split into scATAC-seq and scRNA-seq AnnData objects + atac_adata, rna_adata = mdata.mod['atac'].copy(), mdata.mod['rna'].copy() + atac_adata.obs['modality'], rna_adata.obs['modality'] = 'atac', 'rna' + + if atac_adata.uns[MDKM.MATCHED_ATAC_RNA_DATA_KEY]: + main_info('Integration: matched multiome, so just filtering cells') + + # Restrict to cells common to both AnnData objects + shared_cells = pd.Index(np.intersect1d(rna_adata.obs_names, atac_adata.obs_names)) + atac_adata_filtered = atac_adata[shared_cells, :].copy() + rna_adata_filtered = rna_adata[shared_cells, :].copy() + + return MuData({'atac': atac_adata_filtered, 'rna': rna_adata_filtered}) + elif integration_method == 'moscot': + return integrate_via_moscot(mdata=mdata, + alpha=alpha, + entropic_regularization=entropic_regularization) + elif integration_method == 'multivi': + return integrate_via_multivi(mdata=mdata, + gtf_path=gtf_path, + lr=lr, + max_epochs=max_epochs) + else: + raise ValueError(f'Unknown integration method {integration_method} requested.') + + +def integrate_via_moscot(mdata: MuData, + alpha: float = 0.7, + entropic_regularization: float = 0.01, + gtf_path: Union[PathLike, str] = None, + ) -> MuData: + pass + + +def integrate_via_multivi(mdata: MuData, + gtf_path: Union[PathLike, str] = None, + lr: float = 0.0001, + max_epochs: int = 500, + ) -> MuData: + import scvi + main_info('Integration via MULTIVI ...') + integration_logger = LoggerManager.gen_logger('integration_via_multivi') + integration_logger.log_time() + + # Split into scATAC-seq and scRNA-seq AnnData objects + atac_adata, rna_adata = mdata.mod['atac'].copy(), mdata.mod['rna'].copy() + atac_adata.obs['modality'], rna_adata.obs['modality'] = 'atac', 'rna' + + # Check whether cell indices need to be prepended by 'atac' and 'rna' + if ':' not in atac_adata.obs_names[0]: + atac_adata.obs_names = atac_adata.obs_names.map(lambda x: f'atac:{x}') + num_atac_cells, num_atac_peaks = atac_adata.n_obs, atac_adata.n_vars + + if ':' not in rna_adata.obs_names[0]: + rna_adata.obs_names = rna_adata.obs_names.map(lambda x: f'rna:{x}') + num_rna_cells = rna_adata.n_obs + + # Check whether gene activity was pre-computed + if MDKM.ATAC_GENE_ACTIVITY_KEY not in atac_adata.obsm.keys(): + main_info('Computing gene activities', indent_level=2) + atac_adata = gene_activity(atac_adata=atac_adata, + gtf_path=gtf_path) + gene_activity_matrix = atac_adata.obsm[MDKM.ATAC_GENE_ACTIVITY_KEY] + + # Restrict to gene names common to gene activity matrix from atac-seq data and + # counts matrix from rna-seq data + gene_names_atac = atac_adata.uns[MDKM.ATAC_GENE_ACTIVITY_GENES_KEY] + gene_names_rna = rna_adata.var_names + common_genes = gene_names_rna.intersection(gene_names_atac) + num_genes = len(common_genes) + + # Filter gene activity and scATAC-seq data into a single AnnData object, with a + # batch label indicating the origin + main_info('Preparing ATAC-seq data for MULTIVI', indent_level=2) + gene_activity_filtered = gene_activity_matrix[:, [gene_names_atac.get_loc(gene) for gene in common_genes]] + + # Assemble multi-ome for the ATAC-seq data + # ... X + atac_multiome_X = hstack([gene_activity_filtered, atac_adata.X]) + + # ... obs + atac_multiome_obs = atac_adata.obs[['modality']].copy() + + # ... var + multiome_var = pd.concat((rna_adata.var.loc[common_genes].copy(), atac_adata.var.copy()), axis=1) + + atac_multiome = AnnData(X=csr_matrix(atac_multiome_X), + obs=atac_multiome_obs, + var=multiome_var) + + # Assemble multi-ome for RNA-seq data + main_info('Preparing RNA-seq data for MULTIVI', indent_level=2) + rna_adata_filtered = rna_adata[:, common_genes].copy() + + # ... X + rna_multiome_X = hstack([rna_adata_filtered.X.copy(), csr_matrix((num_rna_cells, num_atac_peaks))]) + + # ... obs + rna_multiome_obs = rna_adata_filtered.obs[['modality']].copy() + + # ... var - NTD + + rna_multiome = AnnData(X=csr_matrix(rna_multiome_X), + obs=rna_multiome_obs, + var=multiome_var) + + # Concatenate the data + combined_adata = ad.concat([atac_multiome, rna_multiome], axis=0) + + # Setup AnnData object for scvi-tools + main_info('Setting up combined data for MULTIVI', indent_level=2) + scvi.model.MULTIVI.setup_anndata(combined_adata, batch_key='modality') + + # Instantiate the SCVI model + main_info('Instantiating MULTIVI model', indent_level=2) + multivi_model = scvi.model.MULTIVI(adata=combined_adata, n_genes=num_genes, n_regions=num_atac_peaks) + + # Train the model + main_info('Training MULTIVI model', indent_level=2) + multivi_model.train(max_epochs=max_epochs, lr=lr) + + # Extract the latent representation + combined_adata.obsm['latent'] = multivi_model.get_latent_representation() + + # Impute counts from latent space + # ... X + main_info('Imputing RNA expression', indent_level=2) + imputed_rna_X = multivi_model.get_normalized_expression() + + # ... obs + multiome_obs = pd.concat((atac_multiome_obs, rna_multiome_obs)) + + # ... var + rna_multiome_var = rna_adata.var.loc[common_genes].copy() + + imputed_rna_adata = AnnData(X=imputed_rna_X, + obs=multiome_obs, + var=rna_multiome_var, + ) + + # ... X + main_info('Imputing accessibility', indent_level=2) + imputed_atac_X = multivi_model.get_accessibility_estimates() + + # ... obs - NTD + + # ... var + atac_multiome_var = atac_adata.var.copy() + + imputed_atac_adata = AnnData(X=imputed_atac_X, + obs=multiome_obs, + var=atac_multiome_var, + ) + + # Knit together into one harmonized MuData object + harmonized_mdata = MuData({'atac': imputed_atac_adata, 'rna': imputed_rna_adata}) + + integration_logger.finish_progress(progress_name='integration_via_multivi', indent_level=3) + + return harmonized_mdata + + +def tfidf_normalize( + atac_adata: AnnData, + log_tf: bool = True, + log_idf: bool = True, + log_tfidf: bool = False, + mv_algorithm: bool = True, + scale_factor: float = 1e4, +) -> None: + import muon as mu + # This computes the term frequency / inverse domain frequency normalization. + if mv_algorithm: + # MultiVelo's method + npeaks = atac_adata.X.sum(1) + npeaks_inv = csr_matrix(1.0 / npeaks) + tf = atac_adata.X.multiply(npeaks_inv) + idf = diags(np.ravel(atac_adata.X.shape[0] / atac_adata.X.sum(0))).log1p() + tf_idf = tf.dot(idf) * scale_factor + atac_adata.layers[MDKM.ATAC_TFIDF_LAYER] = np.log1p(tf_idf) + else: + atac_adata = mu.atac.pp.tfidf(data=atac_adata, + log_tf=log_tf, + log_idf=log_idf, + log_tfidf=log_tfidf, + scale_factor=scale_factor, + from_layer='counts', + to_layer=MDKM.ATAC_TFIDF_LAYER, + copy=True) + + return atac_adata diff --git a/dynamo/multivelo/ChromatinVelocity.py b/dynamo/multivelo/ChromatinVelocity.py new file mode 100644 index 000000000..6501f3659 --- /dev/null +++ b/dynamo/multivelo/ChromatinVelocity.py @@ -0,0 +1,213 @@ +import numpy as np +from scipy.sparse import issparse +from typing import Literal + +# Import from dynamo +from ..dynamo_logger import ( + main_exception, +) + + +# ChromatinVelocity class - patterned after MultiVelo, but retains accessibility at individual CRE +class ChromatinVelocity: + def __init__(self, + c, + u, + s, + ss, + us, + uu, + fit_args=None, + gene=None, + r2_adjusted=False): + self.gene = gene + self.outlier = np.clip(fit_args['outlier'], a_min=80, a_max=100) + self.r2_adjusted = r2_adjusted + self.total_n = len(u) + + # Convert all sparse vectors to dense ones + c = c.A if issparse(c) else c + s = s.A if issparse(s) else s + u = u.A if issparse(u) else u + ss = ss.A if ((ss is not None) and issparse(ss)) else ss + us = us.A if ((us is not None) and issparse(us)) else us + uu = uu.A if ((uu is not None) and issparse(uu)) else uu + + # In distinction to MultiVelo c will be (total_n, n_peak) array + # Sweep the minimum value in each column from the array + self.offset_c = np.min(c, axis=0) + self.c_all = c - self.offset_c + + # The other moments are (total_n, ) arrays + self.s_all, self.u_all = np.ravel(np.array(s, dtype=np.float64)), np.ravel(np.array(u, dtype=np.float64)) + self.offset_s, self.offset_u = np.min(self.s_all), np.min(self.u_all) + self.s_all -= self.offset_s + self.u_all -= self.offset_u + + # For 'stochastic' method also need second moments + if ss is not None: + self.ss_all = np.ravel(np.array(ss, dtype=np.float64)) + if us is not None: + self.us_all = np.ravel(np.array(us, dtype=np.float64)) + if uu is not None: + self.uu_all = np.ravel(np.array(uu, dtype=np.float64)) + + # Ensure at least one element in each cell is positive + any_c_positive = np.any(self.c_all > 0, axis=1) + self.non_zero = np.ravel(any_c_positive) | np.ravel(self.u_all > 0) | np.ravel(self.s_all > 0) + + # remove outliers + # ... for chromatin, we'll be more stringent - if *any* peak count for a cell + # is an outlier, we'll remove that cell + self.non_outlier = np.all(self.c_all <= np.percentile(self.c_all, self.outlier, axis=0), axis=1) + self.non_outlier &= np.ravel(self.u_all <= np.percentile(self.u_all, self.outlier)) + self.non_outlier &= np.ravel(self.s_all <= np.percentile(self.s_all, self.outlier)) + self.c = self.c_all[self.non_zero & self.non_outlier] + self.u = self.u_all[self.non_zero & self.non_outlier] + self.s = self.s_all[self.non_zero & self.non_outlier] + self.ss = (None if ss is None + else self.ss_all[self.non_zero & self.non_outlier]) + self.us = (None if us is None + else self.us_all[self.non_zero & self.non_outlier]) + self.uu = (None if uu is None + else self.uu_all[self.non_zero & self.non_outlier]) + self.low_quality = len(self.u) < 10 + + # main_info(f'{len(self.u)} cells passed filter and will be used to fit regressions.') + + # 4 rate parameters + self.alpha_c = 0.1 + self.alpha = 0.0 + self.beta = 0.0 + self.gamma_det = 0.0 + self.gamma_stoch = 0.0 + + # other parameters or results + self.loss_det = np.inf + self.loss_stoch = np.inf + self.r2_det = 0 + self.r2_stoch = 0 + self.residual_det = None + self.residual_stoch = None + self.residual2_stoch = None + + self.steady_state_func = None + + # Select the cells for regression + w_sub_for_c = np.any(self.c >= 0.1 * np.max(self.c, axis=0), axis=1) + w_sub = w_sub_for_c & (self.u >= 0.1 * np.max(self.u)) & (self.s >= 0.1 * np.max(self.s)) + c_sub = self.c[w_sub] + w_sub_for_c = np.any(self.c >= np.mean(c_sub, axis=0) + np.std(c_sub, axis=0)) + w_sub = w_sub_for_c & (self.u >= 0.1 * np.max(self.u)) & (self.s >= 0.1 * np.max(self.s)) + self.w_sub = w_sub + if np.sum(self.w_sub) < 10: + self.low_quality = True + + # This method originated from MultiVelo - Corrected R^2 + def compute_deterministic(self): + # Steady state slope - no different than usual transcriptomic version + u_high = self.u[self.w_sub] + s_high = self.s[self.w_sub] + wu_high = u_high >= np.percentile(u_high, 95) + ws_high = s_high >= np.percentile(s_high, 95) + ss_u = u_high[wu_high | ws_high] + ss_s = s_high[wu_high | ws_high] + + gamma_det = np.dot(ss_u, ss_s) / np.dot(ss_s, ss_s) + self.steady_state_func = lambda x: gamma_det * x + residual_det = self.u_all - self.steady_state_func(self.s_all) + + loss_det = np.dot(residual_det, residual_det) / len(self.u_all) + + if self.r2_adjusted: + gamma_det = np.dot(self.u, self.s) / np.dot(self.s, self.s) + residual_det = self.u_all - gamma_det * self.s_all + + total_det = self.u_all - np.mean(self.u_all) + # total_det = self.u_all # Since fitting only slope with zero intercept, should not include mean + + self.gamma_det = gamma_det + self.loss_det = loss_det + self.residual_det = residual_det + + self.r2_det = 1 - np.dot(residual_det, residual_det) / np.dot(total_det, total_det) + + + # This method originated from MultiVelo + def compute_stochastic(self): + self.compute_deterministic() + + var_ss = 2 * self.ss - self.s + cov_us = 2 * self.us + self.u + s_all_ = 2 * self.s_all ** 2 - (2 * self.ss_all - self.s_all) + u_all_ = (2 * self.us_all + self.u_all) - 2 * self.u_all * self.s_all + gamma2 = np.dot(cov_us, var_ss) / np.dot(var_ss, var_ss) + residual2 = cov_us - gamma2 * var_ss + std_first = np.std(self.residual_det) + std_second = np.std(residual2) + + # chromatin adjusted steady-state slope + u_high = self.u[self.w_sub] + s_high = self.s[self.w_sub] + wu_high = u_high >= np.percentile(u_high, 95) + ws_high = s_high >= np.percentile(s_high, 95) + ss_u = u_high * (wu_high | ws_high) + ss_s = s_high * (wu_high | ws_high) + a = np.hstack((ss_s / std_first, var_ss[self.w_sub] / std_second)) + b = np.hstack((ss_u / std_first, cov_us[self.w_sub] / std_second)) + + gamma_stoch = np.dot(b, a) / np.dot(a, a) + self.steady_state_func = lambda x: gamma_stoch * x + self.residual_stoch = self.u_all - self.steady_state_func(self.s_all) + self.residual2_stoch = u_all_ - self.steady_state_func(s_all_) + loss_stoch = np.dot(self.residual_stoch, self.residual_stoch) / len(self.u_all) + + self.gamma_stoch = gamma_stoch + self.loss_stoch = loss_stoch + self.r2_stoch = 1 - np.dot(self.residual_stoch, self.residual_stoch) / np.dot(self.u_all, self.u_all) + + def get_gamma(self, + mode: Literal['deterministic', 'stochastic'] = 'stochastic'): + if mode == 'deterministic': + return self.gamma_det + elif mode == 'stochastic': + return self.gamma_stoch + else: + main_exception(f"Unknown mode {mode} - must be one of 'deterministic' or 'stochastic'") + + def get_loss(self, + mode: Literal['deterministic', 'stochastic'] = 'stochastic'): + if mode == 'deterministic': + return self.loss_det + elif mode == 'stochastic': + return self.loss_stoch + else: + main_exception(f"Unknown mode {mode} - must be one of 'deterministic' or 'stochastic'") + + def get_r2(self, + mode: Literal['deterministic', 'stochastic'] = 'stochastic'): + if mode == 'deterministic': + return self.r2_det + elif mode == 'stochastic': + return self.r2_stoch + else: + main_exception(f"Unknown mode {mode} - must be one of 'deterministic' or 'stochastic'") + + def get_variance_velocity(self, + mode: Literal['deterministic', 'stochastic'] = 'stochastic'): + if mode == 'stochastic': + return self.residual2_stoch + else: + main_exception("Should not call get_variance_velocity for mode other than 'stochastic'") + + def get_velocity(self, + mode: Literal['deterministic', 'stochastic'] = 'stochastic'): + vel = None # Make the lint checker happy + if mode == 'deterministic': + vel = self.residual_det + elif mode == 'stochastic': + vel = self.residual_stoch + else: + main_exception(f"Unknown mode {mode} - must be one of 'deterministic' or 'stochastic'") + + return vel diff --git a/dynamo/multivelo/MultiConfiguration.py b/dynamo/multivelo/MultiConfiguration.py new file mode 100644 index 000000000..41b036704 --- /dev/null +++ b/dynamo/multivelo/MultiConfiguration.py @@ -0,0 +1,97 @@ +from ..configuration import DynamoAdataKeyManager + +class MultiDynamoMdataKeyManager(DynamoAdataKeyManager): + # A class to manage the keys used in MuData object used for MultiDynamo + # Universal keys - independent of modality + INFERRED_BATCH_KEY = 'inferred_batch' + + # .mod + # ... 'atac' + # ... ... layers + ATAC_COUNTS_LAYER = 'counts' + ATAC_FIRST_MOMENT_CHROM_LAYER = 'M_c' + ATAC_TFIDF_LAYER = 'X_tfidf' # Also X? + ATAC_CHROMATIN_VELOCITY_LAYER = 'lifted_velo_c' + + # ... ... .obs + + # ... ... .obsm + ATAC_GENE_ACTIVITY_KEY = 'gene_activity' # Computed gene activity matrix - for unmatched data only + ATAC_OBSM_LSI_KEY = 'X_lsi' + ATAC_OBSM_PC_KEY = 'X_pca' + + # ... ... .obsp + + # ... ... .uns + ATAC_GENE_ACTIVITY_GENES_KEY = 'gene_activity_genes' # Genes for gene activity matrix + MATCHED_ATAC_RNA_DATA_KEY = 'matched_atac_rna_data' # Indicates whether ATAC- and RNA-seq data are matched + + # ... ... .var (atac:*) + + # ... ... .varm + ATAC_VARM_LSI_KEY = 'LSI' + + # ... 'cite' + # ... ... layers + + # ... ... .obs + + # ... ... .obsm + + # ... ... .obsp + + # ... ... .uns + MATCHED_CITE_RNA_DATA_KEY = 'matched_cite_rna_data' # Indicates whether CITE- and RNA-seq data are matched + + # ... ... .var (cite:*) + + # ... ... .varm + + # ... 'hic' + # ... ... layers + + # ... ... .obs + + # ... ... .obsm + + # ... ... .obsp + + # ... ... .uns + MATCHED_HIC_RNA_DATA_KEY = 'matched_hic_rna_data' # Indicates whether HiC- and RNA-seq data are matched + + # ... ... .var (hic:*) + + # ... ... .varm + + # ... 'rna' + # Most things are handled by DynamoAdataKeyManager; these are in addition to thos defined in dynamo + # ... ... layers + RNA_COUNTS_LAYER = 'counts' + RNA_COUNTS_LAYER_FROM_LOOM = 'matrix' + RNA_FIRST_MOMENT_CHROM_LAYER = 'M_c' + RNA_FIRST_MOMENT_SPLICED_LAYER = 'M_s' + RNA_FIRST_MOMENT_UNSPLICED_LAYER = 'M_u' + RNA_SECOND_MOMENT_SS_LAYER = 'M_ss' + RNA_SECOND_MOMENT_US_LAYER = 'M_us' + RNA_SECOND_MOMENT_UU_LAYER = 'M_uu' + RNA_SPLICED_LAYER = 'spliced' + RNA_SPLICED_VELOCITY_LAYER = 'velocity_S' + RNA_UNSPLICED_LAYER = 'unspliced' + + # ... ... .obs + + # ... ... .obsm + RNA_OBSM_PC_KEY = 'X_pca' + + # ... ... .obsp + + # ... ... .uns + + # ... ... .var (rna:*) + + # ... ... .varm + + def bogus_function(self): + pass + +MDKM = MultiDynamoMdataKeyManager diff --git a/dynamo/multivelo/MultiIO.py b/dynamo/multivelo/MultiIO.py new file mode 100644 index 000000000..2fd97171e --- /dev/null +++ b/dynamo/multivelo/MultiIO.py @@ -0,0 +1,415 @@ +from anndata import ( + AnnData, + read_loom +) +from .MultiConfiguration import MDKM + +from mudata import MuData + + +import numpy as np +import os +from os import PathLike +import pandas as pd +from pathlib import Path +import re +from typing import ( + Dict, + Literal, + Union +) + +# Imports from dynamo +from ..dynamo_logger import ( + LoggerManager, + main_exception, + main_info, +) + +# Imports from MultiDynamo +from .old_MultiVelocity import MultiVelocity +from .MultiPreprocessor import aggregate_peaks_10x + + +def add_splicing_data( + mdata: MuData, + multiome_base_path: PathLike | str, + rna_splicing_loom: PathLike | str = 'multiome.loom', + cellranger_path_structure: bool = True +) -> MuData: + # Extract accessibility and transcriptomic counts + atac_adata, rna_adata = mdata.mod['atac'], mdata.mod['rna'] + + # Read in spicing data + splicing_data_path = os.path.join(multiome_base_path, + 'velocyto' if cellranger_path_structure else '', + rna_splicing_loom) + ldata = read_loom(filename=Path(splicing_data_path)) + + # Merge splicing data with transcriptomic data + rna_adata.var_names_make_unique() + ldata.var_names_make_unique() + + common_obs = pd.unique(rna_adata.obs_names.intersection(ldata.obs_names)) + common_vars = pd.unique(rna_adata.var_names.intersection(ldata.var_names)) + + if len(common_obs) == 0: + # Try cleaning cell indices, if intersection of indices is vacuous + clean_obs_names(rna_adata) + clean_obs_names(ldata) + common_obs = rna_adata.obs_names.intersection(ldata.obs_names) + + # Restrict to common cell indices and genes + rna_adata = rna_adata[common_obs, common_vars].copy() + ldata = ldata[common_obs, common_vars].copy() + + # Transfer layers from ldata + for key, data in ldata.layers.items(): + if key not in rna_adata.layers: + rna_adata.layers[key] = data.copy() + + # Copy over the loom counts to a counts layer + rna_adata.layers[MDKM.RNA_COUNTS_LAYER] = rna_adata.layers[MDKM.RNA_COUNTS_LAYER_FROM_LOOM].copy() + + mdata = MuData({'atac': atac_adata, 'rna': rna_adata}) + + return mdata + + +# These are convenience functions pattern after (but not identical to) ones in scvelo +def clean_obs_names( + adata: AnnData, + alphabet: Literal['[AGTCBDHKMNRSVWY]'] = '[AGTCBDHKMNRSVWY]', + batch_key: str = MDKM.INFERRED_BATCH_KEY, + id_length: int = 16 +) -> AnnData: + if adata.obs_names.map(len).unique().size == 1: + # Here if all cell indices have the same numbers of characters + # ... find (first) instance of id_length valid nucleotides in the first cell index + start, end = re.search(alphabet * id_length, adata.obs_names[0]).span() + + # ... truncate the cell indices to the valid nucleotides + new_obs_names = [obs_name[start:end] for obs_name in adata.obs_names] + + # ... any characters prior to the characters that define the new cell index + # might specify the batch, so save it as tuple with the new cell index + prefixes = [ + obs_name.replace(new_obs_name, "") + for obs_name, new_obs_name in zip(adata.obs_names, new_obs_names) + ] + else: + # Here if cell indices have different lengths + prefixes, new_obs_names = [], [] + for obs_name in adata.obs_names: + # ... loop over the cell indices individually; find the (first) instance + # of id_length valid nucleotides in each cell index + start, end = re.search(alphabet * id_length, adata.obs_names[0]).span() + + # ... truncate the cell indices to the valid nucleotides + new_obs_names.append(obs_name[start:end]) + + # ... any characters prior to the characters that define the new cell index + # might specify the batch, so save it as tuple with the new cell index + prefixes.append(obs_name.replace(obs_name[start:end], "")) + + adata.obs_names = new_obs_names + adata.obs_names_make_unique() + + if len(prefixes[0]) > 0 and len(np.unique(prefixes)) > 1: + # If non-trival list of prefices (non-trivial length and more than one different), + # then add MDKM.INFERRED_BATCH_KEY to cell metadata + adata.obs[batch_key] = ( + pd.Categorical(prefixes) + if len(np.unique(prefixes)) < adata.n_obs + else prefixes + ) + + return adata + + +def homogenize_mudata_obs_names( + mdata: MuData, + alphabet: Literal['[AGTCBDHKMNRSVWY]'] = '[AGTCBDHKMNRSVWY]', + batch_key: str = MDKM.INFERRED_BATCH_KEY, + id_length: int = 16 +) -> MuData: + cleaned_modality_dict = {} + for modality, modality_adata in mdata.mod.items(): + cleaned_modality_adata = clean_obs_names(adata=modality_adata, + alphabet=alphabet, + batch_key=batch_key, + id_length=id_length) + cleaned_modality_dict[modality] = cleaned_modality_adata.copy() + return MuData(cleaned_modality_dict) + + +def read(path_dict: Dict) -> MultiVelocity: + pass # Can significantly simply + +# ... from unmatched scRNA-seq and scATAC-seq data +def read_10x_atac_rna_h5_old( + atac_path: Union[PathLike, str], + rna_path: Union[PathLike, str], + atac_counts_matrix: Union[PathLike, str] = 'filtered_peak_bc_matrix', + rna_h5_fn: Union[PathLike, str] = 'filtered_feature_bc_matrix.h5', + rna_splicing_loom: Union[PathLike, str] = 'multiome.loom', + alphabet: Literal['[AGTCBDHKMNRSVWY]'] = '[AGTCBDHKMNRSVWY]', + batch_key: str = MDKM.INFERRED_BATCH_KEY, + cellranger_path_structure: bool = True, + id_length: int = 16 +) -> MuData: + from muon import atac as ac + import muon as mu + import scvi + main_info('Deserializing UNMATCHED scATAC-seq and scRNA-seq data ...') + temp_logger = LoggerManager.gen_logger('read_10x_atac_rna_h5') + temp_logger.log_time() + + # Read scATAC-seq h5 file + # ... counts + main_info(f'reading scATAC-seq data', indent_level=2) + atac_matrix_path = os.path.join(atac_path, + 'outs' if cellranger_path_structure else '', + atac_counts_matrix) + + atac_adata = scvi.data.read_10x_atac(atac_matrix_path) + + # Read scRNA-seq h5 file + main_info(f'reading scRNA-seq data', indent_level=2) + rna_h5_path = os.path.join(rna_path, + 'outs' if cellranger_path_structure else '', + rna_h5_fn) + + rna_adata = mu.read_10x_h5(filename=Path(rna_h5_path)).mod['rna'] + + # Assemble MuData object + main_info(f'combining scATAC-seq data and scRNA-seq data into MuData object ...', indent_level=2) + mdata = MuData({'atac': atac_adata, 'rna': rna_adata}) + + # Flag the scATAC-seq data as unmatched to the scRNA-seq data + main_info(f' .uns[{MDKM.MATCHED_ATAC_RNA_DATA_KEY}] = False', indent_level=3) + mdata.mod['atac'].uns[MDKM.MATCHED_ATAC_RNA_DATA_KEY] = False + + # Add path to fragment file + main_info(f" path to fragments file in .uns['files']", indent_level=3) + mdata.mod['atac'].uns['files'] = {'fragments': os.path.join(atac_path, + 'outs' if cellranger_path_structure else '', + 'fragments.tsv.gz')} + + # Add 'outs' paths + # ... atac + mdata.mod['atac'].uns['base_data_path'] = atac_path + + # ... rna + mdata.mod['rna'].uns['base_data_path'] = rna_path + + # Add peak annotation + main_info(f'adding peak annotation ...', indent_level=2) + ac.tl.add_peak_annotation(data=mdata, annotation=os.path.join(atac_path, + 'outs' if cellranger_path_structure else '', + 'peak_annotation.tsv')) + + # Homogenize cell indices across modalities + main_info(f'homogenizing cell indices ...', indent_level=2) + mdata = homogenize_mudata_obs_names(mdata=mdata, + alphabet=alphabet, + batch_key=batch_key, + id_length=id_length) + + # Add transcriptomic splicing data + main_info(f'adding splicing data ...', indent_level=2) + mdata = add_splicing_data(mdata=mdata, + multiome_base_path=rna_path, + rna_splicing_loom=rna_splicing_loom, + cellranger_path_structure=cellranger_path_structure) + + temp_logger.finish_progress(progress_name='read_10x_atac_rna_h5') + + return mdata + + +# ... from matched 10X multiome +def read_10x_multiome_h5_old( + multiome_base_path: Union[PathLike, str], + multiome_h5_fn: Union[PathLike, str] = 'filtered_feature_bc_matrix.h5', + rna_splicing_loom: Union[PathLike, str] = 'multiome.loom', + alphabet: Literal['[AGTCBDHKMNRSVWY]'] = '[AGTCBDHKMNRSVWY]', + batch_key: str = MDKM.INFERRED_BATCH_KEY, + cellranger_path_structure: bool = True, + id_length: int = 16 +) -> MuData: + import muon as mu + from muon import atac as ac + + main_info('Deserializing MATCHED scATAC-seq and scRNA-seq data ...') + temp_logger = LoggerManager.gen_logger('read_10x_multiome_h5') + temp_logger.log_time() + + # Assemble absolute path to multiomic data + full_multiome_path = os.path.join(multiome_base_path, + 'outs' if cellranger_path_structure else '', + multiome_h5_fn) + + # Read the multiome h5 file + main_info(f'reading the multiome h5 file ...', indent_level=2) + mdata = mu.read_10x_h5(Path(full_multiome_path), extended=True) + + # Flag the scATAC-seq data as matched to the scRNA-seq data + main_info(f' .uns[{MDKM.MATCHED_ATAC_RNA_DATA_KEY}] = True', indent_level=3) + mdata.mod['atac'].uns[MDKM.MATCHED_ATAC_RNA_DATA_KEY] = True + + # Add 'outs' paths - Note: for multiome they are identical + # ... atac + mdata.mod['atac'].uns['base_data_path'] = multiome_base_path + + # ... rna + mdata.mod['rna'].uns['base_data_path'] = multiome_base_path + + # Add path to fragment file + main_info(f" path to fragments file in .uns['files'] ...", indent_level=3) + mdata.mod['atac'].uns['files'] = {'fragments': os.path.join(multiome_base_path, + 'outs' if cellranger_path_structure else '', + 'fragments.tsv.gz')} + + # Add peak annotation + main_info(f'adding peak annotation ...', indent_level=2) + ac.tl.add_peak_annotation(data=mdata, annotation=os.path.join(multiome_base_path, + 'outs' if cellranger_path_structure else '', + 'peak_annotation.tsv')) + + # Homogenize cell indices across modalities + main_info(f'homogenizing cell indices ...', indent_level=2) + mdata = homogenize_mudata_obs_names(mdata=mdata, + alphabet=alphabet, + batch_key=batch_key, + id_length=id_length) + + # Add transcriptomic splicing data + main_info(f'adding splicing data ...', indent_level=2) + mdata = add_splicing_data(mdata=mdata, + multiome_base_path=multiome_base_path, + rna_splicing_loom=rna_splicing_loom, + cellranger_path_structure=cellranger_path_structure) + + temp_logger.finish_progress(progress_name='read_10x_multiome_h5') + + return mdata + + +def read_10x_multiome_h5( + multiome_base_path: Union[PathLike, str], + multiome_h5_fn: Union[PathLike, str] = 'filtered_feature_bc_matrix.h5', + rna_splicing_loom: Union[PathLike, str] = 'multiome.loom', + alphabet: Literal['[AGTCBDHKMNRSVWY]'] = '[AGTCBDHKMNRSVWY]', + batch_key: str = MDKM.INFERRED_BATCH_KEY, + cellranger_path_structure: bool = True, + id_length: int = 16, + gtf_path: Union[PathLike, str] = None, +): + import muon as mu + from muon import atac as ac + + main_info('Deserializing MATCHED scATAC-seq and scRNA-seq data ...') + temp_logger = LoggerManager.gen_logger('read_10x_multiome_h5') + temp_logger.log_time() + + # Assemble absolute path to multiomic data + full_multiome_path = os.path.join(multiome_base_path, + 'outs' if cellranger_path_structure else '', + multiome_h5_fn) + + # Read the multiome h5 file + main_info(f'reading the multiome h5 file ...', indent_level=2) + mdata = mu.read_10x_h5(Path(full_multiome_path), extended=True) + + # Flag the scATAC-seq data as matched to the scRNA-seq data + main_info(f' .uns[{MDKM.MATCHED_ATAC_RNA_DATA_KEY}] = True', indent_level=3) + mdata.mod['atac'].uns[MDKM.MATCHED_ATAC_RNA_DATA_KEY] = True + + # Add 'outs' paths - Note: for multiome they are identical + # ... atac + mdata.mod['atac'].uns['base_data_path'] = multiome_base_path + + # ... rna + mdata.mod['rna'].uns['base_data_path'] = multiome_base_path + + #Add path of fragments file if exist + fragments_path = os.path.join(multiome_base_path, + 'outs' if cellranger_path_structure else '', + 'fragments.tsv.gz') + if os.path.exists(fragments_path): + main_info(f" path to fragments file in .uns['files'] ...", indent_level=3) + mdata.mod['atac'].uns['files'] = {'fragments': fragments_path} + else: + main_info(f"fragments file not found in {fragments_path}", indent_level=3) + + # Add peak annotation file if exist + peak_annotation_path = os.path.join(multiome_base_path, + 'outs' if cellranger_path_structure else '', + 'peak_annotation.tsv') + if os.path.exists(peak_annotation_path): + main_info(f'adding peak annotation ...', indent_level=2) + ac.tl.add_peak_annotation(data=mdata, annotation=peak_annotation_path) + + elif gtf_path is not None: + main_info(f'adding peak annotation from gtf file ...', indent_level=2) + import Epiverse as ev + atac_anno=ev.utils.Annotation(gtf_path) + atac_anno.tss_init(upstream=1000, + downstream=100) + atac_anno.distal_init(upstream=[1000,200000], + downstream=[1000,200000]) + atac_anno.body_init() + + import pandas as pd + k=0 + for chr in mdata['atac'].var['seqnames'].unique(): + if k==0: + merge_pd=atac_anno.query_multi(query_list=mdata['atac'].var.loc[mdata['atac'].var['seqnames']==chr].index.tolist(), + chrom=chr,batch=4,ncpus=8) + else: + merge_pd1=atac_anno.query_multi(query_list=mdata['atac'].var.loc[mdata['atac'].var['seqnames']==chr].index.tolist(), + chrom=chr,batch=4,ncpus=8) + merge_pd=pd.concat([merge_pd,merge_pd1]) + k+=1 + merge_pd=atac_anno.merge_info(merge_pd) + atac_anno.add_gene_info(mdata['atac'],merge_pd, + columns=['peaktype','neargene','neargene_tss']) + else: + main_info(f"peak annotation file not found in {peak_annotation_path} and gtf file not provided", indent_level=3) + + # Homogenize cell indices across modalities + main_info(f'homogenizing cell indices ...', indent_level=2) + mdata = homogenize_mudata_obs_names(mdata=mdata, + alphabet=alphabet, + batch_key=batch_key, + id_length=id_length) + + # Add transcriptomic splicing data if exist + rna_splicing_loom_path = os.path.join(multiome_base_path, + 'velocyto' if cellranger_path_structure else '', + rna_splicing_loom) + if os.path.exists(rna_splicing_loom_path): + main_info(f'adding splicing data ...', indent_level=2) + mdata = add_splicing_data(mdata=mdata, + multiome_base_path=multiome_base_path, + rna_splicing_loom=rna_splicing_loom, + cellranger_path_structure=cellranger_path_structure) + else: + main_info(f"splicing data file not found in {rna_splicing_loom_path}", indent_level=3) + + # Aggregate_peaks_10x + main_info(f'aggregating peaks ...', indent_level=2) + feature_linkage_path=os.path.join(multiome_base_path, + 'outs' if cellranger_path_structure else '', + 'analysis/feature_linkage/feature_linkage.bedpe') + adata_aggr = aggregate_peaks_10x(mdata['atac'], + peak_annotation_path, + feature_linkage_path) + + mdata.mod['aggr']=adata_aggr + + + temp_logger.finish_progress(progress_name='read_10x_multiome_h5') + + return mdata \ No newline at end of file diff --git a/dynamo/multivelo/MultiPreprocessor.py b/dynamo/multivelo/MultiPreprocessor.py new file mode 100644 index 000000000..297ea91c4 --- /dev/null +++ b/dynamo/multivelo/MultiPreprocessor.py @@ -0,0 +1,718 @@ +# Imports from external modules +from anndata import AnnData +from .MultiConfiguration import MDKM +from mudata import MuData + +import numpy as np +import os +import pandas as pd +import scanpy as sc +from tqdm import tqdm +from scipy.sparse import coo_matrix, csr_matrix, diags + +from typing import Any, Callable, Dict, List, Literal, Optional, TypedDict + +# Imports from dynamo +from ..dynamo_logger import ( + LoggerManager, + main_debug, + main_exception, + main_info, + main_info_insert_adata, + main_warning, +) +from ..preprocessing.gene_selection import ( + select_genes_monocle +) +from ..preprocessing.normalization import ( + calc_sz_factor, + normalize +) +from ..preprocessing.pca import ( + pca +) +from ..preprocessing.Preprocessor import ( + Preprocessor +) +from ..preprocessing.QC import ( + filter_cells_by_highly_variable_genes, + filter_cells_by_outliers as monocle_filter_cells_by_outliers, + filter_genes_by_outliers as monocle_filter_genes_by_outliers +) +from ..preprocessing.transform import ( + log1p +) +from ..preprocessing.utils import ( + collapse_species_adata, + convert2symbol +) + +# Imports from MultiDynamo +from .ATACseqTools import ( + tfidf_normalize +) +from .MultiQC import ( + modality_basic_stats, + modality_filter_cells_by_outliers, + modality_filter_features_by_outliers +) + +# Define a custom type for the recipe dictionary using TypedDict +ATACType = Literal['archR', 'cicero', 'muon', 'signac'] +CITEType = Literal['seurat'] +HiCType = Literal['periwal'] +ModalityType = Literal['atac', 'cite', 'hic', 'rna'] +RNAType = Literal['monocle', 'seurat', 'sctransform', 'pearson_residuals', 'monocle_pearson_residuals'] + +class RecipeDataType(TypedDict, total=False): # total=False allows partial dictionary to be valid + atac: ATACType + cite: CITEType + hic: HiCType + rna: RNAType + + +# The Multiomic Preprocessor class, MultiPreprocessor +class MultiPreprocessor(Preprocessor): + def __init__( + self, + cell_cycle_score_enable: bool=False, + cell_cycle_score_kwargs: Dict[str, Any] = {}, + collapse_species_adata_function: Callable = collapse_species_adata, + convert_gene_name_function: Callable=convert2symbol, + filter_cells_by_highly_variable_genes_function: Callable = filter_cells_by_highly_variable_genes, + filter_cells_by_highly_variable_genes_kwargs: Dict[str, Any] = {}, + filter_cells_by_outliers_function: Callable=monocle_filter_cells_by_outliers, + filter_cells_by_outliers_kwargs: Dict[str, Any] = {}, + filter_genes_by_outliers_function: Callable=monocle_filter_genes_by_outliers, + filter_genes_by_outliers_kwargs: Dict[str, Any] = {}, + force_gene_list: Optional[List[str]]=None, + gene_append_list: List[str] = [], + gene_exclude_list: List[str] = {}, + norm_method: Callable=log1p, + norm_method_kwargs: Dict[str, Any] = {}, + normalize_by_cells_function: Callable=normalize, + normalize_by_cells_function_kwargs: Dict[str, Any] = {}, + normalize_selected_genes_function: Callable=None, + normalize_selected_genes_kwargs: Dict[str, Any] = {}, + pca_function: Callable=pca, + pca_kwargs: Dict[str, Any] = {}, + regress_out_kwargs: Dict[List[str], Any] = {}, + sctransform_kwargs: Dict[str, Any] = {}, + select_genes_function: Callable = select_genes_monocle, + select_genes_kwargs: Dict[str, Any] = {}, + size_factor_function: Callable=calc_sz_factor, + size_factor_kwargs: Dict[str, Any] = {}) -> None: + super().__init__( + collapse_species_adata_function = collapse_species_adata_function, + convert_gene_name_function = convert_gene_name_function, + filter_cells_by_outliers_function = filter_cells_by_outliers_function, + filter_cells_by_outliers_kwargs = filter_cells_by_outliers_kwargs, + filter_genes_by_outliers_function = filter_genes_by_outliers_function, + filter_genes_by_outliers_kwargs = filter_genes_by_outliers_kwargs, + filter_cells_by_highly_variable_genes_function = filter_cells_by_highly_variable_genes_function, + filter_cells_by_highly_variable_genes_kwargs = filter_cells_by_highly_variable_genes_kwargs, + normalize_by_cells_function = normalize_by_cells_function, + normalize_by_cells_function_kwargs = normalize_by_cells_function_kwargs, + size_factor_function = size_factor_function, + size_factor_kwargs = size_factor_kwargs, + select_genes_function = select_genes_function, + select_genes_kwargs = select_genes_kwargs, + normalize_selected_genes_function = normalize_selected_genes_function, + normalize_selected_genes_kwargs = normalize_selected_genes_kwargs, + norm_method = norm_method, + norm_method_kwargs = norm_method_kwargs, + pca_function = pca_function, + pca_kwargs = pca_kwargs, + gene_append_list = gene_append_list, + gene_exclude_list = gene_exclude_list, + force_gene_list = force_gene_list, + sctransform_kwargs = sctransform_kwargs, + regress_out_kwargs = regress_out_kwargs, + cell_cycle_score_enable = cell_cycle_score_enable, + cell_cycle_score_kwargs = cell_cycle_score_kwargs + ) + + def preprocess_atac( + self, + mdata: MuData, + recipe: ATACType = 'muon', + tkey: Optional[str] = None, + experiment_type: Optional[str] = None + ) -> None: + if recipe == 'archR': + self.preprocess_atac_archr(mdata, + tkey=tkey, + experiment_type=experiment_type) + elif recipe == 'cicero': + self.preprocess_atac_cicero(mdata, + tkey=tkey, + experiment_type=experiment_type) + elif recipe == 'muon': + self.preprocess_atac_muon(mdata, + tkey=tkey, + experiment_type=experiment_type) + elif recipe == 'signac': + self.preprocess_atac_signac(mdata, + tkey=tkey, + experiment_type=experiment_type) + else: + raise NotImplementedError("preprocess recipe chosen not implemented: %s" % recipe) + + def preprocess_atac_archr( + self, + mdata: MuData, + tkey: Optional[str] = None, + experiment_type: Optional[str] = None + ) -> None: + pass + + def preprocess_atac_cicero( + self, + mdata: MuData, + tkey: Optional[str] = None, + experiment_type: Optional[str] = None + ) -> None: + pass + + def preprocess_atac_muon( + self, + mdata: MuData, + tkey: Optional[str] = None, + experiment_type: Optional[str] = None + ) -> None: + from muon import atac as ac + main_info('Running muon preprocessing pipeline for scATAC-seq data ...') + preprocess_logger = LoggerManager.gen_logger('preprocess_atac_muon') + preprocess_logger.log_time() + + # Standardize MuData object + self.standardize_mdata(mdata, tkey, experiment_type) + + # Filter peaks + modality_filter_features_by_outliers(mdata, + modality='atac', + quantiles=[0.01, 0.99], + var_key='n_cells_by_counts') + + # Filter cells + modality_filter_cells_by_outliers(mdata, + modality='atac', + quantiles=[0.01, 0.99], + obs_key='n_genes_by_counts') + + modality_filter_cells_by_outliers(mdata, + modality='atac', + quantiles=[0.01, 0.99], + obs_key='total_counts') + + # Extract chromatin accessibility and transcriptome + atac_adata, rna_adata = mdata.mod['atac'], mdata.mod['rna'] + + # ... store counts layer used for SCVI's variational autoencoders + atac_adata.layers[MDKM.ATAC_COUNTS_LAYER] = atac_adata.X + rna_adata.layers[MDKM.RNA_COUNTS_LAYER] = rna_adata.X + + # ... compute TF-IDF + main_info(f'computing TF-IDF', indent_level=1) + atac_adata = tfidf_normalize(atac_adata=atac_adata, mv_algorithm=False) + + # Normalize + main_info(f'normalizing', indent_level=1) + sc.pp.normalize_total(atac_adata, target_sum=1e4) + sc.pp.log1p(atac_adata) + + # Feature selection + main_info(f'feature selection', indent_level=1) + sc.pp.highly_variable_genes(atac_adata, min_mean=0.05, max_mean=1.5, min_disp=0.5) + main_info(f'identified {np.sum(atac_adata.var.highly_variable)} highly variable features', indent_level=2) + + # Store current AnnData object in raw + atac_adata.raw = atac_adata + + # Latent sematic indexing + main_info(f'computing latent sematic indexing', indent_level=1) + ac.tl.lsi(atac_adata) + + # ... drop first component (size related) + main_info(f' X_lsi key in .obsm', indent_level=2) + atac_adata.obsm[MDKM.ATAC_OBSM_LSI_KEY] = atac_adata.obsm[MDKM.ATAC_OBSM_LSI_KEY][:, 1:] + main_info(f' LSI key in .varm', indent_level=2) + atac_adata.varm[MDKM.ATAC_VARM_LSI_KEY] = atac_adata.varm[MDKM.ATAC_VARM_LSI_KEY][:, 1:] + main_info(f' [lsi][stdev] key in .uns', indent_level=2) + atac_adata.uns['lsi']['stdev'] = atac_adata.uns['lsi']['stdev'][1:] + + # ... perhaps gratuitous deep copy + mdata.mod['atac'] = atac_adata.copy() + + preprocess_logger.finish_progress(progress_name='preprocess_atac_muon') + + def preprocess_atac_signac( + self, + mdata: MuData, + recipe: ATACType = 'muon', + tkey: Optional[str] = None, + experiment_type: Optional[str] = None + ) -> None: + pass + + def preprocess_cite( + self, + mdata: MuData, + recipe: CITEType + ) -> None: + pass + + def preprocess_hic( + self, + mdata: MuData, + recipe: HiCType + ) -> None: + pass + + def preprocess_mdata( + self, + mdata: MuData, + recipe_dict: RecipeDataType = None, + tkey: Optional[str] = None, + experiment_type: Optional[str] = None, + ) -> None: + """Preprocess the MuData object with the recipe specified. + + Args: + mdata: An AnnData object. + recipe_dict: The recipe used to preprocess the data. Current modalities are scATAC-seq, CITE-seq, scHi-C + and scRNA-seq + tkey: the key for time information (labeling time period for the cells) in .obs. Defaults to None. + experiment_type: the experiment type of the data. If not provided, would be inferred from the data. + + Raises: + NotImplementedError: the recipe is invalid. + """ + + if recipe_dict is None: + # Default recipe + recipe_dict = {'atac': 'signac', 'rna': 'seurat'} + + for mod, recipe in recipe_dict.items(): + if mod not in mdata.mod: + main_exception((f'Modality {mod} not found in MuData object')) + + if mod == 'atac': + self.preprocess_atac(mdata=mdata, + recipe=recipe, + tkey=tkey, + experiment_type=experiment_type) + + elif mod == 'cite': + self.preprocess_cite(mdata=mdata, + recipe=recipe, + tkey=tkey, + experiment_type=experiment_type) + elif mod == 'hic': + self.preprocess_hic(mdata=mdata, + recipe=recipe, + tkey=tkey, + experiment_type=experiment_type) + elif mod == 'rna': + rna_adata = mdata.mod.get('rna', None) + + self.preprocess_adata(adata=rna_adata, + recipe=recipe, + tkey=tkey, + experiment_type=experiment_type) + else: + raise NotImplementedError(f'Preprocess recipe not implemented for modality: {mod}') + + # Integrate modalities - at this point have filtered out poor quality cells for individual + # modalities. Next we need to + + def standardize_mdata( + self, + mdata: MuData, + tkey: str, + experiment_type: str + ) -> None: + """Process the scATAC-seq modality within MuData to make it meet the standards of dynamo. + + The index of the observations would be ensured to be unique. The layers with sparse matrix would be converted to + compressed csr_matrix. MDKM.allowed_layer_raw_names() will be used to define only_splicing, only_labeling and + splicing_labeling keys. + + Args: + mdata: an AnnData object. + tkey: the key for time information (labeling time period for the cells) in .obs. + experiment_type: the experiment type. + """ + + for modality, modality_adata in mdata.mod.items(): + if modality == 'rna': + # Handled by dynamo + continue + + # Compute basic QC metrics + modality_basic_stats(mdata=mdata, modality=modality) + + self.add_experiment_info(modality_adata, tkey, experiment_type) + main_info_insert_adata("tkey=%s" % tkey, "uns['pp']", indent_level=2) + main_info_insert_adata("experiment_type=%s" % modality_adata.uns["pp"]["experiment_type"], + "uns['pp']", + indent_level=2) + + self.convert_layers2csr(modality_adata) + + +def aggregate_peaks_10x(adata_atac, peak_annot_file, linkage_file, + peak_dist=10000, min_corr=0.5, gene_body=False, + return_dict=False, parallel=False, n_jobs=1): + + """Peak to gene aggregation. + + This function aggregates promoter and enhancer peaks to genes based on the + 10X linkage file. + + Parameters + ---------- + adata_atac: :class:`~anndata.AnnData` + ATAC anndata object which stores raw peak counts. + peak_annot_file: `str` + Peak annotation file from 10X CellRanger ARC. + linkage_file: `str` + Peak-gene linkage file from 10X CellRanger ARC. This file stores highly + correlated peak-peak and peak-gene pair information. + peak_dist: `int` (default: 10000) + Maximum distance for peaks to be included for a gene. + min_corr: `float` (default: 0.5) + Minimum correlation for a peak to be considered as enhancer. + gene_body: `bool` (default: `False`) + Whether to add gene body peaks to the associated promoters. + return_dict: `bool` (default: `False`) + Whether to return promoter and enhancer dictionaries. + + Returns + ------- + A new ATAC anndata object which stores gene aggreagted peak counts. + Additionally, if `return_dict==True`: + A dictionary which stores genes and promoter peaks. + And a dictionary which stores genes and enhancer peaks. + """ + promoter_dict = {} + distal_dict = {} + gene_body_dict = {} + corr_dict = {} + + # read annotations + with open(peak_annot_file) as f: + header = next(f) + tmp = header.split('\t') + if len(tmp) == 4: + cellranger_version = 1 + elif len(tmp) == 6: + cellranger_version = 2 + else: + raise ValueError('Peak annotation file should contain 4 columns ' + '(CellRanger ARC 1.0.0) or 6 columns (CellRanger ' + 'ARC 2.0.0)') + + main_info(f'CellRanger ARC identified as {cellranger_version}.0.0', + indent_level=1) + + if cellranger_version == 1: + for line in f: + tmp = line.rstrip().split('\t') + tmp1 = tmp[0].split('_') + peak = f'{tmp1[0]}:{tmp1[1]}-{tmp1[2]}' + if tmp[1] != '': + genes = tmp[1].split(';') + dists = tmp[2].split(';') + types = tmp[3].split(';') + for i, gene in enumerate(genes): + dist = dists[i] + annot = types[i] + if annot == 'promoter': + if gene not in promoter_dict: + promoter_dict[gene] = [peak] + else: + promoter_dict[gene].append(peak) + elif annot == 'distal': + if dist == '0': + if gene not in gene_body_dict: + gene_body_dict[gene] = [peak] + else: + gene_body_dict[gene].append(peak) + else: + if gene not in distal_dict: + distal_dict[gene] = [peak] + else: + distal_dict[gene].append(peak) + else: + for line in f: + tmp = line.rstrip().split('\t') + peak = f'{tmp[0]}:{tmp[1]}-{tmp[2]}' + gene = tmp[3] + dist = tmp[4] + annot = tmp[5] + if annot == 'promoter': + if gene not in promoter_dict: + promoter_dict[gene] = [peak] + else: + promoter_dict[gene].append(peak) + elif annot == 'distal': + if dist == '0': + if gene not in gene_body_dict: + gene_body_dict[gene] = [peak] + else: + gene_body_dict[gene].append(peak) + else: + if gene not in distal_dict: + distal_dict[gene] = [peak] + else: + distal_dict[gene].append(peak) + + # read linkages + with open(linkage_file) as f: + for line in f: + tmp = line.rstrip().split('\t') + if tmp[12] == "peak-peak": + peak1 = f'{tmp[0]}:{tmp[1]}-{tmp[2]}' + peak2 = f'{tmp[3]}:{tmp[4]}-{tmp[5]}' + tmp2 = tmp[6].split('><')[0][1:].split(';') + tmp3 = tmp[6].split('><')[1][:-1].split(';') + corr = float(tmp[7]) + for t2 in tmp2: + gene1 = t2.split('_') + for t3 in tmp3: + gene2 = t3.split('_') + # one of the peaks is in promoter, peaks belong to the + # same gene or are close in distance + if (((gene1[1] == "promoter") != + (gene2[1] == "promoter")) and + ((gene1[0] == gene2[0]) or + (float(tmp[11]) < peak_dist))): + + if gene1[1] == "promoter": + gene = gene1[0] + else: + gene = gene2[0] + if gene in corr_dict: + # peak 1 is in promoter, peak 2 is not in gene + # body -> peak 2 is added to gene 1 + if (peak2 not in corr_dict[gene] and + gene1[1] == "promoter" and + (gene2[0] not in gene_body_dict or + peak2 not in gene_body_dict[gene2[0]])): + + corr_dict[gene][0].append(peak2) + corr_dict[gene][1].append(corr) + # peak 2 is in promoter, peak 1 is not in gene + # body -> peak 1 is added to gene 2 + if (peak1 not in corr_dict[gene] and + gene2[1] == "promoter" and + (gene1[0] not in gene_body_dict or + peak1 not in gene_body_dict[gene1[0]])): + + corr_dict[gene][0].append(peak1) + corr_dict[gene][1].append(corr) + else: + # peak 1 is in promoter, peak 2 is not in gene + # body -> peak 2 is added to gene 1 + if (gene1[1] == "promoter" and + (gene2[0] not in + gene_body_dict + or peak2 not in + gene_body_dict[gene2[0]])): + + corr_dict[gene] = [[peak2], [corr]] + # peak 2 is in promoter, peak 1 is not in gene + # body -> peak 1 is added to gene 2 + if (gene2[1] == "promoter" and + (gene1[0] not in + gene_body_dict + or peak1 not in + gene_body_dict[gene1[0]])): + + corr_dict[gene] = [[peak1], [corr]] + elif tmp[12] == "peak-gene": + peak1 = f'{tmp[0]}:{tmp[1]}-{tmp[2]}' + tmp2 = tmp[6].split('><')[0][1:].split(';') + gene2 = tmp[6].split('><')[1][:-1] + corr = float(tmp[7]) + for t2 in tmp2: + gene1 = t2.split('_') + # peak 1 belongs to gene 2 or are close in distance + # -> peak 1 is added to gene 2 + if ((gene1[0] == gene2) or (float(tmp[11]) < peak_dist)): + gene = gene1[0] + if gene in corr_dict: + if (peak1 not in corr_dict[gene] and + gene1[1] != "promoter" and + (gene1[0] not in gene_body_dict or + peak1 not in gene_body_dict[gene1[0]])): + + corr_dict[gene][0].append(peak1) + corr_dict[gene][1].append(corr) + else: + if (gene1[1] != "promoter" and + (gene1[0] not in gene_body_dict or + peak1 not in gene_body_dict[gene1[0]])): + corr_dict[gene] = [[peak1], [corr]] + elif tmp[12] == "gene-peak": + peak2 = f'{tmp[3]}:{tmp[4]}-{tmp[5]}' + gene1 = tmp[6].split('><')[0][1:] + tmp3 = tmp[6].split('><')[1][:-1].split(';') + corr = float(tmp[7]) + for t3 in tmp3: + gene2 = t3.split('_') + # peak 2 belongs to gene 1 or are close in distance + # -> peak 2 is added to gene 1 + if ((gene1 == gene2[0]) or (float(tmp[11]) < peak_dist)): + gene = gene1 + if gene in corr_dict: + if (peak2 not in corr_dict[gene] and + gene2[1] != "promoter" and + (gene2[0] not in gene_body_dict or + peak2 not in gene_body_dict[gene2[0]])): + + corr_dict[gene][0].append(peak2) + corr_dict[gene][1].append(corr) + else: + if (gene2[1] != "promoter" and + (gene2[0] not in gene_body_dict or + peak2 not in gene_body_dict[gene2[0]])): + + corr_dict[gene] = [[peak2], [corr]] + + gene_dict = promoter_dict + enhancer_dict = {} + promoter_genes = list(promoter_dict.keys()) + main_info(f'Found {len(promoter_genes)} genes with promoter peaks', indent_level=1) + for gene in promoter_genes: + if gene_body: # add gene-body peaks + if gene in gene_body_dict: + for peak in gene_body_dict[gene]: + if peak not in gene_dict[gene]: + gene_dict[gene].append(peak) + enhancer_dict[gene] = [] + if gene in corr_dict: # add enhancer peaks + for j, peak in enumerate(corr_dict[gene][0]): + corr = corr_dict[gene][1][j] + if corr > min_corr: + if peak not in gene_dict[gene]: + gene_dict[gene].append(peak) + enhancer_dict[gene].append(peak) + + # aggregate to genes + adata_atac_X_copy = adata_atac.X.A + gene_mat = np.zeros((adata_atac.shape[0], len(promoter_genes))) + var_names = adata_atac.var_names.to_numpy() + var_dict = {} + + for i, name in enumerate(var_names): + var_dict.update({name: i}) + + # if we only want to run one job at a time, then no parallelization + # is necessary + if n_jobs == 1: + parallel = False + + if parallel: + from joblib import Parallel, delayed + # if we want to run in parallel, modify the gene_mat variable with + # multiple cores, calling prepare_gene_mat with joblib.Parallel() + Parallel(n_jobs=n_jobs, + require='sharedmem')( + delayed(prepare_gene_mat)(var_dict, + gene_dict[promoter_genes[i]], + gene_mat, + adata_atac_X_copy, + i)for i in tqdm(range( + len(promoter_genes)))) + + else: + # if we aren't running in parallel, just call prepare_gene_mat + # from a for loop + for i, gene in tqdm(enumerate(promoter_genes), + total=len(promoter_genes)): + prepare_gene_mat(var_dict, + gene_dict[promoter_genes[i]], + gene_mat, + adata_atac_X_copy, + i) + + gene_mat[gene_mat < 0] = 0 + gene_mat = AnnData(X=csr_matrix(gene_mat)) + gene_mat.obs_names = pd.Index(list(adata_atac.obs_names)) + gene_mat.var_names = pd.Index(promoter_genes) + gene_mat = gene_mat[:, gene_mat.X.sum(0) > 0] + if return_dict: + return gene_mat, promoter_dict, enhancer_dict + else: + return gene_mat + +def prepare_gene_mat(var_dict, peaks, gene_mat, adata_atac_X_copy, i): + + for peak in peaks: + if peak in var_dict: + peak_index = var_dict[peak] + + gene_mat[:, i] += adata_atac_X_copy[:, peak_index] + + + +def knn_smooth_chrom(adata_atac, nn_idx=None, nn_dist=None, conn=None, + n_neighbors=None): + """KNN smoothing. + + This function smooth (impute) the count matrix with k nearest neighbors. + The inputs can be either KNN index and distance matrices or a pre-computed + connectivities matrix (for example in adata_rna object). + + Parameters + ---------- + adata_atac: :class:`~anndata.AnnData` + ATAC anndata object. + nn_idx: `np.darray` (default: `None`) + KNN index matrix of size (cells, k). + nn_dist: `np.darray` (default: `None`) + KNN distance matrix of size (cells, k). + conn: `csr_matrix` (default: `None`) + Pre-computed connectivities matrix. + n_neighbors: `int` (default: `None`) + Top N neighbors to extract for each cell in the connectivities matrix. + + Returns + ------- + `.layers['Mc']` stores imputed values. + """ + if nn_idx is not None and nn_dist is not None: + if nn_idx.shape[0] != adata_atac.shape[0]: + raise ValueError('Number of rows of KNN indices does not equal to ' + 'number of observations.') + if nn_dist.shape[0] != adata_atac.shape[0]: + raise ValueError('Number of rows of KNN distances does not equal ' + 'to number of observations.') + X = coo_matrix(([], ([], [])), shape=(nn_idx.shape[0], 1)) + from umap.umap_ import fuzzy_simplicial_set + conn, sigma, rho, dists = fuzzy_simplicial_set(X, nn_idx.shape[1], + None, None, + knn_indices=nn_idx-1, + knn_dists=nn_dist, + return_dists=True) + elif conn is not None: + pass + else: + raise ValueError('Please input nearest neighbor indices and distances,' + ' or a connectivities matrix of size n x n, with ' + 'columns being neighbors.' + ' For example, RNA connectivities can usually be ' + 'found in adata.obsp.') + + conn = conn.tocsr().copy() + n_counts = (conn > 0).sum(1).A1 + if n_neighbors is not None and n_neighbors < n_counts.min(): + from .sparse_matrix_utils import top_n_sparse + conn = top_n_sparse(conn, n_neighbors) + conn.setdiag(1) + conn_norm = conn.multiply(1.0 / conn.sum(1)).tocsr() + adata_atac.layers['Mc'] = csr_matrix.dot(conn_norm, adata_atac.X) + adata_atac.obsp['connectivities'] = conn + diff --git a/dynamo/multivelo/MultiQC.py b/dynamo/multivelo/MultiQC.py new file mode 100644 index 000000000..8604af14e --- /dev/null +++ b/dynamo/multivelo/MultiQC.py @@ -0,0 +1,141 @@ +import anndata as ad +from anndata import AnnData +from .MultiConfiguration import MDKM +from mudata import MuData + + + +import numpy as np +import pandas as pd +import scanpy as sc +from scipy.sparse import ( + issparse +) +from typing import ( + List, + Literal, + Optional, + Union, +) + +# Define several Literals - might move to MDKM +ModalityType = Literal['atac', 'cite', 'hic', 'rna'] +ObsKeyType = Literal['n_genes_by_counts', 'total_counts'] +VarKeyType = Literal['n_cells_by_counts'] + +# Imports from dynamo +from ..dynamo_logger import ( + LoggerManager, + main_debug, + main_exception, + main_finish_progress, + main_info, + main_info_insert_adata, + main_warning, +) + +def modality_basic_stats( + mdata: MuData, + modality: ModalityType = None +) -> None: + """Generate basic stats of the adata, including number of genes, number of cells, and number of mitochondria genes. + + Args: + adata: an AnnData object. + + Returns: + An updated AnnData object with a number of QC metrics computed: 'n_cells_by_counts', 'n_features_by_counts', and + 'total_counts'. (Note: since most modalities do not have direct information about related genes, fractions of + mitochondrial genes cannot be computed.) + """ + from muon import atac as ac + modality_adata = mdata.mod.get(modality, None) + if modality_adata is None: + raise ValueError(f'Modality {modality} not found in MuData object.') + + # Compute QC metrics via functionality in scanpy + sc.pp.calculate_qc_metrics(modality_adata, percent_top=None, log1p=False, inplace=True) + + # Compute modality specific QC metrics + if modality == 'atac': + ac.tl.nucleosome_signal(mdata, n=1e6) + + +def modality_filter_cells_by_outliers( + mdata: MuData, + modality: ModalityType = 'atac', + obs_key: VarKeyType = 'n_cells_by_counts', + quantiles: Optional[Union[List[float], float]] = [0.01, 0.99], + thresholds: Optional[Union[List[float], float]] = None +) -> None: + import muon as mu + modality_adata = mdata.mod.get(modality, None) + if modality_adata is None: + raise ValueError(f'Modality {modality} not found in MuData object.') + + if quantiles is not None: + # Thresholds were specified as quantiles + qc_parameter_series = modality_adata.obs[obs_key] + + if isinstance(quantiles, list): + if len(quantiles) > 2: + raise ValueError(f'More than 2 quantiles were specified {len(quantiles)}.') + + min_feature_thresh, max_feature_thresh = qc_parameter_series.quantile(quantiles).tolist() + else: + min_feature_thresh, max_feature_thresh = qc_parameter_series.quantile(quantiles), np.inf + else: + # Thresholds were specified as absolute thresholds + if isinstance(thresholds, list): + if len(thresholds) > 2: + raise ValueError(f'More than 2 thresholds were specified {len(thresholds)}.') + + min_feature_thresh, max_feature_thresh = thresholds + else: + min_feature_thresh, max_feature_thresh = thresholds, np.inf + + # Carry out the actual filtering + pre_filter_n_cells = modality_adata.n_obs + mu.pp.filter_obs(modality_adata, obs_key, lambda x: (x >= min_feature_thresh) & (x <= max_feature_thresh)) + post_filter_n_cells = modality_adata.n_obs + main_info(f'filtered out {pre_filter_n_cells - post_filter_n_cells} outlier cells', indent_level=2) + + +def modality_filter_features_by_outliers( + mdata: MuData, + modality: ModalityType = 'atac', + quantiles: Optional[Union[List[float], float]] = [0.01, 0.99], + thresholds: Optional[Union[List[float], float]] = None, + var_key: ObsKeyType = 'n_cells_by_counts' +) -> None: + import muon as mu + modality_adata = mdata.mod.get(modality, None) + if modality_adata is None: + raise ValueError(f'Modality {modality} not found in MuData object.') + + if quantiles is not None: + # Thresholds were specified as quantiles + qc_parameter_series = modality_adata.var[var_key] + + if isinstance(quantiles, list): + if len(quantiles) > 2: + raise ValueError(f'More than 2 quantiles were specified {len(quantiles)}.') + + min_feature_thresh, max_feature_thresh = qc_parameter_series.quantile(quantiles).tolist() + else: + min_feature_thresh, max_feature_thresh = qc_parameter_series.quantile(quantiles), np.inf + else: + # Thresholds were specified as absolute thresholds + if isinstance(thresholds, list): + if len(thresholds) > 2: + raise ValueError(f'More than 2 thresholds were specified {len(thresholds)}.') + + min_feature_thresh, max_feature_thresh = thresholds + else: + min_feature_thresh, max_feature_thresh = thresholds, np.inf + + # Carry out the actual filtering + pre_filter_n_cells = modality_adata.n_obs + mu.pp.filter_var(modality_adata, var_key, lambda x: (x >= min_feature_thresh) & (x <= max_feature_thresh)) + post_filter_n_cells = modality_adata.n_obs + main_info(f'filtered out {pre_filter_n_cells - post_filter_n_cells} outlier features', indent_level=2) diff --git a/dynamo/multivelo/MultiVelo.py b/dynamo/multivelo/MultiVelo.py new file mode 100644 index 000000000..7e1405d23 --- /dev/null +++ b/dynamo/multivelo/MultiVelo.py @@ -0,0 +1,115 @@ +import pandas as pd +import numpy as np +from anndata import AnnData +from mudata import MuData +from typing import Dict + +from ..tl import dynamics,reduceDimension,cell_velocities +from .MultiPreprocessor import knn_smooth_chrom +from .dynamical_chrom_func import recover_dynamics_chrom + +def multi_velocities( + mdata: MuData, + model: str='stochastic', + method: str='pearson', + other_kernels_dict: Dict={'transform': 'sqrt'}, + core: int=3, + device: str='cpu', + extra_color_key: str=None, + max_iter: int=5, + velo_arg: Dict ={}, + vkey: str='velo_s', + **kwargs +)->AnnData: + """ + Calculate the velocites using the scRNA-seq and scATAC-seq data. + + Args: + mdata: MuData object containing the RNA and ATAC data. + model: The model used to calculate the dynamics. Default is 'stochastic'. + method: The method used to calculate the velocity. Default is 'pearson'. + other_kernels_dict: The dictionary containing the parameters for the other kernels. Default is {'transform': 'sqrt'}. + core: The number of cores used for the calculation. Default is 3. + device: The device used for the calculation. Default is 'cpu'. + extra_color_key: The extra color key used for the calculation. Default is None. + max_iter: The maximum number of iterations used for the calculation. Default is 5. + velo_arg: The dictionary containing the parameters for the velocity calculation. Default is {}. + vkey: The key used for the velocity calculation. Default is 'velo_s'. + **kwargs: The other parameters used for the calculation. + + Returns: + An updated AnnData object with the velocities calculated. + + """ + # We need to calculate the dynamics of the RNA data first and reduce the dimensionality + dynamics(mdata['rna'], model=model, cores=core) + reduceDimension(mdata['rna']) + cell_velocities(mdata['rna'], method=method, + other_kernels_dict=other_kernels_dict, + **velo_arg + ) + + # And we use the connectivity matrix from the RNA data to smooth the ATAC data and calculate the Mc + knn_smooth_chrom(mdata['aggr'], conn= mdata['rna'].obsp['connectivities']) + + # We then select the genes that are present in both datasets + shared_cells = pd.Index(np.intersect1d(mdata['rna'].obs_names, mdata['aggr'].obs_names)) + shared_genes = pd.Index(np.intersect1d( + [i.split('rna:')[-1] for i in mdata['rna'][:,mdata['rna'].var['use_for_dynamics']].var_names], + [i.split('aggr:')[-1] for i in mdata['aggr'].var_names] + )) + + # We then create the AnnData objects for the RNA and ATAC data + adata_rna = mdata['rna'][shared_cells, [f'rna:{i}' for i in shared_genes]].copy() + adata_atac = mdata['aggr'][shared_cells, [f'aggr:{i}' for i in shared_genes]].copy() + adata_rna.var.index=[i.split('rna:')[-1] for i in adata_rna.var.index] + adata_atac.var.index=[i.split('aggr:')[-1] for i in adata_atac.var.index] + + adata_rna.layers['Ms']=adata_rna.layers['M_s'] + adata_rna.layers['Mu']=adata_rna.layers['M_u'] + + # Now we use MultiVelo's recover_dynamics_chrom function to calculate the dynamics of the RNA and ATAC data + adata_result = recover_dynamics_chrom(adata_rna, + adata_atac, + max_iter=max_iter, + init_mode="invert", + parallel=True, + n_jobs = core, + save_plot=False, + rna_only=False, + fit=True, + n_anchors=500, + extra_color_key=extra_color_key, + device=device, + **kwargs + ) + + # We need to add some information of new RNA velocity to the ATAC data + if vkey not in adata_result.layers.keys(): + raise ValueError('Velocity matrix is not found. Please run multivelo' + '.recover_dynamics_chrom function first.') + if vkey+'_norm' not in adata_result.layers.keys(): + adata_result.layers[vkey+'_norm'] = adata_result.layers[vkey] / np.sum( + np.abs(adata_result.layers[vkey]), 0) + adata_result.layers[vkey+'_norm'] /= np.mean(adata_result.layers[vkey+'_norm']) + adata_result.uns[vkey+'_norm_params'] = adata_result.uns[vkey+'_params'] + if vkey+'_norm_genes' not in adata_result.var.columns: + adata_result.var[vkey+'_norm_genes'] = adata_result.var[vkey+'_genes'] + + # Transition genes identification and velocity calculation + transition_genes=adata_result.var.loc[adata_result.var['velo_s_norm_genes']==True].index.tolist() + if 'pearson_transition_matrix' in adata_result.obsp.keys(): + del adata_result.obsp['pearson_transition_matrix'] + if 'velocity_umap' in adata_result.obsm.keys(): + del adata_result.obsm['velocity_umap'] + cell_velocities(adata_result, vkey='velo_s',#layer='Ms', + X=adata_result[:,transition_genes].layers['Ms'], + V=adata_result[:,transition_genes].layers['velo_s'], + transition_genes=adata_result.var.loc[adata_result.var['velo_s_norm_genes']==True].index.tolist(), + method=method, + other_kernels_dict=other_kernels_dict, + **velo_arg + ) + return adata_result + + diff --git a/dynamo/multivelo/__init__.py b/dynamo/multivelo/__init__.py new file mode 100644 index 000000000..3cc49c2d8 --- /dev/null +++ b/dynamo/multivelo/__init__.py @@ -0,0 +1,10 @@ +from .ATACseqTools import * +from .ChromatinVelocity import * +from .MultiConfiguration import * +from .MultiIO import * +from .MultiQC import * +from .old_MultiomicVectorField import * +from .MultiPreprocessor import * +from .old_MultiVelocity import * +from .pyWNN import * +from .MultiVelo import * diff --git a/dynamo/multivelo/dynamical_chrom_func.py b/dynamo/multivelo/dynamical_chrom_func.py new file mode 100644 index 000000000..3adc096f4 --- /dev/null +++ b/dynamo/multivelo/dynamical_chrom_func.py @@ -0,0 +1,6386 @@ +from dynamo.multivelo import settings + +import os +import sys +import numpy as np +from numpy.linalg import norm +import matplotlib.pyplot as plt +from scipy import sparse +from scipy.sparse import coo_matrix +from scipy.optimize import minimize +from scipy.spatial import KDTree +from sklearn.metrics import pairwise_distances +from sklearn.mixture import GaussianMixture + +import scvelo as scv +import pandas as pd +import seaborn as sns +from numba import njit +import numba +from numba.typed import List +from tqdm.auto import tqdm + +import math +import torch +from torch import nn + +current_path = os.path.dirname(__file__) +src_path = os.path.join(current_path, "..") +sys.path.append(src_path) + +from ..dynamo_logger import ( + LoggerManager, + main_exception, + main_info, +) + + + +# a funciton to check for invalid values of different parameters +def check_params(alpha_c, + alpha, + beta, + gamma, + c0=None, + u0=None, + s0=None): + + new_alpha_c = alpha_c + new_alpha = alpha + new_beta = beta + new_gamma = gamma + + new_c0 = c0 + new_u0 = u0 + new_s0 = s0 + + inf_fix = 1e10 + zero_fix = 1e-10 + + # check if any of our parameters are infinite + if c0 is not None and math.isinf(c0): + main_info("c0 is infinite.", indent_level=1) + new_c0 = inf_fix + if u0 is not None and math.isinf(u0): + main_info("u0 is infinite.", indent_level=1) + new_u0 = inf_fix + if s0 is not None and math.isinf(s0): + main_info("s0 is infinite.", indent_level=1) + new_s0 = inf_fix + if math.isinf(alpha_c): + new_alpha_c = inf_fix + main_info("alpha_c is infinite.", indent_level=1) + if math.isinf(alpha): + new_alpha = inf_fix + main_info("alpha is infinite.", indent_level=1) + if math.isinf(beta): + new_beta = inf_fix + main_info("beta is infinite.", indent_level=1) + if math.isinf(gamma): + new_gamma = inf_fix + main_info("gamma is infinite.", indent_level=1) + + # check if any of our parameters are nan + if c0 is not None and math.isnan(c0): + main_info("c0 is Nan.", indent_level=1) + new_c0 = zero_fix + if u0 is not None and math.isnan(u0): + main_info("u0 is Nan.", indent_level=1) + new_u0 = zero_fix + if s0 is not None and math.isnan(s0): + main_info("s0 is Nan.", indent_level=1) + new_s0 = zero_fix + if math.isnan(alpha_c): + new_alpha_c = zero_fix + main_info("alpha_c is Nan.", indent_level=1) + if math.isnan(alpha): + new_alpha = zero_fix + main_info("alpha is Nan.", indent_level=1) + if math.isnan(beta): + new_beta = zero_fix + main_info("beta is Nan.", indent_level=1) + if math.isnan(gamma): + new_gamma = zero_fix + main_info("gamma is Nan.", indent_level=1) + + # check if any of our rate parameters are 0 + if alpha_c < 1e-7: + new_alpha_c = zero_fix + main_info("alpha_c is zero.", indent_level=1) + if alpha < 1e-7: + new_alpha = zero_fix + main_info("alpha is zero.", indent_level=1) + if beta < 1e-7: + new_beta = zero_fix + main_info("beta is zero.", indent_level=1) + if gamma < 1e-7: + new_gamma = zero_fix + main_info("gamma is zero.", indent_level=1) + + if beta == alpha_c: + new_beta += zero_fix + main_info("alpha_c and beta are equal, leading to divide by zero", + indent_level=1) + if beta == gamma: + new_gamma += zero_fix + main_info("gamma and beta are equal, leading to divide by zero", + indent_level=1) + if alpha_c == gamma: + new_gamma += zero_fix + main_info("gamma and alpha_c are equal, leading to divide by zero", + indent_level=1) + + if c0 is not None and u0 is not None and s0 is not None: + return new_alpha_c, new_alpha, new_beta, new_gamma, new_c0, new_u0, \ + new_s0 + + return new_alpha_c, new_alpha, new_beta, new_gamma + + +@njit( + locals={ + "res": numba.types.float64[:, ::1], + "eat": numba.types.float64[::1], + "ebt": numba.types.float64[::1], + "egt": numba.types.float64[::1], + }, + fastmath=True) +def predict_exp(tau, + c0, + u0, + s0, + alpha_c, + alpha, + beta, + gamma, + scale_cc=1, + pred_r=True, + chrom_open=True, + backward=False, + rna_only=False): + + if len(tau) == 0: + return np.empty((0, 3)) + if backward: + tau = -tau + res = np.empty((len(tau), 3)) + eat = np.exp(-alpha_c * tau) + ebt = np.exp(-beta * tau) + egt = np.exp(-gamma * tau) + if rna_only: + kc = 1 + c0 = 1 + else: + if chrom_open: + kc = 1 + else: + kc = 0 + alpha_c *= scale_cc + + const = (kc - c0) * alpha / (beta - alpha_c) + + res[:, 0] = kc - (kc - c0) * eat + + if pred_r: + + res[:, 1] = u0 * ebt + (alpha * kc / beta) * (1 - ebt) + res[:, 1] += const * (ebt - eat) + + res[:, 2] = s0 * egt + (alpha * kc / gamma) * (1 - egt) + res[:, 2] += ((beta / (gamma - beta)) * + ((alpha * kc / beta) - u0 - const) * (egt - ebt)) + res[:, 2] += (beta / (gamma - alpha_c)) * const * (egt - eat) + + else: + res[:, 1] = np.zeros(len(tau)) + res[:, 2] = np.zeros(len(tau)) + return res + + +@njit(locals={ + "exp_sw1": numba.types.float64[:, ::1], + "exp_sw2": numba.types.float64[:, ::1], + "exp_sw3": numba.types.float64[:, ::1], + "exp1": numba.types.float64[:, ::1], + "exp2": numba.types.float64[:, ::1], + "exp3": numba.types.float64[:, ::1], + "exp4": numba.types.float64[:, ::1], + "tau_sw1": numba.types.float64[::1], + "tau_sw2": numba.types.float64[::1], + "tau_sw3": numba.types.float64[::1], + "tau1": numba.types.float64[::1], + "tau2": numba.types.float64[::1], + "tau3": numba.types.float64[::1], + "tau4": numba.types.float64[::1] + }, + fastmath=True) +def generate_exp(tau_list, + t_sw_array, + alpha_c, + alpha, + beta, + gamma, + scale_cc=1, + model=1, + rna_only=False): + + if beta == alpha_c: + beta += 1e-3 + if gamma == beta or gamma == alpha_c: + gamma += 1e-3 + switch = len(t_sw_array) + if switch >= 1: + tau_sw1 = np.array([t_sw_array[0]]) + if switch >= 2: + tau_sw2 = np.array([t_sw_array[1] - t_sw_array[0]]) + if switch == 3: + tau_sw3 = np.array([t_sw_array[2] - t_sw_array[1]]) + exp_sw1, exp_sw2, exp_sw3 = (np.empty((0, 3)), + np.empty((0, 3)), + np.empty((0, 3))) + if tau_list is None: + if model == 0: + if switch >= 1: + exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta, + gamma, pred_r=False, scale_cc=scale_cc, + rna_only=rna_only) + if switch >= 2: + exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], + exp_sw1[0, 1], exp_sw1[0, 2], + alpha_c, alpha, beta, gamma, + pred_r=False, chrom_open=False, + scale_cc=scale_cc, rna_only=rna_only) + if switch >= 3: + exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0], + exp_sw2[0, 1], exp_sw2[0, 2], + alpha_c, alpha, beta, gamma, + chrom_open=False, + scale_cc=scale_cc, + rna_only=rna_only) + elif model == 1: + if switch >= 1: + exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta, + gamma, pred_r=False, scale_cc=scale_cc, + rna_only=rna_only) + if switch >= 2: + exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], + exp_sw1[0, 1], exp_sw1[0, 2], + alpha_c, alpha, beta, gamma, + scale_cc=scale_cc, rna_only=rna_only) + if switch >= 3: + exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0], + exp_sw2[0, 1], exp_sw2[0, 2], + alpha_c, alpha, beta, gamma, + chrom_open=False, + scale_cc=scale_cc, + rna_only=rna_only) + elif model == 2: + if switch >= 1: + exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta, + gamma, pred_r=False, scale_cc=scale_cc, + rna_only=rna_only) + if switch >= 2: + exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], + exp_sw1[0, 1], exp_sw1[0, 2], + alpha_c, alpha, beta, gamma, + scale_cc=scale_cc, rna_only=rna_only) + if switch >= 3: + exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0], + exp_sw2[0, 1], exp_sw2[0, 2], + alpha_c, 0, beta, gamma, + scale_cc=scale_cc, + rna_only=rna_only) + + return (np.empty((0, 3)), np.empty((0, 3)), np.empty((0, 3)), + np.empty((0, 3))), (exp_sw1, exp_sw2, exp_sw3) + + tau1 = tau_list[0] + if switch >= 1: + tau2 = tau_list[1] + if switch >= 2: + tau3 = tau_list[2] + if switch == 3: + tau4 = tau_list[3] + exp1, exp2, exp3, exp4 = (np.empty((0, 3)), np.empty((0, 3)), + np.empty((0, 3)), np.empty((0, 3))) + if model == 0: + exp1 = predict_exp(tau1, 0, 0, 0, alpha_c, alpha, beta, gamma, + pred_r=False, scale_cc=scale_cc, rna_only=rna_only) + if switch >= 1: + exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta, + gamma, pred_r=False, scale_cc=scale_cc, + rna_only=rna_only) + exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, beta, gamma, + pred_r=False, chrom_open=False, + scale_cc=scale_cc, rna_only=rna_only) + if switch >= 2: + exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, beta, + gamma, pred_r=False, chrom_open=False, + scale_cc=scale_cc, rna_only=rna_only) + exp3 = predict_exp(tau3, exp_sw2[0, 0], exp_sw2[0, 1], + exp_sw2[0, 2], alpha_c, alpha, beta, gamma, + chrom_open=False, scale_cc=scale_cc, + rna_only=rna_only) + if switch == 3: + exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0], + exp_sw2[0, 1], exp_sw2[0, 2], + alpha_c, alpha, beta, gamma, + chrom_open=False, scale_cc=scale_cc, + rna_only=rna_only) + exp4 = predict_exp(tau4, exp_sw3[0, 0], exp_sw3[0, 1], + exp_sw3[0, 2], alpha_c, 0, beta, gamma, + chrom_open=False, scale_cc=scale_cc, + rna_only=rna_only) + elif model == 1: + exp1 = predict_exp(tau1, 0, 0, 0, alpha_c, alpha, beta, gamma, + pred_r=False, scale_cc=scale_cc, rna_only=rna_only) + if switch >= 1: + exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta, + gamma, pred_r=False, scale_cc=scale_cc, + rna_only=rna_only) + exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, beta, gamma, + scale_cc=scale_cc, rna_only=rna_only) + if switch >= 2: + exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, beta, + gamma, scale_cc=scale_cc, + rna_only=rna_only) + exp3 = predict_exp(tau3, exp_sw2[0, 0], exp_sw2[0, 1], + exp_sw2[0, 2], alpha_c, alpha, beta, gamma, + chrom_open=False, scale_cc=scale_cc, + rna_only=rna_only) + if switch == 3: + exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0], + exp_sw2[0, 1], exp_sw2[0, 2], + alpha_c, alpha, beta, gamma, + chrom_open=False, scale_cc=scale_cc, + rna_only=rna_only) + exp4 = predict_exp(tau4, exp_sw3[0, 0], exp_sw3[0, 1], + exp_sw3[0, 2], alpha_c, 0, beta, gamma, + chrom_open=False, scale_cc=scale_cc, + rna_only=rna_only) + elif model == 2: + exp1 = predict_exp(tau1, 0, 0, 0, alpha_c, alpha, beta, gamma, + pred_r=False, scale_cc=scale_cc, rna_only=rna_only) + if switch >= 1: + exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta, + gamma, pred_r=False, scale_cc=scale_cc, + rna_only=rna_only) + exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, beta, gamma, + scale_cc=scale_cc, rna_only=rna_only) + if switch >= 2: + exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, beta, + gamma, scale_cc=scale_cc, + rna_only=rna_only) + exp3 = predict_exp(tau3, exp_sw2[0, 0], exp_sw2[0, 1], + exp_sw2[0, 2], alpha_c, 0, beta, gamma, + scale_cc=scale_cc, rna_only=rna_only) + if switch == 3: + exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0], + exp_sw2[0, 1], exp_sw2[0, 2], + alpha_c, 0, beta, gamma, + scale_cc=scale_cc, rna_only=rna_only) + exp4 = predict_exp(tau4, exp_sw3[0, 0], exp_sw3[0, 1], + exp_sw3[0, 2], alpha_c, 0, beta, gamma, + chrom_open=False, scale_cc=scale_cc, + rna_only=rna_only) + return (exp1, exp2, exp3, exp4), (exp_sw1, exp_sw2, exp_sw3) + + +@njit(locals={ + "exp_sw1": numba.types.float64[:, ::1], + "exp_sw2": numba.types.float64[:, ::1], + "exp_sw3": numba.types.float64[:, ::1], + "exp1": numba.types.float64[:, ::1], + "exp2": numba.types.float64[:, ::1], + "exp3": numba.types.float64[:, ::1], + "exp4": numba.types.float64[:, ::1], + "tau_sw1": numba.types.float64[::1], + "tau_sw2": numba.types.float64[::1], + "tau_sw3": numba.types.float64[::1], + "tau1": numba.types.float64[::1], + "tau2": numba.types.float64[::1], + "tau3": numba.types.float64[::1], + "tau4": numba.types.float64[::1] + }, + fastmath=True) +def generate_exp_backward(tau_list, t_sw_array, alpha_c, alpha, beta, gamma, + scale_cc=1, model=1): + if beta == alpha_c: + beta += 1e-3 + if gamma == beta or gamma == alpha_c: + gamma += 1e-3 + switch = len(t_sw_array) + if switch >= 1: + tau_sw1 = np.array([t_sw_array[0]]) + if switch >= 2: + tau_sw2 = np.array([t_sw_array[1] - t_sw_array[0]]) + if t is None: + if model == 0: + exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, + gamma, scale_cc=scale_cc, chrom_open=False, + backward=True) + exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, beta, gamma, + scale_cc=scale_cc, chrom_open=False, + backward=True) + elif model == 1: + exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, + gamma, scale_cc=scale_cc, chrom_open=False, + backward=True) + exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, beta, gamma, + scale_cc=scale_cc, chrom_open=False, + backward=True) + elif model == 2: + exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, + gamma, scale_cc=scale_cc, chrom_open=False, + backward=True) + exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, 0, beta, gamma, + scale_cc=scale_cc, backward=True) + return (np.empty((0, 0)), + np.empty((0, 0)), + np.empty((0, 0))), (exp_sw1, exp_sw2) + + tau1 = tau_list[0] + if switch >= 1: + tau2 = tau_list[1] + if switch >= 2: + tau3 = tau_list[2] + + exp1, exp2, exp3 = np.empty((0, 3)), np.empty((0, 3)), np.empty((0, 3)) + if model == 0: + exp1 = predict_exp(tau1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, gamma, + scale_cc=scale_cc, chrom_open=False, backward=True) + if switch >= 1: + exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, + gamma, scale_cc=scale_cc, chrom_open=False, + backward=True) + exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, beta, gamma, + scale_cc=scale_cc, chrom_open=False, + backward=True) + if switch >= 2: + exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, beta, + gamma, scale_cc=scale_cc, + chrom_open=False, backward=True) + exp3 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, beta, gamma, + scale_cc=scale_cc, chrom_open=False, + backward=True) + elif model == 1: + exp1 = predict_exp(tau1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, gamma, + scale_cc=scale_cc, chrom_open=False, backward=True) + if switch >= 1: + exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, + gamma, scale_cc=scale_cc, chrom_open=False, + backward=True) + exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, beta, gamma, + scale_cc=scale_cc, chrom_open=False, + backward=True) + if switch >= 2: + exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, beta, + gamma, scale_cc=scale_cc, + chrom_open=False, backward=True) + exp3 = predict_exp(tau3, exp_sw2[0, 0], exp_sw2[0, 1], + exp_sw2[0, 2], alpha_c, alpha, beta, gamma, + scale_cc=scale_cc, backward=True) + elif model == 2: + exp1 = predict_exp(tau1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, gamma, + scale_cc=scale_cc, chrom_open=False, backward=True) + if switch >= 1: + exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, alpha, + beta, gamma, scale_cc=scale_cc, + chrom_open=False, backward=True) + exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, 0, beta, gamma, + scale_cc=scale_cc, backward=True) + if switch >= 2: + exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, 0, beta, gamma, + scale_cc=scale_cc, backward=True) + exp3 = predict_exp(tau3, exp_sw2[0, 0], exp_sw2[0, 1], + exp_sw2[0, 2], alpha_c, alpha, beta, gamma, + scale_cc=scale_cc, backward=True) + return (exp1, exp2, exp3), (exp_sw1, exp_sw2) + + +@njit(locals={ + "res": numba.types.float64[:, ::1], + }, + fastmath=True) +def ss_exp(alpha_c, alpha, beta, gamma, pred_r=True, chrom_open=True): + res = np.empty((1, 3)) + if not chrom_open: + res[0, 0] = 0 + res[0, 1] = 0 + res[0, 2] = 0 + else: + res[0, 0] = 1 + if pred_r: + res[0, 1] = alpha / beta + res[0, 2] = alpha / gamma + else: + res[0, 1] = 0 + res[0, 2] = 0 + return res + + +@njit(locals={ + "ss1": numba.types.float64[:, ::1], + "ss2": numba.types.float64[:, ::1], + "ss3": numba.types.float64[:, ::1], + "ss4": numba.types.float64[:, ::1] + }, + fastmath=True) +def compute_ss_exp(alpha_c, alpha, beta, gamma, model=0): + if model == 0: + ss1 = ss_exp(alpha_c, alpha, beta, gamma, pred_r=False) + ss2 = ss_exp(alpha_c, alpha, beta, gamma, pred_r=False, + chrom_open=False) + ss3 = ss_exp(alpha_c, alpha, beta, gamma, chrom_open=False) + ss4 = ss_exp(alpha_c, 0, beta, gamma, chrom_open=False) + elif model == 1: + ss1 = ss_exp(alpha_c, alpha, beta, gamma, pred_r=False) + ss2 = ss_exp(alpha_c, alpha, beta, gamma) + ss3 = ss_exp(alpha_c, alpha, beta, gamma, chrom_open=False) + ss4 = ss_exp(alpha_c, 0, beta, gamma, chrom_open=False) + elif model == 2: + ss1 = ss_exp(alpha_c, alpha, beta, gamma, pred_r=False) + ss2 = ss_exp(alpha_c, alpha, beta, gamma) + ss3 = ss_exp(alpha_c, 0, beta, gamma) + ss4 = ss_exp(alpha_c, 0, beta, gamma, chrom_open=False) + return np.vstack((ss1, ss2, ss3, ss4)) + + +@njit(fastmath=True) +def velocity_equations(c, u, s, alpha_c, alpha, beta, gamma, scale_cc=1, + pred_r=True, chrom_open=True, rna_only=False): + if rna_only: + c = np.full(len(u), 1.0) + if not chrom_open: + alpha_c *= scale_cc + if pred_r: + return -alpha_c * c, alpha * c - beta * u, beta * u - gamma * s + else: + return -alpha_c * c, np.zeros(len(u)), np.zeros(len(u)) + else: + if pred_r: + return (alpha_c - alpha_c * c), (alpha * c - beta * u), (beta * u + - gamma + * s) + else: + return alpha_c - alpha_c * c, np.zeros(len(u)), np.zeros(len(u)) + + +@njit(locals={ + "state0": numba.types.boolean[::1], + "state1": numba.types.boolean[::1], + "state2": numba.types.boolean[::1], + "state3": numba.types.boolean[::1], + "tau1": numba.types.float64[::1], + "tau2": numba.types.float64[::1], + "tau3": numba.types.float64[::1], + "tau4": numba.types.float64[::1], + "exp_list": numba.types.Tuple((numba.types.float64[:, ::1], + numba.types.float64[:, ::1], + numba.types.float64[:, ::1], + numba.types.float64[:, ::1])), + "exp_sw_list": numba.types.Tuple((numba.types.float64[:, ::1], + numba.types.float64[:, ::1], + numba.types.float64[:, ::1])), + "c": numba.types.float64[::1], + "u": numba.types.float64[::1], + "s": numba.types.float64[::1], + "vc_vec": numba.types.float64[::1], + "vu_vec": numba.types.float64[::1], + "vs_vec": numba.types.float64[::1] + }, + fastmath=True) +def compute_velocity(t, + t_sw_array, + state, + alpha_c, + alpha, + beta, + gamma, + rescale_c, + rescale_u, + scale_cc=1, + model=1, + total_h=20, + rna_only=False): + + if state is None: + state0 = t <= t_sw_array[0] + state1 = (t_sw_array[0] < t) & (t <= t_sw_array[1]) + state2 = (t_sw_array[1] < t) & (t <= t_sw_array[2]) + state3 = t_sw_array[2] < t + else: + state0 = np.equal(state, 0) + state1 = np.equal(state, 1) + state2 = np.equal(state, 2) + state3 = np.equal(state, 3) + + tau1 = t[state0] + tau2 = t[state1] - t_sw_array[0] + tau3 = t[state2] - t_sw_array[1] + tau4 = t[state3] - t_sw_array[2] + tau_list = [tau1, tau2, tau3, tau4] + switch = np.sum(t_sw_array < total_h) + typed_tau_list = List() + [typed_tau_list.append(x) for x in tau_list] + exp_list, exp_sw_list = generate_exp(typed_tau_list, + t_sw_array[:switch], + alpha_c, + alpha, + beta, + gamma, + model=model, + scale_cc=scale_cc, + rna_only=rna_only) + + c = np.empty(len(t)) + u = np.empty(len(t)) + s = np.empty(len(t)) + for i, ii in enumerate([state0, state1, state2, state3]): + if np.any(ii): + c[ii] = exp_list[i][:, 0] + u[ii] = exp_list[i][:, 1] + s[ii] = exp_list[i][:, 2] + + vc_vec = np.zeros(len(u)) + vu_vec = np.zeros(len(u)) + vs_vec = np.zeros(len(u)) + + if model == 0: + if np.any(state0): + vc_vec[state0], vu_vec[state0], vs_vec[state0] = \ + velocity_equations(c[state0], u[state0], s[state0], alpha_c, + alpha, beta, gamma, pred_r=False, + scale_cc=scale_cc, rna_only=rna_only) + if np.any(state1): + vc_vec[state1], vu_vec[state1], vs_vec[state1] = \ + velocity_equations(c[state1], u[state1], s[state1], alpha_c, + alpha, beta, gamma, pred_r=False, + chrom_open=False, scale_cc=scale_cc, + rna_only=rna_only) + if np.any(state2): + vc_vec[state2], vu_vec[state2], vs_vec[state2] = \ + velocity_equations(c[state2], u[state2], s[state2], alpha_c, + alpha, beta, gamma, chrom_open=False, + scale_cc=scale_cc, rna_only=rna_only) + if np.any(state3): + vc_vec[state3], vu_vec[state3], vs_vec[state3] = \ + velocity_equations(c[state3], u[state3], s[state3], alpha_c, 0, + beta, gamma, chrom_open=False, + scale_cc=scale_cc, rna_only=rna_only) + elif model == 1: + if np.any(state0): + vc_vec[state0], vu_vec[state0], vs_vec[state0] = \ + velocity_equations(c[state0], u[state0], s[state0], alpha_c, + alpha, beta, gamma, pred_r=False, + scale_cc=scale_cc, rna_only=rna_only) + if np.any(state1): + vc_vec[state1], vu_vec[state1], vs_vec[state1] = \ + velocity_equations(c[state1], u[state1], s[state1], alpha_c, + alpha, beta, gamma, scale_cc=scale_cc, + rna_only=rna_only) + if np.any(state2): + vc_vec[state2], vu_vec[state2], vs_vec[state2] = \ + velocity_equations(c[state2], u[state2], s[state2], alpha_c, + alpha, beta, gamma, chrom_open=False, + scale_cc=scale_cc, rna_only=rna_only) + if np.any(state3): + vc_vec[state3], vu_vec[state3], vs_vec[state3] = \ + velocity_equations(c[state3], u[state3], s[state3], alpha_c, 0, + beta, gamma, chrom_open=False, + scale_cc=scale_cc, rna_only=rna_only) + elif model == 2: + if np.any(state0): + vc_vec[state0], vu_vec[state0], vs_vec[state0] = \ + velocity_equations(c[state0], u[state0], s[state0], alpha_c, + alpha, beta, gamma, pred_r=False, + scale_cc=scale_cc, rna_only=rna_only) + if np.any(state1): + vc_vec[state1], vu_vec[state1], vs_vec[state1] = \ + velocity_equations(c[state1], u[state1], s[state1], alpha_c, + alpha, beta, gamma, scale_cc=scale_cc, + rna_only=rna_only) + if np.any(state2): + vc_vec[state2], vu_vec[state2], vs_vec[state2] = \ + velocity_equations(c[state2], u[state2], s[state2], alpha_c, + 0, beta, gamma, scale_cc=scale_cc, + rna_only=rna_only) + if np.any(state3): + vc_vec[state3], vu_vec[state3], vs_vec[state3] = \ + velocity_equations(c[state3], u[state3], s[state3], alpha_c, 0, + beta, gamma, chrom_open=False, + scale_cc=scale_cc, rna_only=rna_only) + return vc_vec * rescale_c, vu_vec * rescale_u, vs_vec + + +def log_valid(x): + return np.log(np.clip(x, 1e-3, 1 - 1e-3)) + + +def approx_tau(u, s, u0, s0, alpha, beta, gamma): + if gamma == beta: + gamma -= 1e-3 + u_inf = alpha / beta + if beta > gamma: + b_new = beta / (gamma - beta) + s_inf = alpha / gamma + s_inf_new = s_inf - b_new * u_inf + s_new = s - b_new * u + s0_new = s0 - b_new * u0 + tau = -1.0 / gamma * log_valid((s_new - s_inf_new) / + (s0_new - s_inf_new)) + else: + tau = -1.0 / beta * log_valid((u - u_inf) / (u0 - u_inf)) + return tau + + +def anchor_points(t_sw_array, total_h=20, t=1000, mode='uniform', + return_time=False): + t_ = np.linspace(0, total_h, t) + tau1 = t_[t_ <= t_sw_array[0]] + tau2 = t_[(t_sw_array[0] < t_) & (t_ <= t_sw_array[1])] - t_sw_array[0] + tau3 = t_[(t_sw_array[1] < t_) & (t_ <= t_sw_array[2])] - t_sw_array[1] + tau4 = t_[t_sw_array[2] < t_] - t_sw_array[2] + + if mode == 'log': + if len(tau1) > 0: + tau1 = np.expm1(tau1) + tau1 = tau1 / np.max(tau1) * (t_sw_array[0]) + if len(tau2) > 0: + tau2 = np.expm1(tau2) + tau2 = tau2 / np.max(tau2) * (t_sw_array[1] - t_sw_array[0]) + if len(tau3) > 0: + tau3 = np.expm1(tau3) + tau3 = tau3 / np.max(tau3) * (t_sw_array[2] - t_sw_array[1]) + if len(tau4) > 0: + tau4 = np.expm1(tau4) + tau4 = tau4 / np.max(tau4) * (total_h - t_sw_array[2]) + + tau_list = [tau1, tau2, tau3, tau4] + if return_time: + return t_, tau_list + else: + return tau_list + + +# @jit(nopython=True, fastmath=True, debug=True) +def pairwise_distance_square(X, Y): + res = np.empty((X.shape[0], Y.shape[0]), dtype=X.dtype) + for a in range(X.shape[0]): + for b in range(Y.shape[0]): + val = 0.0 + for i in range(X.shape[1]): + tmp = X[a, i] - Y[b, i] + val += tmp**2 + res[a, b] = val + return res + + +def calculate_dist_and_time(c, u, s, + t_sw_array, + alpha_c, alpha, beta, gamma, + rescale_c, rescale_u, + scale_cc=1, + scale_factor=None, + model=1, + conn=None, + t=1000, k=1, + direction='complete', + total_h=20, + rna_only=False, + penalize_gap=True, + all_cells=True): + + n = len(u) + if scale_factor is None: + scale_factor = np.array([np.std(c), np.std(u), np.std(s)]) + tau_list = anchor_points(t_sw_array, total_h, t) + switch = np.sum(t_sw_array < total_h) + typed_tau_list = List() + [typed_tau_list.append(x) for x in tau_list] + alpha_c, alpha, beta, gamma = check_params(alpha_c, alpha, beta, gamma) + exp_list, exp_sw_list = generate_exp(typed_tau_list, + t_sw_array[:switch], + alpha_c, + alpha, + beta, + gamma, + model=model, + scale_cc=scale_cc, + rna_only=rna_only) + rescale_factor = np.array([rescale_c, rescale_u, 1.0]) + exp_list = [x*rescale_factor for x in exp_list] + exp_sw_list = [x*rescale_factor for x in exp_sw_list] + max_c = 0 + max_u = 0 + max_s = 0 + if rna_only: + exp_mat = (np.hstack((np.reshape(u, (-1, 1)), np.reshape(s, (-1, 1)))) + / scale_factor[1:]) + else: + exp_mat = np.hstack((np.reshape(c, (-1, 1)), np.reshape(u, (-1, 1)), + np.reshape(s, (-1, 1)))) / scale_factor + + dists = np.full((n, 4), np.inf) + taus = np.zeros((n, 4), dtype=u.dtype) + ts = np.zeros((n, 4), dtype=u.dtype) + anchor_exp, anchor_t = None, None + + for i in range(switch+1): + if not all_cells: + max_ci = (np.max(exp_list[i][:, 0]) if exp_list[i].shape[0] > 0 + else 0) + max_c = max_ci if max_ci > max_c else max_c + max_ui = np.max(exp_list[i][:, 1]) if exp_list[i].shape[0] > 0 else 0 + max_u = max_ui if max_ui > max_u else max_u + max_si = np.max(exp_list[i][:, 2]) if exp_list[i].shape[0] > 0 else 0 + max_s = max_si if max_si > max_s else max_s + + skip_phase = False + if direction == 'off': + if (model in [1, 2]) and (i < 2): + skip_phase = True + elif direction == 'on': + if (model in [1, 2]) and (i >= 2): + skip_phase = True + if rna_only and i == 0: + skip_phase = True + + if not skip_phase: + if rna_only: + tmp = exp_list[i][:, 1:] / scale_factor[1:] + else: + tmp = exp_list[i] / scale_factor + if anchor_exp is None: + anchor_exp = exp_list[i] + anchor_t = (tau_list[i] + t_sw_array[i-1] if i >= 1 + else tau_list[i]) + else: + anchor_exp = np.vstack((anchor_exp, exp_list[i])) + anchor_t = np.hstack((anchor_t, tau_list[i] + t_sw_array[i-1] + if i >= 1 else tau_list[i])) + + if not all_cells: + anchor_dist = np.diff(tmp, axis=0, prepend=np.zeros((1, 2)) + if rna_only else np.zeros((1, 3))) + anchor_dist = np.sqrt((anchor_dist**2).sum(axis=1)) + remove_cand = anchor_dist < (0.01*np.max(exp_mat[1]) + if rna_only + else 0.01*np.max(exp_mat[2])) + step_idx = np.arange(0, len(anchor_dist), 1) % 3 > 0 + remove_cand &= step_idx + keep_idx = np.where(~remove_cand)[0] + tmp = tmp[keep_idx, :] + + tree = KDTree(tmp) + dd, ii = tree.query(exp_mat, k=k) + dd = dd**2 + if k > 1: + dd = np.mean(dd, axis=1) + if conn is not None: + dd = conn.dot(dd) + dists[:, i] = dd + + if not all_cells: + ii = keep_idx[ii] + if k == 1: + taus[:, i] = tau_list[i][ii] + else: + for j in range(n): + taus[j, i] = tau_list[i][ii[j, :]] + ts[:, i] = taus[:, i] + t_sw_array[i-1] if i >= 1 else taus[:, i] + + min_dist = np.min(dists, axis=1) + state_pred = np.argmin(dists, axis=1) + t_pred = ts[np.arange(n), state_pred] + + anchor_t1_list = [] + anchor_t2_list = [] + t_sw_adjust = np.zeros(3, dtype=u.dtype) + + if direction == 'complete': + t_sorted = np.sort(t_pred) + dt = np.diff(t_sorted, prepend=0) + gap_thresh = 3*np.percentile(dt, 99) + idx = np.where(dt > gap_thresh)[0] + for i in idx: + t1 = t_sorted[i-1] if i > 0 else 0 + t2 = t_sorted[i] + anchor_t1 = anchor_exp[np.argmin(np.abs(anchor_t - t1)), :] + anchor_t2 = anchor_exp[np.argmin(np.abs(anchor_t - t2)), :] + if all_cells: + anchor_t1_list.append(np.ravel(anchor_t1)) + anchor_t2_list.append(np.ravel(anchor_t2)) + if not all_cells: + for j in range(1, switch): + crit1 = ((t1 > t_sw_array[j-1]) and (t2 > t_sw_array[j-1]) + and (t1 <= t_sw_array[j]) + and (t2 <= t_sw_array[j])) + crit2 = ((np.abs(anchor_t1[2] - exp_sw_list[j][0, 2]) + < 0.02 * max_s) and + (np.abs(anchor_t2[2] - exp_sw_list[j][0, 2]) + < 0.01 * max_s)) + crit3 = ((np.abs(anchor_t1[1] - exp_sw_list[j][0, 1]) + < 0.02 * max_u) and + (np.abs(anchor_t2[1] - exp_sw_list[j][0, 1]) + < 0.01 * max_u)) + crit4 = ((np.abs(anchor_t1[0] - exp_sw_list[j][0, 0]) + < 0.02 * max_c) and + (np.abs(anchor_t2[0] - exp_sw_list[j][0, 0]) + < 0.01 * max_c)) + if crit1 and crit2 and crit3 and crit4: + t_sw_adjust[j] += t2 - t1 + if penalize_gap: + dist_gap = np.sum(((anchor_t1[1:] - anchor_t2[1:]) / + scale_factor[1:])**2) + idx_to_adjust = t_pred >= t2 + t_sw_array_ = np.append(t_sw_array, total_h) + state_to_adjust = np.where(t_sw_array_ > t2)[0] + dists[np.ix_(idx_to_adjust, state_to_adjust)] += dist_gap + min_dist = np.min(dists, axis=1) + state_pred = np.argmin(dists, axis=1) + if all_cells: + t_pred = ts[np.arange(n), state_pred] + + if all_cells: + exp_ss_mat = compute_ss_exp(alpha_c, alpha, beta, gamma, model=model) + if rna_only: + exp_ss_mat[:, 0] = 1 + dists_ss = pairwise_distance_square(exp_mat, exp_ss_mat * + rescale_factor / scale_factor) + + reach_ss = np.full((n, 4), False) + for i in range(n): + for j in range(4): + if min_dist[i] > dists_ss[i, j]: + reach_ss[i, j] = True + late_phase = np.full(n, -1) + for i in range(3): + late_phase[np.abs(t_pred - t_sw_array[i]) < 0.1] = i + return min_dist, t_pred, state_pred, reach_ss, late_phase, max_u, \ + max_s, anchor_t1_list, anchor_t2_list + else: + return min_dist, state_pred, max_u, max_s, t_sw_adjust + + +def t_of_c(alpha_c, k_c, c_o, c, rescale_factor, sw_t): + + coef = -float(1)/alpha_c + + c_val = np.clip(c / rescale_factor, a_min=0, a_max=1) + + in_log = (float(k_c) - c_val) / float((k_c) - (c_o)) + + epsilon = 1e-9 + + return_val = coef * np.log(in_log + epsilon) + + if k_c == 0: + return_val += sw_t + + return return_val + + +def make_X(c, u, s, + max_u, + max_s, + alpha_c, alpha, beta, gamma, + gene_sw_t, + c0, c_sw1, c_sw2, c_sw3, + u0, u_sw1, u_sw2, u_sw3, + s0, s_sw1, s_sw2, s_sw3, + model, direction, state): + + if direction == "complete": + dire = 0 + elif direction == "on": + dire = 1 + elif direction == "off": + dire = 2 + + n = c.shape[0] + + epsilon = 1e-5 + + if dire == 0: + x = np.concatenate((np.array([c, + np.log(u + epsilon), + np.log(s + epsilon)]), + np.full((n, 17), [np.log(alpha_c + epsilon), + np.log(alpha + epsilon), + np.log(beta + epsilon), + np.log(gamma + epsilon), + c_sw1, c_sw2, c_sw3, + np.log(u_sw2 + epsilon), + np.log(u_sw3 + epsilon), + np.log(s_sw2 + epsilon), + np.log(s_sw3 + epsilon), + np.log(max_u), + np.log(max_s), + gene_sw_t[0], + gene_sw_t[1], + gene_sw_t[2], + model]).T, + np.full((n, 1), state).T + )).T.astype(np.float32) + + elif dire == 1: + x = np.concatenate((np.array([c, + np.log(u + epsilon), + np.log(s + epsilon)]), + np.full((n, 12), [np.log(alpha_c + epsilon), + np.log(alpha + epsilon), + np.log(beta + epsilon), + np.log(gamma + epsilon), + c_sw1, c_sw2, + np.log(u_sw1 + epsilon), + np.log(u_sw2 + epsilon), + np.log(s_sw1 + epsilon), + np.log(s_sw2 + epsilon), + gene_sw_t[0], + model]).T, + np.full((n, 1), state).T + )).T.astype(np.float32) + + elif dire == 2: + if model == 1: + + max_u_t = -(float(1)/alpha_c)*np.log((max_u*beta) + / (alpha*c0[2])) + + x = np.concatenate((np.array([np.log(c + epsilon), + np.log(u + epsilon), + np.log(s + epsilon)]), + np.full((n, 14), [np.log(alpha_c + epsilon), + np.log(alpha + epsilon), + np.log(beta + epsilon), + np.log(gamma + epsilon), + c_sw2, c_sw3, + np.log(u_sw2 + epsilon), + np.log(u_sw3 + epsilon), + np.log(s_sw2 + epsilon), + np.log(s_sw3 + epsilon), + max_u_t, + np.log(max_u), + np.log(max_s), + gene_sw_t[2]]).T, + np.full((n, 1), state).T + )).T.astype(np.float32) + elif model == 2: + x = np.concatenate((np.array([c, + np.log(u + epsilon), + np.log(s + epsilon)]), + np.full((n, 12), [np.log(alpha_c + epsilon), + np.log(alpha + epsilon), + np.log(beta + epsilon), + np.log(gamma + epsilon), + c_sw2, c_sw3, + np.log(u_sw2 + epsilon), + np.log(u_sw3 + epsilon), + np.log(s_sw2 + epsilon), + np.log(s_sw3 + epsilon), + np.log(max_u), + gene_sw_t[2]]).T, + np.full((n, 1), state).T + )).T.astype(np.float32) + + return x + + +def calculate_dist_and_time_nn(c, u, s, + max_u, max_s, + t_sw_array, + alpha_c, alpha, beta, gamma, + rescale_c, rescale_u, + ode_model_0, ode_model_1, + ode_model_2_m1, ode_model_2_m2, + device, + scale_cc=1, + scale_factor=None, + model=1, + conn=None, + t=1000, k=1, + direction='complete', + total_h=20, + rna_only=False, + penalize_gap=True, + all_cells=True): + + rescale_factor = np.array([rescale_c, rescale_u, 1.0]) + + exp_list_net, exp_sw_list_net = generate_exp(None, + t_sw_array, + alpha_c, + alpha, + beta, + gamma, + model=model, + scale_cc=scale_cc, + rna_only=rna_only) + + N = len(c) + N_list = np.arange(N) + + if scale_factor is None: + cur_scale_factor = np.array([np.std(c), + np.std(u), + np.std(s)]) + else: + cur_scale_factor = scale_factor + + t_pred_per_state = [] + dists_per_state = [] + + dire = 0 + + if direction == "on": + states = [0, 1] + dire = 1 + + elif direction == "off": + states = [2, 3] + dire = 2 + + else: + states = [0, 1, 2, 3] + dire = 0 + + dists_per_state = np.zeros((N, len(states))) + t_pred_per_state = np.zeros((N, len(states))) + u_pred_per_state = np.zeros((N, len(states))) + s_pred_per_state = np.zeros((N, len(states))) + + increment = 0 + + # determine when we can consider u and s close to zero + zero_us = np.logical_and((u < 0.1 * max_u), (s < 0.1 * max_s)) + + t_pred = np.zeros(N) + dists = None + + # pass all the data through the neural net as each valid state + for state in states: + + # when u and s = 0, it's better to use the inverse c equation + # instead of the neural network, which happens for part of + # state 3 and all of state 0 + inverse_c = np.logical_or(state == 0, + np.logical_and(state == 3, zero_us)) + + not_inverse_c = np.logical_not(inverse_c) + + # if we want to use the inverse c equation... + if np.any(inverse_c): + + # find out at what switch time chromatin closes + c_sw_t = t_sw_array[int(model)] + + # figure out whether chromatin is opening/closing and what + # the initial c value is + if state <= model: + k_c = 1 + c_0_for_t_guess = 0 + elif state > model: + k_c = 0 + c_0_for_t_guess = exp_sw_list_net[int(model)][0, 0] + + # calculate predicted time from the inverse c equation + t_pred[inverse_c] = t_of_c(alpha_c, + k_c, c_0_for_t_guess, + c[inverse_c], + rescale_factor[0], + c_sw_t) + + # if there are points where we want to use the neural network... + if np.any(not_inverse_c): + + # create an input matrix from the data + x = make_X(c[not_inverse_c] / rescale_factor[0], + u[not_inverse_c] / rescale_factor[1], + s[not_inverse_c] / rescale_factor[2], + max_u, + max_s, + alpha_c*(scale_cc if state > model else 1), + alpha, beta, gamma, + t_sw_array, + 0, + exp_sw_list_net[0][0, 0], + exp_sw_list_net[1][0, 0], + exp_sw_list_net[2][0, 0], + 0, + exp_sw_list_net[0][0, 1], + exp_sw_list_net[1][0, 1], + exp_sw_list_net[2][0, 1], + 0, + exp_sw_list_net[0][0, 2], + exp_sw_list_net[1][0, 2], + exp_sw_list_net[2][0, 2], + model, direction, state) + + # do a forward pass + if dire == 0: + t_pred_ten = ode_model_0(torch.tensor(x, + dtype=torch.float, + device=device) + .reshape(-1, x.shape[1])) + + elif dire == 1: + t_pred_ten = ode_model_1(torch.tensor(x, + dtype=torch.float, + device=device) + .reshape(-1, x.shape[1])) + + elif dire == 2: + if model == 1: + t_pred_ten = ode_model_2_m1(torch.tensor(x, + dtype=torch.float, + device=device) + .reshape(-1, x.shape[1])) + elif model == 2: + t_pred_ten = ode_model_2_m2(torch.tensor(x, + dtype=torch.float, + device=device) + .reshape(-1, x.shape[1])) + + # make a numpy array out of our tensor of predicted time points + t_pred[not_inverse_c] = (t_pred_ten.cpu().detach().numpy() + .flatten()*21) - 1 + + # calculate tau values from our predicted time points + if state == 0: + t_pred = np.clip(t_pred, a_min=0, a_max=t_sw_array[0]) + tau1 = t_pred + tau2 = [] + tau3 = [] + tau4 = [] + elif state == 1: + tau1 = [] + t_pred = np.clip(t_pred, a_min=t_sw_array[0], a_max=t_sw_array[1]) + tau2 = t_pred - t_sw_array[0] + tau3 = [] + tau4 = [] + elif state == 2: + tau1 = [] + tau2 = [] + t_pred = np.clip(t_pred, a_min=t_sw_array[1], a_max=t_sw_array[2]) + tau3 = t_pred - t_sw_array[1] + tau4 = [] + elif state == 3: + tau1 = [] + tau2 = [] + tau3 = [] + t_pred = np.clip(t_pred, a_min=t_sw_array[2], a_max=20) + tau4 = t_pred - t_sw_array[2] + + tau_list = [tau1, tau2, tau3, tau4] + + valid_vals = [] + + for i in range(len(tau_list)): + if len(tau_list[i]) == 0: + tau_list[i] = np.array([0.0]) + else: + valid_vals.append(i) + + # take the time points and get predicted c/u/s values from them + exp_list, exp_sw_list_2 = generate_exp(tau_list, + t_sw_array, + alpha_c, + alpha, + beta, + gamma, + model=model, + scale_cc=scale_cc, + rna_only=rna_only) + + pred_c = np.concatenate([exp_list[x][:, 0] * rescale_factor[0] + for x in valid_vals]) + pred_u = np.concatenate([exp_list[x][:, 1] * rescale_factor[1] + for x in valid_vals]) + pred_s = np.concatenate([exp_list[x][:, 2] * rescale_factor[2] + for x in valid_vals]) + + # calculate distance between predicted and real values + c_diff = (c - pred_c) / cur_scale_factor[0] + u_diff = (u - pred_u) / cur_scale_factor[1] + s_diff = (s - pred_s) / cur_scale_factor[2] + + dists = (c_diff*c_diff) + (u_diff*u_diff) + (s_diff*s_diff) + + if conn is not None: + dists = conn.dot(dists) + + # store the distances, times, and predicted u and s values for + # each state + dists_per_state[:, increment] = dists + t_pred_per_state[:, increment] = t_pred + u_pred_per_state[:, increment] = pred_u + s_pred_per_state[:, increment] = pred_s + + increment += 1 + + # whichever state has the smallest distance for a given data point + # is our predicted state + state_pred = np.argmin(dists_per_state, axis=1) + + # slice dists and predicted time over the correct state + dists = dists_per_state[N_list, state_pred] + t_pred = t_pred_per_state[N_list, state_pred] + + max_t = t_pred.max() + min_t = t_pred.min() + + penalty = 0 + + # for induction and complete genes, add a penalty to ensure that not + # all points are in state 0 + if direction == "on" or direction == "complete": + + if t_sw_array[0] >= max_t: + penalty += (t_sw_array[0] - max_t) + 10 + + # for induction genes, add a penalty to ensure that predicted time + # points are not "out of bounds" by being greater than the + # second switch time + if direction == "on": + + if min_t > t_sw_array[1]: + penalty += (min_t - t_sw_array[1]) + 10 + + # for repression genes, add a penalty to ensure that predicted time + # points are not "out of bounds" by being smaller than the + # second switch time + if direction == "off": + + if t_sw_array[1] >= max_t: + penalty += (t_sw_array[1] - max_t) + 10 + + # add penalty to ensure that the time points aren't concentrated to + # one spot + if np.abs(max_t - min_t) <= 1e-2: + penalty += np.abs(max_t - min_t) + 10 + + # because the indices chosen by np.argmin are just indices, + # we need to increment by two to get the true state number for + # our "off" genes (e.g. so that they're in the domain of [2,3] instead + # of [0,1]) + if direction == "off": + state_pred += 2 + + if all_cells: + return dists, t_pred, state_pred, max_u, max_s, penalty + else: + return dists, state_pred, max_u, max_s, penalty + + +# @jit(nopython=True, fastmath=True) +def compute_likelihood(c, u, s, + t_sw_array, + alpha_c, alpha, beta, gamma, + rescale_c, rescale_u, + t_pred, + state_pred, + scale_cc=1, + scale_factor=None, + model=1, + weight=None, + total_h=20, + rna_only=False): + + if weight is None: + weight = np.full(c.shape, True) + c_ = c[weight] + u_ = u[weight] + s_ = s[weight] + t_pred_ = t_pred[weight] + state_pred_ = state_pred[weight] + + n = len(u_) + if scale_factor is None: + scale_factor = np.ones(3) + tau1 = t_pred_[state_pred_ == 0] + tau2 = t_pred_[state_pred_ == 1] - t_sw_array[0] + tau3 = t_pred_[state_pred_ == 2] - t_sw_array[1] + tau4 = t_pred_[state_pred_ == 3] - t_sw_array[2] + tau_list = [tau1, tau2, tau3, tau4] + switch = np.sum(t_sw_array < total_h) + typed_tau_list = List() + [typed_tau_list.append(x) for x in tau_list] + alpha_c, alpha, beta, gamma = check_params(alpha_c, alpha, beta, gamma) + exp_list, _ = generate_exp(typed_tau_list, + t_sw_array[:switch], + alpha_c, + alpha, + beta, + gamma, + model=model, + scale_cc=scale_cc, + rna_only=rna_only) + rescale_factor = np.array([rescale_c, rescale_u, 1.0]) + exp_list = [x*rescale_factor*scale_factor for x in exp_list] + exp_mat = np.hstack((np.reshape(c_, (-1, 1)), np.reshape(u_, (-1, 1)), + np.reshape(s_, (-1, 1)))) * scale_factor + diffs = np.empty((n, 3), dtype=u.dtype) + likelihood_c = 0 + likelihood_u = 0 + likelihood_s = 0 + ssd_c, var_c = 0, 0 + for i in range(switch+1): + index = state_pred_ == i + if np.sum(index) > 0: + diff = exp_mat[index, :] - exp_list[i] + diffs[index, :] = diff + if rna_only: + diff_u = np.ravel(diffs[:, 0]) + diff_s = np.ravel(diffs[:, 1]) + dist_us = diff_u ** 2 + diff_s ** 2 + var_us = np.var(np.sign(diff_s) * np.sqrt(dist_us)) + nll = (0.5 * np.log(2 * np.pi * var_us) + 0.5 / n / + var_us * np.sum(dist_us)) + else: + diff_c = np.ravel(diffs[:, 0]) + diff_u = np.ravel(diffs[:, 1]) + diff_s = np.ravel(diffs[:, 2]) + dist_c = diff_c ** 2 + dist_u = diff_u ** 2 + dist_s = diff_s ** 2 + var_c = np.var(diff_c) + var_u = np.var(diff_u) + var_s = np.var(diff_s) + ssd_c = np.sum(dist_c) + nll_c = (0.5 * np.log(2 * np.pi * var_c) + 0.5 / n / + var_c * np.sum(dist_c)) + nll_u = (0.5 * np.log(2 * np.pi * var_u) + 0.5 / n / + var_u * np.sum(dist_u)) + nll_s = (0.5 * np.log(2 * np.pi * var_s) + 0.5 / n / + var_s * np.sum(dist_s)) + nll = nll_c + nll_u + nll_s + likelihood_c = np.exp(-nll_c) + likelihood_u = np.exp(-nll_u) + likelihood_s = np.exp(-nll_s) + likelihood = np.exp(-nll) + return likelihood, likelihood_c, ssd_c, var_c, likelihood_u, likelihood_s + + +class ChromatinDynamical: + def __init__(self, c, u, s, + gene=None, + model=None, + max_iter=10, + init_mode="grid", + device="cpu", + neural_net=False, + adam=False, + adam_lr=None, + adam_beta1=None, + adam_beta2=None, + batch_size=None, + local_std=None, + embed_coord=None, + connectivities=None, + plot=False, + save_plot=False, + plot_dir=None, + fit_args=None, + partial=None, + direction=None, + rna_only=False, + fit_decoupling=True, + extra_color=None, + rescale_u=None, + alpha=None, + beta=None, + gamma=None, + t_=None + ): + + self.device = device + self.gene = gene + self.local_std = local_std + self.conn = connectivities + + self.neural_net = neural_net + self.adam = adam + self.adam_lr = adam_lr + self.adam_beta1 = adam_beta1 + self.adam_beta2 = adam_beta2 + self.batch_size = batch_size + + self.torch_type = type(u[0].item()) + + # fitting arguments + self.init_mode = init_mode + self.rna_only = rna_only + self.fit_decoupling = fit_decoupling + self.max_iter = max_iter + self.n_anchors = np.clip(int(fit_args['t']), 201, 2000) + self.k_dist = np.clip(int(fit_args['k']), 1, 20) + self.tm = np.clip(fit_args['thresh_multiplier'], 0.4, 2) + self.weight_c = np.clip(fit_args['weight_c'], 0.1, 5) + self.outlier = np.clip(fit_args['outlier'], 80, 100) + self.model = int(model) if isinstance(model, float) else model + self.model_ = None + if self.model == 0 and self.init_mode == 'invert': + self.init_mode = 'grid' + + # plot parameters + self.plot = plot + self.save_plot = save_plot + self.extra_color = extra_color + self.fig_size = fit_args['fig_size'] + self.point_size = fit_args['point_size'] + if plot_dir is None: + self.plot_path = 'rna_plots' if self.rna_only else 'plots' + else: + self.plot_path = plot_dir + self.color = ['tab:red', 'tab:orange', 'tab:green', 'tab:blue'] + self.fig = None + self.ax = None + + # input + self.total_n = len(u) + if sparse.issparse(c): + c = c.A + if sparse.issparse(u): + u = u.A + if sparse.issparse(s): + s = s.A + self.c_all = np.ravel(np.array(c, dtype=np.float64)) + self.u_all = np.ravel(np.array(u, dtype=np.float64)) + self.s_all = np.ravel(np.array(s, dtype=np.float64)) + + # adjust offset + self.offset_c, self.offset_u, self.offset_s = np.min(self.c_all), \ + np.min(self.u_all), np.min(self.s_all) + self.offset_c = 0 if self.rna_only else self.offset_c + self.c_all -= self.offset_c + self.u_all -= self.offset_u + self.s_all -= self.offset_s + # remove zero counts + self.non_zero = (np.ravel(self.c_all > 0) | np.ravel(self.u_all > 0) | + np.ravel(self.s_all > 0)) + # remove outliers + self.non_outlier = np.ravel(self.c_all <= np.percentile(self.c_all, + self.outlier)) + self.non_outlier &= np.ravel(self.u_all <= np.percentile(self.u_all, + self.outlier)) + self.non_outlier &= np.ravel(self.s_all <= np.percentile(self.s_all, + self.outlier)) + self.c = self.c_all[self.non_zero & self.non_outlier] + self.u = self.u_all[self.non_zero & self.non_outlier] + self.s = self.s_all[self.non_zero & self.non_outlier] + self.low_quality = len(self.u) < 10 + # scale modalities + self.std_c, self.std_u, self.std_s = (np.std(self.c_all) + if not self.rna_only + else 1.0, np.std(self.u_all), + np.std(self.s_all)) + if self.std_u == 0 or self.std_s == 0: + self.low_quality = True + self.scale_c, self.scale_u, self.scale_s = np.max(self.c_all) \ + if not self.rna_only else 1.0, self.std_u/self.std_s, 1.0 + + # if we're on neural net mode, check to see if c is way bigger than + # u or s, which would be very hard for the neural net to fit + if not self.low_quality and neural_net: + max_c_orig = np.max(self.c) + if max_c_orig / np.max(self.u) > 500: + self.low_quality = True + + if not self.low_quality: + if max_c_orig / np.max(self.s) > 500: + self.low_quality = True + + self.c_all /= self.scale_c + self.u_all /= self.scale_u + self.s_all /= self.scale_s + self.c /= self.scale_c + self.u /= self.scale_u + self.s /= self.scale_s + self.scale_factor = np.array([np.std(self.c_all) / self.std_s / + self.weight_c, 1.0, 1.0]) + self.scale_factor[0] = 1 if self.rna_only else self.scale_factor[0] + self.max_u, self.max_s = np.max(self.u), np.max(self.s) + self.max_u_all, self.max_s_all = np.max(self.u_all), np.max(self.s_all) + if self.conn is not None: + self.conn_sub = self.conn[np.ix_(self.non_zero & self.non_outlier, + self.non_zero & self.non_outlier)] + else: + self.conn_sub = None + + main_info(f'{len(self.u)} cells passed filter and will be used to ' + 'compute trajectories.', indent_level=2) + self.known_pars = (True + if None not in [rescale_u, alpha, beta, gamma, t_] + else False) + if self.known_pars: + main_info(f'known parameters for gene {self.gene} are ' + f'scaling={rescale_u}, alpha={alpha}, beta={beta},' + f' gamma={gamma}, t_={t_}.', indent_level=1) + + # define neural networks + self.ode_model_0 = nn.Sequential( + nn.Linear(21, 150), + nn.ReLU(), + nn.Linear(150, 112), + nn.ReLU(), + nn.Linear(112, 75), + nn.ReLU(), + nn.Linear(75, 1), + nn.Sigmoid() + ) + + self.ode_model_1 = nn.Sequential( + nn.Linear(16, 64), + nn.ReLU(), + nn.Linear(64, 48), + nn.ReLU(), + nn.Linear(48, 32), + nn.ReLU(), + nn.Linear(32, 1), + nn.Sigmoid() + ) + + self.ode_model_2_m1 = nn.Sequential( + nn.Linear(18, 220), + nn.ReLU(), + nn.Linear(220, 165), + nn.ReLU(), + nn.Linear(165, 110), + nn.ReLU(), + nn.Linear(110, 1), + nn.Sigmoid() + ) + + self.ode_model_2_m2 = nn.Sequential( + nn.Linear(16, 150), + nn.ReLU(), + nn.Linear(150, 112), + nn.ReLU(), + nn.Linear(112, 75), + nn.ReLU(), + nn.Linear(75, 1), + nn.Sigmoid() + ) + + self.ode_model_0.to(torch.device(self.device)) + self.ode_model_1.to(torch.device(self.device)) + self.ode_model_2_m1.to(torch.device(self.device)) + self.ode_model_2_m2.to(torch.device(self.device)) + + # load in neural network + net_path = os.path.dirname(os.path.abspath(__file__)) + \ + "/neural_nets/" + + self.ode_model_0.load_state_dict(torch.load(net_path+"dir0.pt")) + self.ode_model_1.load_state_dict(torch.load(net_path+"dir1.pt")) + self.ode_model_2_m1.load_state_dict(torch.load(net_path+"dir2_m1.pt")) + self.ode_model_2_m2.load_state_dict(torch.load(net_path+"dir2_m2.pt")) + + # 4 rate parameters + self.alpha_c = 0.1 + self.alpha = alpha if alpha is not None else 0.0 + self.beta = beta if beta is not None else 0.0 + self.gamma = gamma if gamma is not None else 0.0 + # 3 possible switch time points + self.t_sw_1 = 0.1 if t_ is not None else 0.0 + self.t_sw_2 = t_+0.1 if t_ is not None else 0.0 + self.t_sw_3 = 20.0 if t_ is not None else 0.0 + # 2 rescale factors + self.rescale_c = 1.0 + self.rescale_u = rescale_u if rescale_u is not None else 1.0 + self.rates = None + self.t_sw_array = None + self.fit_rescale = True if rescale_u is None else False + self.params = None + + # other parameters or results + self.t = None + self.state = None + self.loss = [np.inf] + self.likelihood = -1.0 + self.l_c = 0 + self.ssd_c, self.var_c = 0, 0 + self.scale_cc = 1.0 + self.fitting_flag_ = 0 + self.velocity = None + self.anchor_t1_list, self.anchor_t2_list = None, None + self.anchor_exp = None + self.anchor_exp_sw = None + self.anchor_min_idx, self.anchor_max_idx, self.anchor_velo_min_idx, \ + self.anchor_velo_max_idx = None, None, None, None + self.anchor_velo = None + self.c0 = self.u0 = self.s0 = 0.0 + self.realign_ratio = 1.0 + self.partial = False + self.direction = 'complete' + self.steady_state_func = None + + # for fit and update + self.cur_iter = 0 + self.cur_loss = None + self.cur_state_pred = None + self.cur_t_sw_adjust = None + + # partial checking and model examination + determine_model = model is None + if partial is None and direction is None: + if embed_coord is not None: + self.embed_coord = embed_coord[self.non_zero & + self.non_outlier] + else: + self.embed_coord = None + self.check_partial_trajectory(determine_model=determine_model) + elif direction is not None: + self.direction = direction + if direction in ['on', 'off']: + self.partial = True + else: + self.partial = False + self.check_partial_trajectory(fit_gmm=False, fit_slope=False, + determine_model=determine_model) + elif partial is not None: + self.partial = partial + self.check_partial_trajectory(fit_gmm=False, + determine_model=determine_model) + else: + self.check_partial_trajectory(fit_gmm=False, fit_slope=False, + determine_model=determine_model) + + # intialize steady state parameters + if not self.known_pars and not self.low_quality: + self.initialize_steady_state_params(model_mismatch=self.model + != self.model_) + if self.known_pars: + self.params = np.array([self.t_sw_1, + self.t_sw_2-self.t_sw_1, + self.t_sw_3-self.t_sw_2, + self.alpha_c, + self.alpha, + self.beta, + self.gamma, + self.scale_cc, + self.rescale_c, + self.rescale_u]) + + # the torch tensor version of the anchor points function + def anchor_points_ten(self, t_sw_array, total_h=20, t=1000, mode='uniform', + return_time=False): + + t_ = torch.linspace(0, total_h, t, device=self.device, + dtype=self.torch_type) + tau1 = t_[t_ <= t_sw_array[0]] + tau2 = t_[(t_sw_array[0] < t_) & (t_ <= t_sw_array[1])] - t_sw_array[0] + tau3 = t_[(t_sw_array[1] < t_) & (t_ <= t_sw_array[2])] - t_sw_array[1] + tau4 = t_[t_sw_array[2] < t_] - t_sw_array[2] + + if mode == 'log': + if len(tau1) > 0: + tau1 = torch.expm1(tau1) + tau1 = tau1 / torch.max(tau1) * (t_sw_array[0]) + if len(tau2) > 0: + tau2 = torch.expm1(tau2) + tau2 = tau2 / torch.max(tau2) * (t_sw_array[1] - t_sw_array[0]) + if len(tau3) > 0: + tau3 = torch.expm1(tau3) + tau3 = tau3 / torch.max(tau3) * (t_sw_array[2] - t_sw_array[1]) + if len(tau4) > 0: + tau4 = torch.expm1(tau4) + tau4 = tau4 / torch.max(tau4) * (total_h - t_sw_array[2]) + + tau_list = [tau1, tau2, tau3, tau4] + if return_time: + return t_, tau_list + else: + return tau_list + + # the torch version of the predict_exp function + def predict_exp_ten(self, + tau, + c0, + u0, + s0, + alpha_c, + alpha, + beta, + gamma, + scale_cc=None, + pred_r=True, + chrom_open=True, + backward=False, + rna_only=False): + + if scale_cc is None: + scale_cc = torch.tensor(1.0, requires_grad=True, + device=self.device, + dtype=self.torch_type) + + if len(tau) == 0: + return torch.empty((0, 3), + requires_grad=True, + device=self.device, + dtype=self.torch_type) + if backward: + tau = -tau + + eat = torch.exp(-alpha_c * tau) + ebt = torch.exp(-beta * tau) + egt = torch.exp(-gamma * tau) + if rna_only: + kc = 1 + c0 = 1 + else: + if chrom_open: + kc = 1 + else: + kc = 0 + alpha_c = alpha_c * scale_cc + + const = (kc - c0) * alpha / (beta - alpha_c) + + res0 = kc - (kc - c0) * eat + + if pred_r: + + res1 = u0 * ebt + (alpha * kc / beta) * (1 - ebt) + res1 += const * (ebt - eat) + + res2 = s0 * egt + (alpha * kc / gamma) * (1 - egt) + res2 += ((beta / (gamma - beta)) * + ((alpha * kc / beta) - u0 - const) * (egt - ebt)) + res2 += (beta / (gamma - alpha_c)) * const * (egt - eat) + + else: + res1 = torch.zeros(len(tau), device=self.device, + requires_grad=True, + dtype=self.torch_type) + res2 = torch.zeros(len(tau), device=self.device, + requires_grad=True, + dtype=self.torch_type) + + res = torch.stack((res0, res1, res2), 1) + + return res + + # the torch tensor version of the generate_exp function + def generate_exp_tens(self, + tau_list, + t_sw_array, + alpha_c, + alpha, + beta, + gamma, + scale_cc=None, + model=1, + rna_only=False): + + if scale_cc is None: + scale_cc = torch.tensor(1.0, requires_grad=True, + device=self.device, + dtype=self.torch_type) + + if beta == alpha_c: + beta += 1e-3 + if gamma == beta or gamma == alpha_c: + gamma += 1e-3 + switch = int(t_sw_array.size(dim=0)) + if switch >= 1: + tau_sw1 = torch.tensor([t_sw_array[0]], requires_grad=True, + device=self.device, + dtype=self.torch_type) + if switch >= 2: + tau_sw2 = torch.tensor([t_sw_array[1] - t_sw_array[0]], + requires_grad=True, + device=self.device, + dtype=self.torch_type) + if switch == 3: + tau_sw3 = torch.tensor([t_sw_array[2] - t_sw_array[1]], + requires_grad=True, + device=self.device, + dtype=self.torch_type) + exp_sw1, exp_sw2, exp_sw3 = (torch.empty((0, 3), + requires_grad=True, + device=self.device, + dtype=self.torch_type), + torch.empty((0, 3), + requires_grad=True, + device=self.device, + dtype=self.torch_type), + torch.empty((0, 3), + requires_grad=True, + device=self.device, + dtype=self.torch_type)) + if tau_list is None: + if model == 0: + if switch >= 1: + exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c, + alpha, beta, gamma, + pred_r=False, + scale_cc=scale_cc, + rna_only=rna_only) + if switch >= 2: + exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0], + exp_sw1[0, 1], + exp_sw1[0, 2], + alpha_c, alpha, beta, + gamma, pred_r=False, + chrom_open=False, + scale_cc=scale_cc, + rna_only=rna_only) + if switch >= 3: + exp_sw3 = self.predict_exp_ten(tau_sw3, + exp_sw2[0, 0], + exp_sw2[0, 1], + exp_sw2[0, 2], + alpha_c, alpha, + beta, gamma, + chrom_open=False, + scale_cc=scale_cc, + rna_only=rna_only) + elif model == 1: + if switch >= 1: + exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c, + alpha, beta, gamma, + pred_r=False, + scale_cc=scale_cc, + rna_only=rna_only) + if switch >= 2: + exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0], + exp_sw1[0, 1], + exp_sw1[0, 2], + alpha_c, alpha, + beta, gamma, + scale_cc=scale_cc, + rna_only=rna_only) + if switch >= 3: + exp_sw3 = self.predict_exp_ten(tau_sw3, + exp_sw2[0, 0], + exp_sw2[0, 1], + exp_sw2[0, 2], + alpha_c, alpha, + beta, gamma, + chrom_open=False, + scale_cc=scale_cc, + rna_only=rna_only) + elif model == 2: + if switch >= 1: + exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c, + alpha, beta, gamma, + pred_r=False, + scale_cc=scale_cc, + rna_only=rna_only) + if switch >= 2: + exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0], + exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, + alpha, beta, gamma, + scale_cc=scale_cc, + rna_only=rna_only) + if switch >= 3: + exp_sw3 = self.predict_exp_ten(tau_sw3, + exp_sw2[0, 0], + exp_sw2[0, 1], + exp_sw2[0, 2], + alpha_c, 0, beta, + gamma, + scale_cc=scale_cc, + rna_only=rna_only) + + return [torch.empty((0, 3), requires_grad=True, + device=self.device, + dtype=self.torch_type), + torch.empty((0, 3), requires_grad=True, + device=self.device, + dtype=self.torch_type), + torch.empty((0, 3), requires_grad=True, + device=self.device, + dtype=self.torch_type), + torch.empty((0, 3), requires_grad=True, + device=self.device, + dtype=self.torch_type)], \ + [exp_sw1, exp_sw2, exp_sw3] + + tau1 = tau_list[0] + if switch >= 1: + tau2 = tau_list[1] + if switch >= 2: + tau3 = tau_list[2] + if switch == 3: + tau4 = tau_list[3] + exp1, exp2, exp3, exp4 = (torch.empty((0, 3), requires_grad=True, + device=self.device, + dtype=self.torch_type), + torch.empty((0, 3), requires_grad=True, + device=self.device, + dtype=self.torch_type), + torch.empty((0, 3), requires_grad=True, + device=self.device, + dtype=self.torch_type), + torch.empty((0, 3), requires_grad=True, + device=self.device, + dtype=self.torch_type)) + if model == 0: + exp1 = self.predict_exp_ten(tau1, 0, 0, 0, alpha_c, alpha, beta, + gamma, pred_r=False, scale_cc=scale_cc, + rna_only=rna_only) + if switch >= 1: + exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c, + alpha, beta, gamma, + pred_r=False, scale_cc=scale_cc, + rna_only=rna_only) + exp2 = self.predict_exp_ten(tau2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, + beta, gamma, pred_r=False, + chrom_open=False, + scale_cc=scale_cc, + rna_only=rna_only) + if switch >= 2: + exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0], + exp_sw1[0, 1], + exp_sw1[0, 2], + alpha_c, alpha, beta, gamma, + pred_r=False, + chrom_open=False, + scale_cc=scale_cc, + rna_only=rna_only) + exp3 = self.predict_exp_ten(tau3, exp_sw2[0, 0], + exp_sw2[0, 1], exp_sw2[0, 2], + alpha_c, alpha, beta, gamma, + chrom_open=False, + scale_cc=scale_cc, + rna_only=rna_only) + if switch == 3: + exp_sw3 = self.predict_exp_ten(tau_sw3, exp_sw2[0, 0], + exp_sw2[0, 1], + exp_sw2[0, 2], + alpha_c, alpha, beta, + gamma, + chrom_open=False, + scale_cc=scale_cc, + rna_only=rna_only) + exp4 = self.predict_exp_ten(tau4, exp_sw3[0, 0], + exp_sw3[0, 1], + exp_sw3[0, 2], + alpha_c, 0, beta, gamma, + chrom_open=False, + scale_cc=scale_cc, + rna_only=rna_only) + elif model == 1: + exp1 = self.predict_exp_ten(tau1, 0, 0, 0, alpha_c, alpha, beta, + gamma, pred_r=False, scale_cc=scale_cc, + rna_only=rna_only) + if switch >= 1: + exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c, + alpha, beta, gamma, + pred_r=False, scale_cc=scale_cc, + rna_only=rna_only) + exp2 = self.predict_exp_ten(tau2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, + beta, gamma, scale_cc=scale_cc, + rna_only=rna_only) + if switch >= 2: + exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0], + exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, + alpha, beta, gamma, + scale_cc=scale_cc, + rna_only=rna_only) + exp3 = self.predict_exp_ten(tau3, exp_sw2[0, 0], + exp_sw2[0, 1], exp_sw2[0, 2], + alpha_c, alpha, beta, gamma, + chrom_open=False, + scale_cc=scale_cc, + rna_only=rna_only) + if switch == 3: + exp_sw3 = self.predict_exp_ten(tau_sw3, exp_sw2[0, 0], + exp_sw2[0, 1], + exp_sw2[0, 2], + alpha_c, alpha, beta, + gamma, + chrom_open=False, + scale_cc=scale_cc, + rna_only=rna_only) + exp4 = self.predict_exp_ten(tau4, exp_sw3[0, 0], + exp_sw3[0, 1], + exp_sw3[0, 2], alpha_c, 0, + beta, gamma, + chrom_open=False, + scale_cc=scale_cc, + rna_only=rna_only) + elif model == 2: + exp1 = self.predict_exp_ten(tau1, 0, 0, 0, alpha_c, alpha, beta, + gamma, pred_r=False, scale_cc=scale_cc, + rna_only=rna_only) + if switch >= 1: + exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c, + alpha, beta, gamma, + pred_r=False, scale_cc=scale_cc, + rna_only=rna_only) + exp2 = self.predict_exp_ten(tau2, exp_sw1[0, 0], exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, alpha, + beta, gamma, scale_cc=scale_cc, + rna_only=rna_only) + if switch >= 2: + exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0], + exp_sw1[0, 1], + exp_sw1[0, 2], alpha_c, + alpha, beta, gamma, + scale_cc=scale_cc, + rna_only=rna_only) + exp3 = self.predict_exp_ten(tau3, exp_sw2[0, 0], + exp_sw2[0, 1], + exp_sw2[0, 2], alpha_c, 0, + beta, gamma, scale_cc=scale_cc, + rna_only=rna_only) + if switch == 3: + exp_sw3 = self.predict_exp_ten(tau_sw3, exp_sw2[0, 0], + exp_sw2[0, 1], + exp_sw2[0, 2], + alpha_c, 0, beta, gamma, + scale_cc=scale_cc, + rna_only=rna_only) + exp4 = self.predict_exp_ten(tau4, exp_sw3[0, 0], + exp_sw3[0, 1], + exp_sw3[0, 2], + alpha_c, 0, beta, gamma, + chrom_open=False, + scale_cc=scale_cc, + rna_only=rna_only) + return [exp1, exp2, exp3, exp4], [exp_sw1, exp_sw2, exp_sw3] + + def check_partial_trajectory(self, fit_gmm=True, fit_slope=True, + determine_model=True): + w_non_zero = ((self.c >= 0.1 * np.max(self.c)) & + (self.u >= 0.1 * np.max(self.u)) & + (self.s >= 0.1 * np.max(self.s))) + u_non_zero = self.u[w_non_zero] + s_non_zero = self.s[w_non_zero] + if len(u_non_zero) < 10: + self.low_quality = True + return + + # GMM + w_low = ((np.percentile(s_non_zero, 30) <= s_non_zero) & + (s_non_zero <= np.percentile(s_non_zero, 40))) + if np.sum(w_low) < 10: + fit_gmm = False + self.partial = True + if self.local_std is None: + main_info('local standard deviation not provided. ' + 'Skipping GMM..', indent_level=2) + if self.embed_coord is None: + main_info('Warning: embedded coordinates not provided. ' + 'Skipping GMM..') + if (fit_gmm and self.local_std is not None and self.embed_coord + is not None): + + pdist = pairwise_distances( + self.embed_coord[w_non_zero, :][w_low, :]) + dists = (np.ravel(pdist[np.triu_indices_from(pdist, k=1)]) + .reshape(-1, 1)) + model = GaussianMixture(n_components=2, covariance_type='tied', + random_state=2021).fit(dists) + mean_diff = np.abs(model.means_[1][0] - model.means_[0][0]) + criterion1 = mean_diff > self.local_std / self.tm + main_info(f'GMM: difference between means = {mean_diff}, ' + f'threshold = {self.local_std / self.tm}.', indent_level=2) + criterion2 = np.all(model.weights_[1] > 0.2 / self.tm) + main_info('GMM: weight of the second Gaussian =' + f' {model.weights_[1]}.', indent_level=2) + if criterion1 and criterion2: + self.partial = False + else: + self.partial = True + main_info(f'GMM decides {"" if self.partial else "not "}' + 'partial.', indent_level=2) + + # steady-state slope + wu = self.u >= np.percentile(u_non_zero, 95) + ws = self.s >= np.percentile(s_non_zero, 95) + ss_u = self.u[wu | ws] + ss_s = self.s[wu | ws] + if np.all(ss_u == 0) or np.all(ss_s == 0): + self.low_quality = True + return + gamma = np.dot(ss_u, ss_s) / np.dot(ss_s, ss_s) + self.steady_state_func = lambda x: gamma*x + + # thickness of phase portrait + u_norm = u_non_zero / np.max(self.u) + s_norm = s_non_zero / np.max(self.s) + exp = np.hstack((np.reshape(u_norm, (-1, 1)), + np.reshape(s_norm, (-1, 1)))) + U, S, Vh = np.linalg.svd(exp) + self.thickness = S[1] + + # slope-based direction decision + with np.errstate(divide='ignore', invalid='ignore'): + slope = self.u / self.s + non_nan = ~np.isnan(slope) + slope = slope[non_nan] + on = slope >= gamma + off = slope < gamma + if len(ss_u) < 10 or len(u_non_zero) < 10: + fit_slope = False + self.direction = 'complete' + if fit_slope: + slope_ = u_non_zero / s_non_zero + on_ = slope_ >= gamma + off_ = slope_ < gamma + on_dist = np.sum((u_non_zero[on_] - gamma * s_non_zero[on_])**2) + off_dist = np.sum((gamma * s_non_zero[off_] - u_non_zero[off_])**2) + main_info(f'Slope: SSE on induction phase = {on_dist},' + f' SSE on repression phase = {off_dist}.', indent_level=2) + if self.thickness < 1.5 / np.sqrt(self.tm): + narrow = True + else: + narrow = False + main_info(f'Thickness of trajectory = {self.thickness}. ' + f'Trajectory is {"narrow" if narrow else "normal"}.', + indent_level=2) + if on_dist > 10 * self.tm**2 * off_dist: + self.direction = 'on' + self.partial = True + elif off_dist > 10 * self.tm**2 * on_dist: + self.direction = 'off' + self.partial = True + else: + if self.partial is True: + if on_dist > 3 * self.tm * off_dist: + self.direction = 'on' + elif off_dist > 3 * self.tm * on_dist: + self.direction = 'off' + else: + if narrow: + self.direction = 'on' + else: + self.direction = 'complete' + self.partial = False + else: + if narrow: + self.direction = ('off' + if off_dist > 2 * self.tm * on_dist + else 'on') + self.partial = True + else: + self.direction = 'complete' + + # model pre-determination + if self.direction == 'on': + self.model_ = 1 + elif self.direction == 'off': + self.model_ = 2 + else: + c_high = self.c >= np.mean(self.c) + 2 * np.std(self.c) + c_high = c_high[non_nan] + if np.sum(c_high) < 10: + c_high = self.c >= np.mean(self.c) + np.std(self.c) + c_high = c_high[non_nan] + if np.sum(c_high) < 10: + c_high = self.c >= np.percentile(self.c, 90) + c_high = c_high[non_nan] + if np.sum(self.c[non_nan][c_high] == 0) > 0.5*np.sum(c_high): + self.low_quality = True + return + c_high_on = np.sum(c_high & on) + c_high_off = np.sum(c_high & off) + if c_high_on > c_high_off: + self.model_ = 1 + else: + self.model_ = 2 + if determine_model: + self.model = self.model_ + + if not self.known_pars: + if fit_gmm or fit_slope: + main_info(f'predicted partial trajectory: {self.partial}', + indent_level=1) + main_info('predicted trajectory direction:' + f'{self.direction}', indent_level=1) + if determine_model: + main_info(f'predicted model: {self.model}', indent_level=1) + + def initialize_steady_state_params(self, model_mismatch=False): + self.scale_cc = 1.0 + self.rescale_c = 1.0 + # estimate rescale factor for u + s_norm = self.s / self.max_s + u_mid = (self.u >= 0.4 * self.max_u) & (self.u <= 0.6 * self.max_u) + if np.sum(u_mid) < 10: + self.rescale_u = self.thickness / 5 + else: + s_low, s_high = np.percentile(s_norm[u_mid], [2, 98]) + s_dist = s_high - s_low + self.rescale_u = s_dist + if self.rescale_u == 0: + self.low_quality = True + return + + c = self.c / self.rescale_c + u = self.u / self.rescale_u + s = self.s + + # some extreme values + wu = u >= np.percentile(u, 97) + ws = s >= np.percentile(s, 97) + ss_u = u[wu | ws] + ss_s = s[wu | ws] + c_upper = np.mean(c[wu | ws]) + + c_high = c >= np.mean(c) + # _r stands for repressed state + c0_r = np.mean(c[c_high]) + u0_r = np.mean(ss_u) + s0_r = np.mean(ss_s) + if c0_r < c_upper: + c0_r = c_upper + 0.1 + + # adjust chromatin level for reasonable initialization + if model_mismatch or not self.fit_decoupling: + c_indu = np.mean(c[self.u > self.steady_state_func(self.s)]) + c_repr = np.mean(c[self.u < self.steady_state_func(self.s)]) + if c_indu == np.nan or c_repr == np.nan: + self.low_quality = True + return + c0_r = np.mean(c[c >= np.min([c_indu, c_repr])]) + + # initialize rates + self.alpha_c = 0.1 + self.beta = 1.0 + self.gamma = np.dot(ss_u, ss_s) / np.dot(ss_s, ss_s) + alpha = u0_r + self.alpha = u0_r + self.rates = np.array([self.alpha_c, self.alpha, self.beta, + self.gamma]) + + # RNA-only + if self.rna_only: + t_sw_1 = 0.1 + t_sw_3 = 20.0 + if self.init_mode == 'grid': + # arange returns sequence [2,6,10,14,18] + for t_sw_2 in np.arange(2, 20, 4, dtype=np.float64): + self.update(params, initialize=True, adjust_time=False, + plot=False) + + elif self.init_mode == 'simple': + t_sw_2 = 10 + self.params = np.array([t_sw_1, + t_sw_2-t_sw_1, + t_sw_3-t_sw_2, + self.alpha_c, + self.alpha, + self.beta, + self.gamma, + self.scale_cc, + self.rescale_c, + self.rescale_u]) + + elif self.init_mode == 'invert': + t_sw_2 = approx_tau(u0_r, s0_r, 0, 0, alpha, self.beta, + self.gamma) + if t_sw_2 <= 0.2: + t_sw_2 = 1.0 + elif t_sw_2 >= 19.9: + t_sw_2 = 19.0 + self.params = np.array([t_sw_1, + t_sw_2-t_sw_1, + t_sw_3-t_sw_2, + self.alpha_c, + self.alpha, + self.beta, + self.gamma, + self.scale_cc, + self.rescale_c, + self.rescale_u]) + + # chromatin-RNA + else: + if self.init_mode == 'grid': + # arange returns sequence [1,5,9,13,17] + for t_sw_1 in np.arange(1, 18, 4, dtype=np.float64): + # arange returns sequence 2,6,10,14,18 + for t_sw_2 in np.arange(t_sw_1+1, 19, 4, dtype=np.float64): + # arange returns sequence [3,7,11,15,19] + for t_sw_3 in np.arange(t_sw_2+1, 20, 4, + dtype=np.float64): + if not self.fit_decoupling: + t_sw_3 = t_sw_2 + 30 / self.n_anchors + params = np.array([t_sw_1, + t_sw_2-t_sw_1, + t_sw_3-t_sw_2, + self.alpha_c, + self.alpha, + self.beta, + self.gamma, + self.scale_cc, + self.rescale_c, + self.rescale_u]) + self.update(params, initialize=True, + adjust_time=False, plot=False) + if not self.fit_decoupling: + break + + elif self.init_mode == 'simple': + t_sw_1, t_sw_2, t_sw_3 = 5, 10, 15 \ + if not self.fit_decoupling \ + else 10.1 + self.params = np.array([t_sw_1, + t_sw_2-t_sw_1, + t_sw_3-t_sw_2, + self.alpha_c, + self.alpha, + self.beta, + self.gamma, + self.scale_cc, + self.rescale_c, + self.rescale_u]) + + elif self.init_mode == 'invert': + self.alpha = u0_r / c_upper + if model_mismatch or not self.fit_decoupling: + self.alpha = u0_r / c0_r + rna_interval = approx_tau(u0_r, s0_r, 0, 0, alpha, self.beta, + self.gamma) + rna_interval = np.clip(rna_interval, 3, 12) + if self.model == 1: + for t_sw_1 in np.arange(1, rna_interval-1, 2, + dtype=np.float64): + t_sw_3 = rna_interval + t_sw_1 + for t_sw_2 in np.arange(t_sw_1+1, rna_interval, 2, + dtype=np.float64): + if not self.fit_decoupling: + t_sw_2 = t_sw_3 - 30 / self.n_anchors + + alpha_c = -np.log(1 - c0_r) / t_sw_2 + params = np.array([t_sw_1, + t_sw_2-t_sw_1, + t_sw_3-t_sw_2, + alpha_c, + self.alpha, + self.beta, + self.gamma, + self.scale_cc, + self.rescale_c, + self.rescale_u]) + self.update(params, initialize=True, + adjust_time=False, plot=False) + if not self.fit_decoupling: + break + + elif self.model == 2: + for t_sw_1 in np.arange(1, rna_interval, 2, + dtype=np.float64): + t_sw_2 = rna_interval + t_sw_1 + for t_sw_3 in np.arange(t_sw_2+1, t_sw_2+6, 2, + dtype=np.float64): + if not self.fit_decoupling: + t_sw_3 = t_sw_2 + 30 / self.n_anchors + + alpha_c = -np.log(1 - c0_r) / t_sw_3 + params = np.array([t_sw_1, + t_sw_2-t_sw_1, + t_sw_3-t_sw_2, + alpha_c, + self.alpha, + self.beta, + self.gamma, + self.scale_cc, + self.rescale_c, + self.rescale_u]) + self.update(params, initialize=True, + adjust_time=False, plot=False) + if not self.fit_decoupling: + break + + self.loss = [self.mse(self.params)] + self.t_sw_array = np.array([self.params[0], + self.params[0]+self.params[1], + self.params[0]+self.params[1] + + self.params[2]]) + self.t_sw_1, self.t_sw_2, self.t_sw_3 = self.t_sw_array + + main_info(f'initial params:\nswitch time array = {self.t_sw_array},' + '\n' + f'rates = {self.rates},\ncc scale = {self.scale_cc},\n' + f'c rescale factor = {self.rescale_c},\n' + f'u rescale factor = {self.rescale_u}', indent_level=1) + main_info(f'initial loss: {self.loss[-1]}', indent_level=1) + + def fit(self): + if self.low_quality: + return self.loss + + if self.plot: + plt.ion() + self.fig = plt.figure(figsize=self.fig_size) + if self.rna_only: + self.ax = self.fig.add_subplot(111) + else: + self.ax = self.fig.add_subplot(111, projection='3d') + + if not self.known_pars: + self.fit_dyn() + + self.update(self.params, perform_update=True, fit_outlier=True, + plot=True) + + # remove long gaps in the last observed state + t_sorted = np.sort(self.t) + dt = np.diff(t_sorted, prepend=0) + mean_dt = np.mean(dt) + std_dt = np.std(dt) + gap_thresh = np.clip(mean_dt+3*std_dt, 3*20/self.n_anchors, None) + if gap_thresh > 0: + idx = np.where(dt > gap_thresh)[0] + gap_sum = 0 + last_t_sw = np.max(self.t_sw_array[self.t_sw_array < 20]) + for i in idx: + t1 = t_sorted[i-1] if i > 0 else 0 + t2 = t_sorted[i] + if t1 > last_t_sw and t2 <= 20: + gap_sum += np.clip(t2 - t1 - mean_dt, 0, None) + if last_t_sw > np.max(self.t): + gap_sum += 20 - last_t_sw + realign_ratio = np.clip(20/(20 - gap_sum), None, 20/last_t_sw) + main_info(f'removing gaps and realigning by {realign_ratio}..', + indent_level=1) + self.rates /= realign_ratio + self.alpha_c, self.alpha, self.beta, self.gamma = self.rates + self.params[:3] *= realign_ratio + self.params[3:7] = self.rates + self.t_sw_array = np.array([self.params[0], + self.params[0]+self.params[1], + self.params[0]+self.params[1] + + self.params[2]]) + self.t_sw_1, self.t_sw_2, self.t_sw_3 = self.t_sw_array + self.update(self.params, perform_update=True, fit_outlier=True, + plot=True) + + if self.plot: + plt.ioff() + plt.show(block=True) + + # likelihood + main_info('computing likelihood..', indent_level=1) + keep = self.non_zero & self.non_outlier & \ + (self.u_all > 0.2 * np.percentile(self.u_all, 99.5)) & \ + (self.s_all > 0.2 * np.percentile(self.s_all, 99.5)) + scale_factor = np.array([self.scale_c / self.std_c, + self.scale_u / self.std_u, + self.scale_s / self.std_s]) + if np.sum(keep) >= 10: + self.likelihood, self.l_c, self.ssd_c, self.var_c, l_u, l_s = \ + compute_likelihood(self.c_all, + self.u_all, + self.s_all, + self.t_sw_array, + self.alpha_c, + self.alpha, + self.beta, + self.gamma, + self.rescale_c, + self.rescale_u, + self.t, + self.state, + scale_cc=self.scale_cc, + scale_factor=scale_factor, + model=self.model, + weight=keep, + rna_only=self.rna_only) + else: + self.likelihood, self.l_c, self.ssd_c, self.var_c, l_u = \ + 0, 0, 0, 0, 0 + # TODO: Keep? Remove?? + l_s = 0 + + if not self.rna_only: + main_info(f'likelihood of c: {self.l_c}, likelihood of u: {l_u},' + f' likelihood of s: {l_s}', indent_level=1) + + # velocity + main_info('computing velocities..', indent_level=1) + self.velocity = np.empty((len(self.u_all), 3)) + if self.conn is not None: + new_time = self.conn.dot(self.t) + new_time[new_time > 20] = 20 + new_state = self.state.copy() + new_state[new_time <= self.t_sw_1] = 0 + new_state[(self.t_sw_1 < new_time) & (new_time <= self.t_sw_2)] = 1 + new_state[(self.t_sw_2 < new_time) & (new_time <= self.t_sw_3)] = 2 + new_state[self.t_sw_3 < new_time] = 3 + + else: + new_time = self.t + new_state = self.state + + self.alpha_c, self.alpha, self.beta, self.gamma = \ + check_params(self.alpha_c, self.alpha, self.beta, self.gamma) + vc, vu, vs = compute_velocity(new_time, + self.t_sw_array, + new_state, + self.alpha_c, + self.alpha, + self.beta, + self.gamma, + self.rescale_c, + self.rescale_u, + scale_cc=self.scale_cc, + model=self.model, + rna_only=self.rna_only) + + self.velocity[:, 0] = vc * self.scale_c + self.velocity[:, 1] = vu * self.scale_u + self.velocity[:, 2] = vs * self.scale_s + + # anchor expression and velocity + anchor_time, tau_list = anchor_points(self.t_sw_array, 20, + self.n_anchors, return_time=True) + switch = np.sum(self.t_sw_array < 20) + typed_tau_list = List() + [typed_tau_list.append(x) for x in tau_list] + self.alpha_c, self.alpha, self.beta, self.gamma, \ + self.c0, self.u0, self.s0 = \ + check_params(self.alpha_c, self.alpha, self.beta, self.gamma, + c0=self.c0, u0=self.u0, s0=self.s0) + exp_list, exp_sw_list = generate_exp(typed_tau_list, + self.t_sw_array[:switch], + self.alpha_c, + self.alpha, + self.beta, + self.gamma, + scale_cc=self.scale_cc, + model=self.model, + rna_only=self.rna_only) + rescale_factor = np.array([self.rescale_c, self.rescale_u, 1.0]) + exp_list = [x*rescale_factor for x in exp_list] + exp_sw_list = [x*rescale_factor for x in exp_sw_list] + c = np.ravel(np.concatenate([exp_list[x][:, 0] + for x in range(switch+1)])) + u = np.ravel(np.concatenate([exp_list[x][:, 1] + for x in range(switch+1)])) + s = np.ravel(np.concatenate([exp_list[x][:, 2] + for x in range(switch+1)])) + c_sw = np.ravel(np.concatenate([exp_sw_list[x][:, 0] + for x in range(switch)])) + u_sw = np.ravel(np.concatenate([exp_sw_list[x][:, 1] + for x in range(switch)])) + s_sw = np.ravel(np.concatenate([exp_sw_list[x][:, 2] + for x in range(switch)])) + self.alpha_c, self.alpha, self.beta, self.gamma = \ + check_params(self.alpha_c, self.alpha, self.beta, self.gamma) + vc, vu, vs = compute_velocity(anchor_time, + self.t_sw_array, + None, + self.alpha_c, + self.alpha, + self.beta, + self.gamma, + self.rescale_c, + self.rescale_u, + scale_cc=self.scale_cc, + model=self.model, + rna_only=self.rna_only) + + # scale and shift back to original scale + c_ = c * self.scale_c + self.offset_c + u_ = u * self.scale_u + self.offset_u + s_ = s * self.scale_s + self.offset_s + c_sw_ = c_sw * self.scale_c + self.offset_c + u_sw_ = u_sw * self.scale_u + self.offset_u + s_sw_ = s_sw * self.scale_s + self.offset_s + vc = vc * self.scale_c + vu = vu * self.scale_u + vs = vs * self.scale_s + + self.anchor_exp = np.empty((len(u_), 3)) + self.anchor_exp[:, 0], self.anchor_exp[:, 1], self.anchor_exp[:, 2] = \ + c_, u_, s_ + self.anchor_exp_sw = np.empty((len(u_sw_), 3)) + self.anchor_exp_sw[:, 0], self.anchor_exp_sw[:, 1], \ + self.anchor_exp_sw[:, 2] = c_sw_, u_sw_, s_sw_ + self.anchor_velo = np.empty((len(u_), 3)) + self.anchor_velo[:, 0] = vc + self.anchor_velo[:, 1] = vu + self.anchor_velo[:, 2] = vs + self.anchor_velo_min_idx = np.sum(anchor_time < np.min(new_time)) + self.anchor_velo_max_idx = np.sum(anchor_time < np.max(new_time)) - 1 + + if self.save_plot: + main_info('saving plots..', indent_level=1) + self.save_dyn_plot(c_, u_, s_, c_sw_, u_sw_, s_sw_, tau_list) + + self.realign_time_and_velocity(c, u, s, anchor_time) + + main_info(f'final params:\nswitch time array = {self.t_sw_array},\n' + f'rates = {self.rates},\ncc scale = {self.scale_cc},\n' + f'c rescale factor = {self.rescale_c},\n' + f'u rescale factor = {self.rescale_u}', + indent_level=1) + main_info(f'final loss: {self.loss[-1]}', indent_level=1) + main_info(f'final likelihood: {self.likelihood}', indent_level=1) + + return self.loss + + # the adam algorithm + # NOTE: The starting point for this function was an excample on the + # GeeksForGeeks website. The particular article is linked below: + # www.geeksforgeeks.org/how-to-implement-adam-gradient-descent-from-scratch-using-python/ + def AdamMin(self, x, n_iter, tol, eps=1e-8): + + n = len(x) + + x_ten = torch.tensor(x, requires_grad=True, device=self.device, + dtype=self.torch_type) + + # record lowest loss as a benchmark + # (right now the lowest loss is the current loss) + lowest_loss = torch.tensor(np.array(self.loss[-1], dtype=self.u.dtype), + device=self.device, + dtype=self.torch_type) + + # record the tensor of the parameters that cause the lowest loss + lowest_x_ten = x_ten + + # the m and v variables used in the adam calculations + m = torch.zeros(n, device=self.device, requires_grad=True, + dtype=self.torch_type) + v = torch.zeros(n, device=self.device, requires_grad=True, + dtype=self.torch_type) + + # the update amount to add to the x tensor after the appropriate + # calculations are made + u = torch.ones(n, device=self.device, requires_grad=True, + dtype=self.torch_type) * float("inf") + + # how many times the new loss is lower than the lowest loss + update_count = 0 + + iterations = 0 + + # run the gradient descent updates + for t in range(n_iter): + + iterations += 1 + + # calculate the loss + loss = self.mse_ten(x_ten) + + # if the loss is lower than the lowest loss... + if loss < lowest_loss: + + # record the new best tensor + lowest_x_ten = x_ten + update_count += 1 + + # if the percentage difference in x tensors and loss values + # is less than the tolerance parameter and we've update the + # loss 3 times by now... + if torch.all((torch.abs(u) / lowest_x_ten) < tol) and \ + (torch.abs(loss - lowest_loss) / lowest_loss) < tol and \ + update_count >= 3: + + # ...we've updated enough. Break! + break + + # record the new lowest loss + lowest_loss = loss + + # take the gradient of mse w/r/t our current parameter values + loss.backward(inputs=x_ten) + g = x_ten.grad + + # calculate the new update value using the Adam formula + m = (self.adam_beta1 * m) + ((1.0 - self.adam_beta1) * g) + v = (self.adam_beta2 * v) + ((1.0 - self.adam_beta2) * g * g) + + mhat = m / (1.0 - (self.adam_beta1**(t+1))) + vhat = v / (1.0 - (self.adam_beta2**(t+1))) + + u = -(self.adam_lr * mhat) / (torch.sqrt(vhat) + eps) + + # update the x tensor + x_ten = x_ten + u + + # as long as we've found at least one better x tensor... + if update_count > 1: + + # record the final lowest loss + if loss < lowest_loss: + lowest_loss = loss + + # set the new loss for the gene to the new lowest loss + self.cur_loss = lowest_loss.item() + + # use the update() function so the gene's parameters + # are the new best one we found + updated = self.update(lowest_x_ten.cpu().detach().numpy()) + + # if we never found a better x tensor, then the return value should + # state that we did not update it + else: + updated = False + + # return whether we updated the x tensor or not + return updated + + def fit_dyn(self): + + while self.cur_iter < self.max_iter: + self.cur_iter += 1 + + # RNA-only + if self.rna_only: + main_info('Nelder Mead on t_sw_2 and alpha..', indent_level=2) + self.fitting_flag_ = 0 + if self.cur_iter == 1: + var_test = (self.alpha + + np.array([-2, -1, -0.5, 0.5, 1, 2]) * 0.1 + * self.alpha) + new_params = self.params.copy() + for var in var_test: + new_params[4] = var + self.update(new_params, adjust_time=False, + penalize_gap=False) + res = minimize(self.mse, x0=[self.params[1], self.params[4]], + method='Nelder-Mead', tol=1e-2, + callback=self.update, options={'maxiter': 3}) + + if self.fit_rescale: + main_info('Nelder Mead on t_sw_2, beta, and rescale u..', + indent_level=2) + res = minimize(self.mse, x0=[self.params[1], + self.params[5], + self.params[9]], + method='Nelder-Mead', tol=1e-2, + callback=self.update, + options={'maxiter': 5}) + + main_info('Nelder Mead on alpha and gamma..', indent_level=2) + self.fitting_flag_ = 1 + res = minimize(self.mse, x0=[self.params[4], self.params[6]], + method='Nelder-Mead', tol=1e-2, + callback=self.update, options={'maxiter': 3}) + + main_info('Nelder Mead on t_sw_2..', indent_level=2) + res = minimize(self.mse, x0=[self.params[1]], + method='Nelder-Mead', tol=1e-2, + callback=self.update, options={'maxiter': 2}) + + main_info('Full Nelder Mead..', indent_level=2) + res = minimize(self.mse, x0=[self.params[1], self.params[4], + self.params[5], self.params[6]], + method='Nelder-Mead', tol=1e-2, + callback=self.update, options={'maxiter': 5}) + + # chromatin-RNA + else: + + if not self.adam: + main_info('Nelder Mead on t_sw_1, chromatin switch time,' + 'and alpha_c..', indent_level=2) + self.fitting_flag_ = 1 + if self.cur_iter == 1: + var_test = (self.gamma + np.array([-1, -0.5, 0.5, 1]) + * 0.1 * self.gamma) + new_params = self.params.copy() + for var in var_test: + new_params[6] = var + self.update(new_params, adjust_time=False) + if self.model == 0 or self.model == 1: + res = minimize(self.mse, x0=[self.params[0], + self.params[1], + self.params[3]], + method='Nelder-Mead', tol=1e-2, + callback=self.update, + options={'maxiter': 20}) + elif self.model == 2: + res = minimize(self.mse, x0=[self.params[0], + self.params[2], + self.params[3]], + method='Nelder-Mead', tol=1e-2, + callback=self.update, + options={'maxiter': 20}) + + main_info('Nelder Mead on chromatin switch time,' + 'chromatin closing rate scaling, and rescale' + 'c..', indent_level=2) + self.fitting_flag_ = 2 + if self.model == 0 or self.model == 1: + res = minimize(self.mse, x0=[self.params[1], + self.params[7], + self.params[8]], + method='Nelder-Mead', tol=1e-2, + callback=self.update, + options={'maxiter': 20}) + elif self.model == 2: + res = minimize(self.mse, x0=[self.params[2], + self.params[7], + self.params[8]], + method='Nelder-Mead', tol=1e-2, + callback=self.update, + options={'maxiter': 20}) + + main_info('Nelder Mead on rna switch time and alpha..', + indent_level=2) + self.fitting_flag_ = 1 + if self.model == 0 or self.model == 1: + res = minimize(self.mse, x0=[self.params[2], + self.params[4]], + method='Nelder-Mead', tol=1e-2, + callback=self.update, + options={'maxiter': 10}) + elif self.model == 2: + res = minimize(self.mse, x0=[self.params[1], + self.params[4]], + method='Nelder-Mead', tol=1e-2, + callback=self.update, + options={'maxiter': 10}) + + main_info('Nelder Mead on rna switch time, beta, and ' + 'rescale u..', indent_level=2) + self.fitting_flag_ = 3 + if self.model == 0 or self.model == 1: + res = minimize(self.mse, x0=[self.params[2], + self.params[5], + self.params[9]], + method='Nelder-Mead', tol=1e-2, + callback=self.update, + options={'maxiter': 20}) + elif self.model == 2: + res = minimize(self.mse, x0=[self.params[1], + self.params[5], + self.params[9]], + method='Nelder-Mead', tol=1e-2, + callback=self.update, + options={'maxiter': 20}) + + main_info('Nelder Mead on alpha and gamma..', indent_level=2) + self.fitting_flag_ = 2 + res = minimize(self.mse, x0=[self.params[4], + self.params[6]], + method='Nelder-Mead', tol=1e-2, + callback=self.update, + options={'maxiter': 10}) + + main_info('Nelder Mead on t_sw..', indent_level=2) + self.fitting_flag_ = 4 + res = minimize(self.mse, x0=self.params[:3], + method='Nelder-Mead', tol=1e-2, + callback=self.update, + options={'maxiter': 20}) + + else: + + main_info('Adam on all parameters', indent_level=2) + self.AdamMin(np.array(self.params, dtype=self.u.dtype), 20, + tol=1e-2) + + main_info('Nelder Mead on t_sw..', indent_level=2) + self.fitting_flag_ = 4 + res = minimize(self.mse, x0=self.params[:3], + method='Nelder-Mead', tol=1e-2, + callback=self.update, + options={'maxiter': 15}) + + main_info(f'iteration {self.cur_iter} finished', indent_level=2) + + def _variables(self, x): + scale_cc = self.scale_cc + rescale_c = self.rescale_c + rescale_u = self.rescale_u + + # RNA-only + if self.rna_only: + if len(x) == 1: # fit t_sw_2 + t3 = np.array([self.t_sw_1, x[0], + self.t_sw_3 - self.t_sw_1 - x[0]]) + r4 = self.rates + + elif len(x) == 2: + if self.fitting_flag_: # fit alpha and gamma + t3 = self.params[:3] + r4 = np.array([self.alpha_c, x[0], self.beta, x[1]]) + else: # fit t_sw_2 and alpha + t3 = np.array([self.t_sw_1, x[0], + self.t_sw_3 - self.t_sw_1 - x[0]]) + r4 = np.array([self.alpha_c, x[1], self.beta, self.gamma]) + + elif len(x) == 3: # fit t_sw_2, beta, and rescale u + t3 = np.array([self.t_sw_1, + x[0], self.t_sw_3 - self.t_sw_1 - x[0]]) + r4 = np.array([self.alpha_c, self.alpha, x[1], self.gamma]) + rescale_u = x[2] + + elif len(x) == 4: # fit all + t3 = np.array([self.t_sw_1, x[0], self.t_sw_3 - self.t_sw_1 + - x[0]]) + r4 = np.array([self.alpha_c, x[1], x[2], x[3]]) + + elif len(x) == 10: # all available + t3 = x[:3] + r4 = x[3:7] + scale_cc = x[7] + rescale_c = x[8] + rescale_u = x[9] + + else: + return + + # chromatin-RNA + else: + + if len(x) == 2: + if self.fitting_flag_ == 1: # fit rna switch time and alpha + if self.model == 0 or self.model == 1: + t3 = np.array([self.t_sw_1, self.params[1], x[0]]) + elif self.model == 2: + t3 = np.array([self.t_sw_1, x[0], + self.t_sw_3 - self.t_sw_1 - x[0]]) + r4 = np.array([self.alpha_c, x[1], self.beta, self.gamma]) + elif self.fitting_flag_ == 2: # fit alpha and gamma + t3 = self.params[:3] + r4 = np.array([self.alpha_c, x[0], self.beta, x[1]]) + + elif len(x) == 3: + # fit t_sw_1, chromatin switch time, and alpha_c + if self.fitting_flag_ == 1: + if self.model == 0 or self.model == 1: + t3 = np.array([x[0], x[1], self.t_sw_3 - x[0] - x[1]]) + elif self.model == 2: + t3 = np.array([x[0], self.t_sw_2 - x[0], x[1]]) + r4 = np.array([x[2], self.alpha, self.beta, self.gamma]) + # fit chromatin switch time, chromatin closing rate scaling, + # and rescale c + elif self.fitting_flag_ == 2: + if self.model == 0 or self.model == 1: + t3 = np.array([self.t_sw_1, x[0], + self.t_sw_3 - self.t_sw_1 - x[0]]) + elif self.model == 2: + t3 = np.array([self.t_sw_1, self.params[1], x[0]]) + r4 = self.rates + scale_cc = x[1] + rescale_c = x[2] + # fit rna switch time, beta, and rescale u + elif self.fitting_flag_ == 3: + if self.model == 0 or self.model == 1: + t3 = np.array([self.t_sw_1, self.params[1], x[0]]) + elif self.model == 2: + t3 = np.array([self.t_sw_1, x[0], + self.t_sw_3 - self.t_sw_1 - x[0]]) + r4 = np.array([self.alpha_c, self.alpha, x[1], self.gamma]) + rescale_u = x[2] + # fit three switch times + elif self.fitting_flag_ == 4: + t3 = x + r4 = self.rates + + elif len(x) == 7: + t3 = x[:3] + r4 = x[3:] + + elif len(x) == 10: + t3 = x[:3] + r4 = x[3:7] + scale_cc = x[7] + rescale_c = x[8] + rescale_u = x[9] + + else: + return + + # clip to meaningful values + if self.fitting_flag_ and not self.adam: + scale_cc = np.clip(scale_cc, + np.max([0.5*self.scale_cc, 0.25]), + np.min([2*self.scale_cc, 4])) + + if not self.known_pars: + if self.fit_decoupling: + t3 = np.clip(t3, 0.1, None) + else: + t3[2] = 30 / self.n_anchors + t3[:2] = np.clip(t3[:2], 0.1, None) + r4 = np.clip(r4, 0.001, 1000) + rescale_c = np.clip(rescale_c, 0.75, 1.5) + rescale_u = np.clip(rescale_u, 0.2, 3) + + return t3, r4, scale_cc, rescale_c, rescale_u + + # the tensor version of the calculate_dist_and_time function + def calculate_dist_and_time_ten(self, + c, u, s, + t_sw_array, + alpha_c, alpha, beta, gamma, + rescale_c, rescale_u, + scale_cc=1, + scale_factor=None, + model=1, + conn=None, + t=1000, k=1, + direction='complete', + total_h=20, + rna_only=False, + penalize_gap=True, + all_cells=True): + + conn = torch.tensor(conn.todense(), + device=self.device, + dtype=self.torch_type) + + c_ten = torch.tensor(c, device=self.device, dtype=self.torch_type) + u_ten = torch.tensor(u, device=self.device, dtype=self.torch_type) + s_ten = torch.tensor(s, device=self.device, dtype=self.torch_type) + + n = len(u) + if scale_factor is None: + scale_factor_ten = torch.stack((torch.std(c_ten), torch.std(u_ten), + torch.std(s_ten))) + else: + scale_factor_ten = torch.tensor(scale_factor, device=self.device, + dtype=self.torch_type) + + tau_list = self.anchor_points_ten(t_sw_array, total_h, t) + + switch = torch.sum(t_sw_array < total_h) + + exp_list, exp_sw_list = self.generate_exp_tens(tau_list, + t_sw_array[:switch], + alpha_c, + alpha, + beta, + gamma, + model=model, + scale_cc=scale_cc, + rna_only=rna_only) + + rescale_factor = torch.stack((rescale_c, rescale_u, + torch.tensor(1.0, device=self.device, + requires_grad=True, + dtype=self.torch_type))) + + for i in range(len(exp_list)): + exp_list[i] = exp_list[i]*rescale_factor + + if i < len(exp_list)-1: + exp_sw_list[i] = exp_sw_list[i]*rescale_factor + + max_c = 0 + max_u = 0 + max_s = 0 + + if rna_only: + exp_mat = (torch.hstack((torch.reshape(u_ten, (-1, 1)), + torch.reshape(s_ten, (-1, 1)))) + / scale_factor_ten[1:]) + else: + exp_mat = torch.hstack((torch.reshape(c_ten, (-1, 1)), + torch.reshape(u_ten, (-1, 1)), + torch.reshape(s_ten, (-1, 1))))\ + / scale_factor_ten + + taus = torch.zeros((1, n), device=self.device, + requires_grad=True, + dtype=self.torch_type) + anchor_exp, anchor_t = None, None + + dists0 = torch.full((1, n), 0.0 if direction == "on" + or direction == "complete" else np.inf, + device=self.device, + requires_grad=True, + dtype=self.torch_type) + dists1 = torch.full((1, n), 0.0 if direction == "on" + or direction == "complete" else np.inf, + device=self.device, + requires_grad=True, + dtype=self.torch_type) + dists2 = torch.full((1, n), 0.0 if direction == "off" + or direction == "complete" else np.inf, + device=self.device, + requires_grad=True, + dtype=self.torch_type) + dists3 = torch.full((1, n), 0.0 if direction == "off" + or direction == "complete" else np.inf, + device=self.device, + requires_grad=True, + dtype=self.torch_type) + + ts0 = torch.zeros((1, n), device=self.device, + requires_grad=True, + dtype=self.torch_type) + ts1 = torch.zeros((1, n), device=self.device, + requires_grad=True, + dtype=self.torch_type) + ts2 = torch.zeros((1, n), device=self.device, + requires_grad=True, + dtype=self.torch_type) + ts3 = torch.zeros((1, n), device=self.device, + requires_grad=True, + dtype=self.torch_type) + + for i in range(switch+1): + + if not all_cells: + max_ci = (torch.max(exp_list[i][:, 0]) + if exp_list[i].shape[0] > 0 + else 0) + max_c = max_ci if max_ci > max_c else max_c + max_ui = torch.max(exp_list[i][:, 1]) if exp_list[i].shape[0] > 0 \ + else 0 + max_u = max_ui if max_ui > max_u else max_u + max_si = torch.max(exp_list[i][:, 2]) if exp_list[i].shape[0] > 0 \ + else 0 + max_s = max_si if max_si > max_s else max_s + + skip_phase = False + if direction == 'off': + if (model in [1, 2]) and (i < 2): + skip_phase = True + elif direction == 'on': + if (model in [1, 2]) and (i >= 2): + skip_phase = True + if rna_only and i == 0: + skip_phase = True + + if not skip_phase: + if rna_only: + tmp = exp_list[i][:, 1:] / scale_factor_ten[1:] + else: + tmp = exp_list[i] / scale_factor_ten + if anchor_exp is None: + anchor_exp = exp_list[i] + anchor_t = (tau_list[i] + t_sw_array[i-1] if i >= 1 + else tau_list[i]) + else: + anchor_exp = torch.vstack((anchor_exp, exp_list[i])) + anchor_t = torch.hstack((anchor_t, + tau_list[i] + t_sw_array[i-1] + if i >= 1 else tau_list[i])) + + if not all_cells: + anchor_prepend_rna = torch.zeros((1, 2), + device=self.device, + dtype=self.torch_type) + anchor_prepend_chrom = torch.zeros((1, 3), + device=self.device, + dtype=self.torch_type) + anchor_dist = torch.diff(tmp, dim=0, + prepend=anchor_prepend_rna + if rna_only + else anchor_prepend_chrom) + + anchor_dist = torch.sqrt((anchor_dist*anchor_dist) + .sum(axis=1)) + remove_cand = anchor_dist < (0.01*torch.max(exp_mat[1]) + if rna_only + else + 0.01*torch.max(exp_mat[2])) + step_idx = torch.arange(0, anchor_dist.size()[0], 1, + device=self.device, + dtype=self.torch_type) % 3 > 0 + remove_cand &= step_idx + keep_idx = torch.where(~remove_cand)[0] + + tmp = tmp[keep_idx, :] + + model = NearestNeighbors(n_neighbors=k, output_type="numpy") + model.fit(tmp.detach()) + dd, ii = model.kneighbors(exp_mat.detach()) + ii = ii.T[0] + + new_dd = ((exp_mat[:, 0] - tmp[ii, 0]) + * (exp_mat[:, 0] - tmp[ii, 0]) + + (exp_mat[:, 1] - tmp[ii, 1]) + * (exp_mat[:, 1] - tmp[ii, 1]) + + (exp_mat[:, 2] - tmp[ii, 2]) + * (exp_mat[:, 2] - tmp[ii, 2])) + + if k > 1: + new_dd = torch.mean(new_dd, dim=1) + if conn is not None: + new_dd = torch.matmul(conn, new_dd) + + if i == 0: + dists0 = dists0 + new_dd + elif i == 1: + dists1 = dists1 + new_dd + elif i == 2: + dists2 = dists2 + new_dd + elif i == 3: + dists3 = dists3 + new_dd + + if not all_cells: + ii = keep_idx[ii] + if k == 1: + taus = tau_list[i][ii] + else: + for j in range(n): + taus[j] = tau_list[i][ii[j, :]] + + if i == 0: + ts0 = ts0 + taus + elif i == 1: + ts1 = ts1 + taus + t_sw_array[0] + elif i == 2: + ts2 = ts2 + taus + t_sw_array[1] + elif i == 3: + ts3 = ts3 + taus + t_sw_array[2] + + dists = torch.cat((dists0, dists1, dists2, dists3), 0) + + ts = torch.cat((ts0, ts1, ts2, ts3), 0) + + state_pred = torch.argmin(dists, axis=0) + + t_pred = ts[state_pred, torch.arange(n, device=self.device)] + + anchor_t1_list = [] + anchor_t2_list = [] + + t_sw_adjust = torch.zeros(3, device=self.device, dtype=self.torch_type) + + if direction == 'complete': + + dist_gap_add = torch.zeros((1, n), device=self.device, + dtype=self.torch_type) + + t_sorted = torch.clone(t_pred) + t_sorted, t_sorted_indices = torch.sort(t_sorted) + + dt = torch.diff(t_sorted, dim=0, + prepend=torch.zeros(1, device=self.device, + dtype=self.torch_type)) + + gap_thresh = 3*torch.quantile(dt, 0.99) + + idx = torch.where(dt > gap_thresh)[0] + + if len(idx) > 0 and penalize_gap: + h_tens = torch.tensor([total_h], device=self.device, + dtype=self.torch_type) + + for i in idx: + + t1 = t_sorted[i-1] if i > 0 else 0 + t2 = t_sorted[i] + anchor_t1 = anchor_exp[torch.argmin(torch.abs(anchor_t - t1)), + :] + anchor_t2 = anchor_exp[torch.argmin(torch.abs(anchor_t - t2)), + :] + if all_cells: + anchor_t1_list.append(torch.ravel(anchor_t1)) + anchor_t2_list.append(torch.ravel(anchor_t2)) + if not all_cells: + for j in range(1, switch): + crit1 = ((t1 > t_sw_array[j-1]) + and (t2 > t_sw_array[j-1]) + and (t1 <= t_sw_array[j]) + and (t2 <= t_sw_array[j])) + crit2 = ((torch.abs(anchor_t1[2] + - exp_sw_list[j][0, 2]) + < 0.02 * max_s) and + (torch.abs(anchor_t2[2] + - exp_sw_list[j][0, 2]) + < 0.01 * max_s)) + crit3 = ((torch.abs(anchor_t1[1] + - exp_sw_list[j][0, 1]) + < 0.02 * max_u) and + (torch.abs(anchor_t2[1] + - exp_sw_list[j][0, 1]) + < 0.01 * max_u)) + crit4 = ((torch.abs(anchor_t1[0] + - exp_sw_list[j][0, 0]) + < 0.02 * max_c) and + (torch.abs(anchor_t2[0] + - exp_sw_list[j][0, 0]) + < 0.01 * max_c)) + if crit1 and crit2 and crit3 and crit4: + t_sw_adjust[j] += t2 - t1 + if penalize_gap: + dist_gap = torch.sum(((anchor_t1[1:] - anchor_t2[1:]) / + scale_factor_ten[1:])**2) + + idx_to_adjust = torch.tensor(t_pred >= t2, + device=self.device) + + idx_to_adjust = torch.reshape(idx_to_adjust, + (1, idx_to_adjust.size()[0])) + + true_tensor = torch.tensor([True], device=self.device) + false_tensor = torch.tensor([False], device=self.device) + + t_sw_array_ = torch.cat((t_sw_array, h_tens), dim=0) + state_to_adjust = torch.where(t_sw_array_ > t2, + true_tensor, false_tensor) + + dist_gap_add[idx_to_adjust] += dist_gap + + if state_to_adjust[0].item(): + dists0 += dist_gap_add + if state_to_adjust[1].item(): + dists1 += dist_gap_add + if state_to_adjust[2].item(): + dists2 += dist_gap_add + if state_to_adjust[3].item(): + dists3 += dist_gap_add + + dist_gap_add[idx_to_adjust] -= dist_gap + + dists = torch.cat((dists0, dists1, dists2, dists3), 0) + + state_pred = torch.argmin(dists, dim=0) + + if all_cells: + t_pred = ts[torch.arange(n, device=self.device), state_pred] + + min_dist = torch.min(dists, dim=0).values + + if all_cells: + exp_ss_mat = compute_ss_exp(alpha_c, alpha, beta, gamma, + model=model) + if rna_only: + exp_ss_mat[:, 0] = 1 + dists_ss = pairwise_distance_square(exp_mat, exp_ss_mat * + rescale_factor / scale_factor) + + reach_ss = np.full((n, 4), False) + for i in range(n): + for j in range(4): + if min_dist[i] > dists_ss[i, j]: + reach_ss[i, j] = True + late_phase = np.full(n, -1) + for i in range(3): + late_phase[torch.abs(t_pred - t_sw_array[i]) < 0.1] = i + + return min_dist, t_pred, state_pred.cpu().detach().numpy(), \ + reach_ss, late_phase, max_u, max_s, anchor_t1_list, \ + anchor_t2_list + + else: + return min_dist, state_pred.cpu().detach().numpy(), max_u, max_s, \ + t_sw_adjust.cpu().detach().numpy() + + # the torch tensor version of the mse function + def mse_ten(self, x, fit_outlier=False, + penalize_gap=True): + + t3 = x[:3] + r4 = x[3:7] + scale_cc = x[7] + rescale_c = x[8] + rescale_u = x[9] + + if not self.known_pars: + if self.fit_decoupling: + t3 = torch.clip(t3, 0.1, None) + else: + t3[2] = 30 / self.n_anchors + t3[:2] = torch.clip(t3[:2], 0.1, None) + r4 = torch.clip(r4, 0.001, 1000) + rescale_c = torch.clip(rescale_c, 0.75, 1.5) + rescale_u = torch.clip(rescale_u, 0.2, 3) + + t_sw_array = torch.cumsum(t3, dim=0) + + if self.rna_only: + t_sw_array[2] = 20 + + # conditions for minimum switch time and rate params + penalty = 0 + if any(t3 < 0.2) or any(r4 < 0.005): + penalty = (torch.sum(0.2 - t3[t3 < 0.2]) if self.fit_decoupling + else torch.sum(0.2 - t3[:2][t3[:2] < 0.2])) + penalty += torch.sum(0.005 - r4[r4 < 0.005]) * 1e2 + + # condition for all params + if any(x > 500): + penalty = torch.sum(x[x > 500] - 500) * 1e-2 + + c_array = self.c_all if fit_outlier else self.c + u_array = self.u_all if fit_outlier else self.u + s_array = self.s_all if fit_outlier else self.s + + if self.batch_size is not None and self.batch_size < len(c_array): + + subset_choice = np.random.choice(len(c_array), self.batch_size, + replace=False) + + c_array = c_array[subset_choice] + u_array = u_array[subset_choice] + s_array = s_array[subset_choice] + + if fit_outlier: + conn_for_calc = self.conn[subset_choice] + if not fit_outlier: + conn_for_calc = self.conn_sub[subset_choice] + + conn_for_calc = ((conn_for_calc.T)[subset_choice]).T + + else: + + if fit_outlier: + conn_for_calc = self.conn + if not fit_outlier: + conn_for_calc = self.conn_sub + + scale_factor_func = np.array(self.scale_factor, dtype=self.u.dtype) + + # distances and time assignments + res = self.calculate_dist_and_time_ten(c_array, + u_array, + s_array, + t_sw_array, + r4[0], + r4[1], + r4[2], + r4[3], + rescale_c, + rescale_u, + scale_cc=scale_cc, + scale_factor=scale_factor_func, + model=self.model, + direction=self.direction, + conn=conn_for_calc, + k=self.k_dist, + t=self.n_anchors, + rna_only=self.rna_only, + penalize_gap=penalize_gap, + all_cells=fit_outlier) + + if fit_outlier: + min_dist, t_pred, state_pred, reach_ss, late_phase, max_u, max_s, \ + self.anchor_t1_list, self.anchor_t2_list = res + else: + min_dist, state_pred, max_u, max_s, t_sw_adjust = res + + loss = torch.mean(min_dist) + + # avoid exceeding maximum expressions + reg = torch.max(torch.tensor([0, max_s - torch.tensor(self.max_s)], + requires_grad=True, + dtype=self.torch_type))\ + + torch.max(torch.tensor([0, max_u - torch.tensor(self.max_u)], + requires_grad=True, + dtype=self.torch_type)) + + loss += reg + + loss += 1e-1 * penalty + + self.cur_loss = loss.item() + self.cur_state_pred = state_pred + + if fit_outlier: + return loss, t_pred + else: + self.cur_t_sw_adjust = t_sw_adjust + + return loss + + def mse(self, x, fit_outlier=False, penalize_gap=True): + x = np.array(x) + + t3, r4, scale_cc, rescale_c, rescale_u = self._variables(x) + + t_sw_array = np.array([t3[0], t3[0]+t3[1], t3[0]+t3[1]+t3[2]]) + if self.rna_only: + t_sw_array[2] = 20 + + # conditions for minimum switch time and rate params + penalty = 0 + if any(t3 < 0.2) or any(r4 < 0.005): + penalty = (np.sum(0.2 - t3[t3 < 0.2]) if self.fit_decoupling + else np.sum(0.2 - t3[:2][t3[:2] < 0.2])) + penalty += np.sum(0.005 - r4[r4 < 0.005]) * 1e2 + + # condition for all params + if any(x > 500): + penalty = np.sum(x[x > 500] - 500) * 1e-2 + + c_array = self.c_all if fit_outlier else self.c + u_array = self.u_all if fit_outlier else self.u + s_array = self.s_all if fit_outlier else self.s + + if self.neural_net: + + res = calculate_dist_and_time_nn(c_array, + u_array, + s_array, + self.max_u_all if fit_outlier + else self.max_u, + self.max_s_all if fit_outlier + else self.max_s, + t_sw_array, + r4[0], + r4[1], + r4[2], + r4[3], + rescale_c, + rescale_u, + self.ode_model_0, + self.ode_model_1, + self.ode_model_2_m1, + self.ode_model_2_m2, + self.device, + scale_cc=scale_cc, + scale_factor=self.scale_factor, + model=self.model, + direction=self.direction, + conn=self.conn if fit_outlier + else self.conn_sub, + k=self.k_dist, + t=self.n_anchors, + rna_only=self.rna_only, + penalize_gap=penalize_gap, + all_cells=fit_outlier) + + if fit_outlier: + min_dist, t_pred, state_pred, max_u, max_s, nn_penalty = res + else: + min_dist, state_pred, max_u, max_s, nn_penalty = res + + penalty += nn_penalty + + t_sw_adjust = [0, 0, 0] + + else: + + # distances and time assignments + res = calculate_dist_and_time(c_array, + u_array, + s_array, + t_sw_array, + r4[0], + r4[1], + r4[2], + r4[3], + rescale_c, + rescale_u, + scale_cc=scale_cc, + scale_factor=self.scale_factor, + model=self.model, + direction=self.direction, + conn=self.conn if fit_outlier + else self.conn_sub, + k=self.k_dist, + t=self.n_anchors, + rna_only=self.rna_only, + penalize_gap=penalize_gap, + all_cells=fit_outlier) + + if fit_outlier: + min_dist, t_pred, state_pred, reach_ss, late_phase, max_u, \ + max_s, self.anchor_t1_list, self.anchor_t2_list = res + else: + min_dist, state_pred, max_u, max_s, t_sw_adjust = res + + loss = np.mean(min_dist) + + # avoid exceeding maximum expressions + reg = np.max([0, max_s - self.max_s]) + np.max([0, max_u - self.max_u]) + loss += reg + + loss += 1e-1 * penalty + self.cur_loss = loss + self.cur_state_pred = state_pred + + if fit_outlier: + return loss, t_pred + else: + self.cur_t_sw_adjust = t_sw_adjust + + return loss + + def update(self, x, perform_update=False, initialize=False, + fit_outlier=False, adjust_time=True, penalize_gap=True, + plot=True): + t3, r4, scale_cc, rescale_c, rescale_u = self._variables(x) + t_sw_array = np.array([t3[0], t3[0]+t3[1], t3[0]+t3[1]+t3[2]]) + + # read results + if initialize: + new_loss = self.mse(x, penalize_gap=penalize_gap) + elif fit_outlier: + new_loss, t_pred = self.mse(x, fit_outlier=True, + penalize_gap=penalize_gap) + else: + new_loss = self.cur_loss + t_sw_adjust = self.cur_t_sw_adjust + state_pred = self.cur_state_pred + + if new_loss < self.loss[-1] or perform_update: + perform_update = True + + self.loss.append(new_loss) + self.alpha_c, self.alpha, self.beta, self.gamma = r4 + self.rates = r4 + self.scale_cc = scale_cc + self.rescale_c = rescale_c + self.rescale_u = rescale_u + + # adjust overcrowded anchors + if not fit_outlier and adjust_time: + t_sw_array -= np.cumsum(t_sw_adjust) + if self.rna_only: + t_sw_array[2] = 20 + + self.t_sw_1, self.t_sw_2, self.t_sw_3 = t_sw_array + self.t_sw_array = t_sw_array + self.params = np.array([self.t_sw_1, + self.t_sw_2-self.t_sw_1, + self.t_sw_3-self.t_sw_2, + self.alpha_c, + self.alpha, + self.beta, + self.gamma, + self.scale_cc, + self.rescale_c, + self.rescale_u]) + if not initialize: + self.state = state_pred + if fit_outlier: + self.t = t_pred + + main_info(f'params updated as: {self.t_sw_array} {self.rates} ' + f'{self.scale_cc} {self.rescale_c} {self.rescale_u}', + indent_level=2) + + # interactive plot + if self.plot and plot: + tau_list = anchor_points(self.t_sw_array, 20, self.n_anchors) + switch = np.sum(self.t_sw_array < 20) + typed_tau_list = List() + [typed_tau_list.append(x) for x in tau_list] + self.alpha_c, self.alpha, self.beta, self.gamma, \ + self.c0, self.u0, self.s0 = \ + check_params(self.alpha_c, self.alpha, self.beta, + self.gamma, c0=self.c0, u0=self.u0, + s0=self.s0) + exp_list, exp_sw_list = generate_exp(typed_tau_list, + self.t_sw_array[:switch], + self.alpha_c, + self.alpha, + self.beta, + self.gamma, + scale_cc=self.scale_cc, + model=self.model, + rna_only=self.rna_only) + rescale_factor = np.array([self.rescale_c, + self.rescale_u, + 1.0]) + exp_list = [x*rescale_factor for x in exp_list] + exp_sw_list = [x*rescale_factor for x in exp_sw_list] + c = np.ravel(np.concatenate([exp_list[x][:, 0] for x in + range(switch+1)])) + u = np.ravel(np.concatenate([exp_list[x][:, 1] for x in + range(switch+1)])) + s = np.ravel(np.concatenate([exp_list[x][:, 2] for x in + range(switch+1)])) + c_ = self.c_all if fit_outlier else self.c + u_ = self.u_all if fit_outlier else self.u + s_ = self.s_all if fit_outlier else self.s + self.ax.clear() + plt.pause(0.1) + if self.rna_only: + self.ax.scatter(s, u, s=self.point_size*1.5, c='black', + alpha=0.6, zorder=2) + if switch >= 1: + c_sw1, u_sw1, s_sw1 = exp_sw_list[0][0] + self.ax.plot([s_sw1], [u_sw1], "om", + markersize=self.point_size, zorder=5) + if switch >= 2: + c_sw2, u_sw2, s_sw2 = exp_sw_list[1][0] + self.ax.plot([s_sw2], [u_sw2], "Xm", + markersize=self.point_size, zorder=5) + if switch == 3: + c_sw3, u_sw3, s_sw3 = exp_sw_list[2][0] + self.ax.plot([s_sw3], [u_sw3], "Dm", + markersize=self.point_size, zorder=5) + if np.max(self.t) == 20: + self.ax.plot([s[-1]], [u[-1]], "*m", + markersize=self.point_size, zorder=5) + for i in range(4): + if any(self.state == i): + self.ax.scatter(s_[(self.state == i)], + u_[(self.state == i)], + s=self.point_size, c=self.color[i]) + self.ax.set_xlabel('s') + self.ax.set_ylabel('u') + + else: + self.ax.scatter(s, u, c, s=self.point_size*1.5, + c='black', alpha=0.6, zorder=2) + if switch >= 1: + c_sw1, u_sw1, s_sw1 = exp_sw_list[0][0] + self.ax.plot([s_sw1], [u_sw1], [c_sw1], "om", + markersize=self.point_size, zorder=5) + if switch >= 2: + c_sw2, u_sw2, s_sw2 = exp_sw_list[1][0] + self.ax.plot([s_sw2], [u_sw2], [c_sw2], "Xm", + markersize=self.point_size, zorder=5) + if switch == 3: + c_sw3, u_sw3, s_sw3 = exp_sw_list[2][0] + self.ax.plot([s_sw3], [u_sw3], [c_sw3], "Dm", + markersize=self.point_size, zorder=5) + if np.max(self.t) == 20: + self.ax.plot([s[-1]], [u[-1]], [c[-1]], "*m", + markersize=self.point_size, zorder=5) + for i in range(4): + if any(self.state == i): + self.ax.scatter(s_[(self.state == i)], + u_[(self.state == i)], + c_[(self.state == i)], + s=self.point_size, c=self.color[i]) + self.ax.set_xlabel('s') + self.ax.set_ylabel('u') + self.ax.set_zlabel('c') + self.fig.canvas.draw() + plt.pause(0.1) + return perform_update + + def save_dyn_plot(self, c, u, s, c_sw, u_sw, s_sw, tau_list, + show_all=False): + if not os.path.exists(self.plot_path): + os.makedirs(self.plot_path) + main_info(f'{self.plot_path} directory created.', indent_level=2) + + switch = np.sum(self.t_sw_array < 20) + scale_back = np.array([self.scale_c, self.scale_u, self.scale_s]) + shift_back = np.array([self.offset_c, self.offset_u, self.offset_s]) + if switch >= 1: + c_sw1, u_sw1, s_sw1 = c_sw[0], u_sw[0], s_sw[0] + if switch >= 2: + c_sw2, u_sw2, s_sw2 = c_sw[1], u_sw[1], s_sw[1] + if switch == 3: + c_sw3, u_sw3, s_sw3 = c_sw[2], u_sw[2], s_sw[2] + + if not show_all: + n_anchors = len(u) + t_lower = np.min(self.t) + t_upper = np.max(self.t) + t_ = np.concatenate((tau_list[0], tau_list[1] + self.t_sw_array[0], + tau_list[2] + self.t_sw_array[1], + tau_list[3] + self.t_sw_array[2])) + c_pre = c[t_[:n_anchors] <= t_lower] + u_pre = u[t_[:n_anchors] <= t_lower] + s_pre = s[t_[:n_anchors] <= t_lower] + c = c[(t_lower < t_[:n_anchors]) & (t_[:n_anchors] < t_upper)] + u = u[(t_lower < t_[:n_anchors]) & (t_[:n_anchors] < t_upper)] + s = s[(t_lower < t_[:n_anchors]) & (t_[:n_anchors] < t_upper)] + + c_all = self.c_all * self.scale_c + self.offset_c + u_all = self.u_all * self.scale_u + self.offset_u + s_all = self.s_all * self.scale_s + self.offset_s + + fig = plt.figure(figsize=self.fig_size) + fig.patch.set_facecolor('white') + ax = fig.add_subplot(111, facecolor='white') + if not show_all and len(u_pre) > 0: + ax.scatter(s_pre, u_pre, s=self.point_size/2, c='black', + alpha=0.4, zorder=2) + ax.scatter(s, u, s=self.point_size*1.5, c='black', alpha=0.6, zorder=2) + for i in range(4): + if any(self.state == i): + ax.scatter(s_all[(self.state == i) & (self.non_outlier)], + u_all[(self.state == i) & (self.non_outlier)], + s=self.point_size, c=self.color[i]) + ax.scatter(s_all[~self.non_outlier], u_all[~self.non_outlier], + s=self.point_size/2, c='grey') + if show_all or t_lower <= self.t_sw_array[0]: + ax.plot([s_sw1], [u_sw1], "om", markersize=self.point_size, + zorder=5) + if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] and + t_upper >= self.t_sw_array[1])): + ax.plot([s_sw2], [u_sw2], "Xm", markersize=self.point_size, + zorder=5) + if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] and + t_upper >= self.t_sw_array[2])): + ax.plot([s_sw3], [u_sw3], "Dm", markersize=self.point_size, + zorder=5) + if np.max(self.t) == 20: + ax.plot([s[-1]], [u[-1]], "*m", markersize=self.point_size, + zorder=5) + if (self.anchor_t1_list is not None and len(self.anchor_t1_list) > 0 + and show_all): + for i in range(len(self.anchor_t1_list)): + exp_t1 = self.anchor_t1_list[i] * scale_back + shift_back + exp_t2 = self.anchor_t2_list[i] * scale_back + shift_back + ax.plot([exp_t1[2]], [exp_t1[1]], "|y", + markersize=self.point_size*1.5) + ax.plot([exp_t2[2]], [exp_t2[1]], "|c", + markersize=self.point_size*1.5) + ax.plot(s_all, + self.steady_state_func(self.s_all) * self.scale_u + + self.offset_u, c='grey', ls=':', lw=self.point_size/4, + alpha=0.7) + ax.set_xlabel('s') + ax.set_ylabel('u') + ax.set_title(f'{self.gene}-{self.model}') + plt.tight_layout() + fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-us.png', + dpi=fig.dpi, facecolor=fig.get_facecolor(), + transparent=False, edgecolor='none') + plt.close(fig) + plt.pause(0.2) + + if self.extra_color is not None: + fig = plt.figure(figsize=self.fig_size) + fig.patch.set_facecolor('white') + ax = fig.add_subplot(111, facecolor='white') + if not show_all and len(u_pre) > 0: + ax.scatter(s_pre, u_pre, s=self.point_size/2, c='black', + alpha=0.4, zorder=2) + ax.scatter(s, u, s=self.point_size*1.5, c='black', alpha=0.6, + zorder=2) + ax.scatter(s_all, u_all, s=self.point_size, c=self.extra_color) + if show_all or t_lower <= self.t_sw_array[0]: + ax.plot([s_sw1], [u_sw1], "om", markersize=self.point_size, + zorder=5) + if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] and + t_upper >= self.t_sw_array[1])): + ax.plot([s_sw2], [u_sw2], "Xm", markersize=self.point_size, + zorder=5) + if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] and + t_upper >= self.t_sw_array[2])): + ax.plot([s_sw3], [u_sw3], "Dm", markersize=self.point_size, + zorder=5) + if np.max(self.t) == 20: + ax.plot([s[-1]], [u[-1]], "*m", markersize=self.point_size, + zorder=5) + if (self.anchor_t1_list is not None and + len(self.anchor_t1_list) > 0 and show_all): + for i in range(len(self.anchor_t1_list)): + exp_t1 = self.anchor_t1_list[i] * scale_back + shift_back + exp_t2 = self.anchor_t2_list[i] * scale_back + shift_back + ax.plot([exp_t1[2]], [exp_t1[1]], "|y", + markersize=self.point_size*1.5) + ax.plot([exp_t2[2]], [exp_t2[1]], "|c", + markersize=self.point_size*1.5) + ax.plot(s_all, self.steady_state_func(self.s_all) * self.scale_u + + self.offset_u, c='grey', ls=':', lw=self.point_size/4, + alpha=0.7) + ax.set_xlabel('s') + ax.set_ylabel('u') + ax.set_title(f'{self.gene}-{self.model}') + plt.tight_layout() + fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-' + 'us_colorby_extra.png', dpi=fig.dpi, + facecolor=fig.get_facecolor(), transparent=False, + edgecolor='none') + plt.close(fig) + plt.pause(0.2) + + if not self.rna_only: + fig = plt.figure(figsize=self.fig_size) + fig.patch.set_facecolor('white') + ax = fig.add_subplot(111, facecolor='white') + if not show_all and len(u_pre) > 0: + ax.scatter(u_pre, c_pre, s=self.point_size/2, c='black', + alpha=0.4, zorder=2) + ax.scatter(u, c, s=self.point_size*1.5, c='black', alpha=0.6, + zorder=2) + ax.scatter(u_all, c_all, s=self.point_size, c=self.extra_color) + if show_all or t_lower <= self.t_sw_array[0]: + ax.plot([u_sw1], [c_sw1], "om", markersize=self.point_size, + zorder=5) + if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] + and t_upper >= + self.t_sw_array[1])): + ax.plot([u_sw2], [c_sw2], "Xm", markersize=self.point_size, + zorder=5) + if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] + and t_upper >= + self.t_sw_array[2])): + ax.plot([u_sw3], [c_sw3], "Dm", markersize=self.point_size, + zorder=5) + if np.max(self.t) == 20: + ax.plot([u[-1]], [c[-1]], "*m", markersize=self.point_size, + zorder=5) + ax.set_xlabel('u') + ax.set_ylabel('c') + ax.set_title(f'{self.gene}-{self.model}') + plt.tight_layout() + fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-' + 'cu_colorby_extra.png', dpi=fig.dpi, + facecolor=fig.get_facecolor(), transparent=False, + edgecolor='none') + plt.close(fig) + plt.pause(0.2) + + if not self.rna_only: + fig = plt.figure(figsize=self.fig_size) + fig.patch.set_facecolor('white') + ax = fig.add_subplot(111, projection='3d', facecolor='white') + if not show_all and len(u_pre) > 0: + ax.scatter(s_pre, u_pre, c_pre, s=self.point_size/2, c='black', + alpha=0.4, zorder=2) + ax.scatter(s, u, c, s=self.point_size*1.5, c='black', alpha=0.6, + zorder=2) + for i in range(4): + if any(self.state == i): + ax.scatter(s_all[(self.state == i) & (self.non_outlier)], + u_all[(self.state == i) & (self.non_outlier)], + c_all[(self.state == i) & (self.non_outlier)], + s=self.point_size, c=self.color[i]) + ax.scatter(s_all[~self.non_outlier], u_all[~self.non_outlier], + c_all[~self.non_outlier], s=self.point_size/2, c='grey') + if show_all or t_lower <= self.t_sw_array[0]: + ax.plot([s_sw1], [u_sw1], [c_sw1], "om", + markersize=self.point_size, zorder=5) + if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] and + t_upper >= self.t_sw_array[1])): + ax.plot([s_sw2], [u_sw2], [c_sw2], "Xm", + markersize=self.point_size, zorder=5) + if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] and + t_upper >= self.t_sw_array[2])): + ax.plot([s_sw3], [u_sw3], [c_sw3], "Dm", + markersize=self.point_size, zorder=5) + if np.max(self.t) == 20: + ax.plot([s[-1]], [u[-1]], [c[-1]], "*m", + markersize=self.point_size, zorder=5) + ax.set_xlabel('s') + ax.set_ylabel('u') + ax.set_zlabel('c') + ax.set_title(f'{self.gene}-{self.model}') + plt.tight_layout() + fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-cus.png', + dpi=fig.dpi, facecolor=fig.get_facecolor(), + transparent=False, edgecolor='none') + plt.close(fig) + plt.pause(0.2) + + fig = plt.figure(figsize=self.fig_size) + fig.patch.set_facecolor('white') + ax = fig.add_subplot(111, facecolor='white') + if not show_all and len(u_pre) > 0: + ax.scatter(s_pre, u_pre, s=self.point_size/2, c='black', + alpha=0.4, zorder=2) + ax.scatter(s, u, s=self.point_size*1.5, c='black', alpha=0.6, + zorder=2) + ax.scatter(s_all, u_all, s=self.point_size, c=np.log1p(self.c_all), + cmap='coolwarm') + if show_all or t_lower <= self.t_sw_array[0]: + ax.plot([s_sw1], [u_sw1], "om", markersize=self.point_size, + zorder=5) + if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] and + t_upper >= self.t_sw_array[1])): + ax.plot([s_sw2], [u_sw2], "Xm", markersize=self.point_size, + zorder=5) + if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] and + t_upper >= self.t_sw_array[2])): + ax.plot([s_sw3], [u_sw3], "Dm", markersize=self.point_size, + zorder=5) + if np.max(self.t) == 20: + ax.plot([s[-1]], [u[-1]], "*m", markersize=self.point_size, + zorder=5) + ax.plot(s_all, self.steady_state_func(self.s_all) * self.scale_u + + self.offset_u, c='grey', ls=':', lw=self.point_size/4, + alpha=0.7) + ax.set_xlabel('s') + ax.set_ylabel('u') + ax.set_title(f'{self.gene}-{self.model}') + plt.tight_layout() + fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-' + 'us_colorby_c.png', dpi=fig.dpi, + facecolor=fig.get_facecolor(), transparent=False, + edgecolor='none') + plt.close(fig) + plt.pause(0.2) + + fig = plt.figure(figsize=self.fig_size) + fig.patch.set_facecolor('white') + ax = fig.add_subplot(111, facecolor='white') + if not show_all and len(u_pre) > 0: + ax.scatter(u_pre, c_pre, s=self.point_size/2, c='black', + alpha=0.4, zorder=2) + ax.scatter(u, c, s=self.point_size*1.5, c='black', alpha=0.6, + zorder=2) + for i in range(4): + if any(self.state == i): + ax.scatter(u_all[(self.state == i) & (self.non_outlier)], + c_all[(self.state == i) & (self.non_outlier)], + s=self.point_size, c=self.color[i]) + ax.scatter(u_all[~self.non_outlier], c_all[~self.non_outlier], + s=self.point_size/2, c='grey') + if show_all or t_lower <= self.t_sw_array[0]: + ax.plot([u_sw1], [c_sw1], "om", markersize=self.point_size, + zorder=5) + if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] and + t_upper >= self.t_sw_array[1])): + ax.plot([u_sw2], [c_sw2], "Xm", markersize=self.point_size, + zorder=5) + if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] and + t_upper >= self.t_sw_array[2])): + ax.plot([u_sw3], [c_sw3], "Dm", markersize=self.point_size, + zorder=5) + if np.max(self.t) == 20: + ax.plot([u[-1]], [c[-1]], "*m", markersize=self.point_size, + zorder=5) + ax.set_xlabel('u') + ax.set_ylabel('c') + ax.set_title(f'{self.gene}-{self.model}') + plt.tight_layout() + fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-cu.png', + dpi=fig.dpi, facecolor=fig.get_facecolor(), + transparent=False, edgecolor='none') + plt.close(fig) + plt.pause(0.2) + + def get_loss(self): + return self.loss + + def get_model(self): + return self.model + + def get_params(self): + return self.t_sw_array, self.rates, self.scale_cc, self.rescale_c, \ + self.rescale_u, self.realign_ratio + + def is_partial(self): + return self.partial + + def get_direction(self): + return self.direction + + def realign_time_and_velocity(self, c, u, s, anchor_time): + # realign time to range (0,20) + self.anchor_min_idx = np.sum(anchor_time < (np.min(self.t)-1e-5)) + self.anchor_max_idx = np.sum(anchor_time < (np.max(self.t)-1e-5)) + self.c0 = c[self.anchor_min_idx] + self.u0 = u[self.anchor_min_idx] + self.s0 = s[self.anchor_min_idx] + self.realign_ratio = 20 / (np.max(self.t) - np.min(self.t)) + main_info(f'fitted params:\nswitch time array = {self.t_sw_array},\n' + f'rates = {self.rates},\ncc scale = {self.scale_cc},\n' + f'c rescale factor = {self.rescale_c},\n' + f'u rescale factor = {self.rescale_u}', + indent_level=1) + main_info(f'aligning to range (0,20) by {self.realign_ratio}..', + indent_level=1) + self.rates /= self.realign_ratio + self.alpha_c, self.alpha, self.beta, self.gamma = self.rates + self.params[3:7] = self.rates + self.t_sw_array = ((self.t_sw_array - np.min(self.t)) + * self.realign_ratio) + self.t_sw_1, self.t_sw_2, self.t_sw_3 = self.t_sw_array + self.params[:3] = np.array([self.t_sw_1, self.t_sw_2 - self.t_sw_1, + self.t_sw_3 - self.t_sw_2]) + self.t -= np.min(self.t) + self.t = self.t * 20 / np.max(self.t) + self.velocity /= self.realign_ratio + self.velocity[:, 0] = np.clip(self.velocity[:, 0], -self.c_all + * self.scale_c, None) + self.velocity[:, 1] = np.clip(self.velocity[:, 1], -self.u_all + * self.scale_u, None) + self.velocity[:, 2] = np.clip(self.velocity[:, 2], -self.s_all + * self.scale_s, None) + self.anchor_velo /= self.realign_ratio + self.anchor_velo[:, 0] = np.clip(self.anchor_velo[:, 0], + -np.max(self.c_all * self.scale_c), + None) + self.anchor_velo[:, 1] = np.clip(self.anchor_velo[:, 1], + -np.max(self.u_all * self.scale_u), + None) + self.anchor_velo[:, 2] = np.clip(self.anchor_velo[:, 2], + -np.max(self.s_all * self.scale_s), + None) + + def get_initial_exp(self): + return np.array([self.c0, self.u0, self.s0]) + + def get_time_assignment(self): + if self.low_quality: + return np.zeros(len(self.u_all)) + return self.t + + def get_state_assignment(self): + if self.low_quality: + return np.zeros(len(self.u_all)) + return self.state + + def get_velocity(self): + if self.low_quality: + return np.zeros((len(self.u_all), 3)) + return self.velocity + + def get_likelihood(self): + return self.likelihood, self.l_c, self.ssd_c, self.var_c + + def get_anchors(self): + if self.low_quality: + return (np.zeros((1, 3)), np.zeros((1, 3)), np.zeros((1, 3)), + 0, 0, 0, 0) + return self.anchor_exp, self.anchor_exp_sw, self.anchor_velo, \ + self.anchor_min_idx, self.anchor_max_idx, \ + self.anchor_velo_min_idx, self.anchor_velo_max_idx + + +def regress_func(c, u, s, m, mi, im, dev, nn, ad, lr, b1, b2, bs, gpdist, + embed, conn, pl, sp, pdir, fa, gene, pa, di, ro, fit, fd, + extra, ru, alpha, beta, gamma, t_, verbosity, log_folder, + log_filename): + + settings.VERBOSITY = verbosity + settings.LOG_FOLDER = log_folder + settings.LOG_FILENAME = log_filename + settings.GENE = gene + + if m is not None: + main_info('#########################################################' + '######################################', indent_level=1) + main_info(f'testing model {m}', indent_level=1) + + c_90 = np.percentile(c, 90) + u_90 = np.percentile(u, 90) + s_90 = np.percentile(s, 90) + low_quality = (u_90 == 0 or s_90 == 0) if ro else (c_90 == 0 or u_90 == 0 + or s_90 == 0) + if low_quality: + main_info(f'low quality gene {gene}, skipping', indent_level=1) + return (np.inf, np.nan, '', (np.zeros(3), np.zeros(4), 0, 0, 0, 0), + np.zeros(3), np.zeros(len(u)), np.zeros(len(u)), + np.zeros((len(u), 3)), (-1.0, 0, 0, 0), + (np.zeros((1, 3)), np.zeros((1, 3)), np.zeros((1, 3)), 0, 0, + 0, 0)) + + if gpdist is not None: + subset_cells = s > 0.1 * np.percentile(s, 99) + subset_cells = np.where(subset_cells)[0] + if len(subset_cells) > 3000: + rng = np.random.default_rng(2021) + subset_cells = rng.choice(subset_cells, 3000, replace=False) + local_pdist = gpdist[np.ix_(subset_cells, subset_cells)] + dists = (np.ravel(local_pdist[np.triu_indices_from(local_pdist, k=1)]) + .reshape(-1, 1)) + local_std = np.std(dists) + else: + local_std = None + + cdc = ChromatinDynamical(c, + u, + s, + model=m, + max_iter=mi, + init_mode=im, + device=dev, + neural_net=nn, + adam=ad, + adam_lr=lr, + adam_beta1=b1, + adam_beta2=b2, + batch_size=bs, + local_std=local_std, + embed_coord=embed, + connectivities=conn, + plot=pl, + save_plot=sp, + plot_dir=pdir, + fit_args=fa, + gene=gene, + partial=pa, + direction=di, + rna_only=ro, + fit_decoupling=fd, + extra_color=extra, + rescale_u=ru, + alpha=alpha, + beta=beta, + gamma=gamma, + t_=t_) + if fit: + loss = cdc.fit() + if loss[-1] == np.inf: + main_info(f'low quality gene {gene}, skipping..', indent_level=1) + loss = cdc.get_loss() + model = cdc.get_model() + direction = cdc.get_direction() + parameters = cdc.get_params() + initial_exp = cdc.get_initial_exp() + velocity = cdc.get_velocity() + likelihood = cdc.get_likelihood() + time = cdc.get_time_assignment() + state = cdc.get_state_assignment() + anchors = cdc.get_anchors() + return loss[-1], model, direction, parameters, initial_exp, time, state, \ + velocity, likelihood, anchors + + +def multimodel_helper(c, u, s, + model_to_run, + max_iter, + init_mode, + device, + neural_net, + adam, + adam_lr, + adam_beta1, + adam_beta2, + batch_size, + global_pdist, + embed_coord, + conn, + plot, + save_plot, + plot_dir, + fit_args, + gene, + partial, + direction, + rna_only, + fit, + fit_decoupling, + extra_color, + rescale_u, + alpha, + beta, + gamma, + t_, + verbosity, log_folder, log_filename): + + loss, param_cand, initial_cand, time_cand = [], [], [], [] + state_cand, velo_cand, likelihood_cand, anch_cand = [], [], [], [] + + for model in model_to_run: + (loss_m, _, direction_, parameters, initial_exp, + time, state, velocity, likelihood, anchors) = \ + regress_func(c, u, s, model, max_iter, init_mode, device, neural_net, + adam, adam_lr, adam_beta1, adam_beta2, batch_size, + global_pdist, embed_coord, conn, plot, save_plot, + plot_dir, fit_args, gene, partial, direction, rna_only, + fit, fit_decoupling, extra_color, rescale_u, alpha, beta, + gamma, t_) + loss.append(loss_m) + param_cand.append(parameters) + initial_cand.append(initial_exp) + time_cand.append(time) + state_cand.append(state) + velo_cand.append(velocity) + likelihood_cand.append(likelihood) + anch_cand.append(anchors) + + best_model = np.argmin(loss) + model = np.nan if rna_only else model_to_run[best_model] + parameters = param_cand[best_model] + initial_exp = initial_cand[best_model] + time = time_cand[best_model] + state = state_cand[best_model] + velocity = velo_cand[best_model] + likelihood = likelihood_cand[best_model] + anchors = anch_cand[best_model] + return loss, model, direction_, parameters, initial_exp, time, state, \ + velocity, likelihood, anchors + + +def recover_dynamics_chrom(adata_rna, + adata_atac=None, + gene_list=None, + max_iter=5, + init_mode='invert', + device="cpu", + neural_net=False, + adam=False, + adam_lr=None, + adam_beta1=None, + adam_beta2=None, + batch_size=None, + model_to_run=None, + plot=False, + parallel=True, + n_jobs=None, + save_plot=False, + plot_dir=None, + rna_only=False, + fit=True, + fit_decoupling=True, + extra_color_key=None, + embedding='X_umap', + n_anchors=500, + k_dist=1, + thresh_multiplier=1.0, + weight_c=0.6, + outlier=99.8, + n_pcs=30, + n_neighbors=30, + fig_size=(8, 6), + point_size=7, + partial=None, + direction=None, + rescale_u=None, + alpha=None, + beta=None, + gamma=None, + t_sw=None + ): + + """Multi-omic dynamics recovery. + + This function optimizes the joint chromatin and RNA model parameters in + ODE solutions. + + Parameters + ---------- + adata_rna: :class:`~anndata.AnnData` + RNA anndata object. Required fields: `Mu`, `Ms`, and `connectivities`. + adata_atac: :class:`~anndata.AnnData` (default: `None`) + ATAC anndata object. Required fields: `Mc`. + gene_list: `str`, list of `str` (default: highly variable genes) + Genes to use for model fitting. + max_iter: `int` (default: `5`) + Iterations to run for parameter optimization. + init_mode: `str` (default: `'invert'`) + Initialization method for switch times. + `'invert'`: initial RNA switch time will be computed with scVelo time + inversion method. + `'grid'`: grid search the best set of switch times. + `'simple'`: simply initialize switch times to be 5, 10, and 15. + device: `str` (default: `'cpu'`) + The CUDA device that pytorch tensor calculations will be run on. Only + to be used with Adam or Neural Network mode. + neural_net: `bool` (default: `False`) + Whether to run time predictions with a neural network or not. Shortens + runtime at the expense of accuracy. If False, uses the usual method of + assigning each data point to an anchor time point as outlined in the + Multivelo paper. + adam: `bool` (default: `False`) + Whether MSE minimization is handled by the Adam algorithm or not. When + set to the default of False, function uses Nelder-Mead instead. + adam_lr: `float` (default: `None`) + The learning rate to use the Adam algorithm. If adam is False, this + value is ignored. + adam_beta1: `float` (default: `None`) + The beta1 parameter for the Adam algorithm. If adam is False, this + value is ignored. + adam_beta2: `float` (default: `None`) + The beta2 parameter for the Adam algorithm. If adam is False, this + value is ignored. + batch_size: `int` (default: `None`) + Speeds up performance using minibatch training. Specifies number of + cells to use per run of MSE when running the Adam algorithm. Ignored + if Adam is set to False. + model_to_run: `int` or list of `int` (default: `None`) + User specified models for each genes. Possible values are 1 are 2. If + `None`, the model + for each gene will be inferred based on expression patterns. If more + than one value is given, + the best model will be decided based on loss of fit. + plot: `bool` or `None` (default: `False`) + Whether to interactively plot the 3D gene portraits. Ignored if + parallel is True. + parallel: `bool` (default: `True`) + Whether to fit genes in a parallel fashion (recommended). + n_jobs: `int` (default: available threads) + Number of parallel jobs. + save_plot: `bool` (default: `False`) + Whether to save the fitted gene portrait figures as files. This will + take some disk space. + plot_dir: `str` (default: `plots` for multiome and `rna_plots` for + RNA-only) + Directory to save the plots. + rna_only: `bool` (default: `False`) + Whether to only use RNA for fitting (RNA velocity). + fit: `bool` (default: `True`) + Whether to fit the models. If False, only pre-determination and + initialization will be run. + fit_decoupling: `bool` (default: `True`) + Whether to fit decoupling phase (Model 1 vs Model 2 distinction). + n_anchors: `int` (default: 500) + Number of anchor time-points to generate as a representation of the + trajectory. + k_dist: `int` (default: 1) + Number of anchors to use to determine a cell's gene time. If more than + 1, time will be averaged. + thresh_multiplier: `float` (default: 1.0) + Multiplier for the heuristic threshold of partial versus complete + trajectory pre-determination. + weight_c: `float` (default: 0.6) + Weighting of scaled chromatin distances when performing 3D residual + calculation. + outlier: `float` (default: 99.8) + The percentile to mark as outlier that will be excluded when fitting + the model. + n_pcs: `int` (default: 30) + Number of principal components to compute distance smoothing neighbors. + This can be different from the one used for expression smoothing. + n_neighbors: `int` (default: 30) + Number of nearest neighbors for distance smoothing. + This can be different from the one used for expression smoothing. + fig_size: `tuple` (default: (8,6)) + Size of each figure when saved. + point_size: `float` (default: 7) + Marker point size for plotting. + extra_color_key: `str` (default: `None`) + Extra color key used for plotting. Common choices are `leiden`, + `celltype`, etc. + The colors for each category must be present in one of anndatas, which + can be pre-computed + with `scanpy.pl.scatter` function. + embedding: `str` (default: `X_umap`) + 2D coordinates of the low-dimensional embedding of cells. + partial: `bool` or list of `bool` (default: `None`) + User specified trajectory completeness for each gene. + direction: `str` or list of `str` (default: `None`) + User specified trajectory directionality for each gene. + rescale_u: `float` or list of `float` (default: `None`) + Known scaling factors for unspliced. Can be computed from scVelo + `fit_scaling` values + as `rescale_u = fit_scaling / std(u) * std(s)`. + alpha: `float` or list of `float` (default: `None`) + Known trascription rates. Can be computed from scVelo `fit_alpha` + values + as `alpha = fit_alpha * fit_alignment_scaling`. + beta: `float` or list of `float` (default: `None`) + Known splicing rates. Can be computed from scVelo `fit_alpha` values + as `beta = fit_beta * fit_alignment_scaling`. + gamma: `float` or list of `float` (default: `None`) + Known degradation rates. Can be computed from scVelo `fit_gamma` values + as `gamma = fit_gamma * fit_alignment_scaling`. + t_sw: `float` or list of `float` (default: `None`) + Known RNA switch time. Can be computed from scVelo `fit_t_` values + as `t_sw = fit_t_ / fit_alignment_scaling`. + + Returns + ------- + fit_alpha_c, fit_alpha, fit_beta, fit_gamma: `.var` + inferred chromatin opening, transcription, splicing, and degradation + (nuclear export) rates + fit_t_sw1, fit_t_sw2, fit_t_sw3: `.var` + inferred switching time points + fit_rescale_c, fit_rescale_u: `.var` + inferred scaling factor for chromatin and unspliced counts + fit_scale_cc: `.var` + inferred scaling value for chromatin closing rate compared to opening + rate + fit_alignment_scaling: `.var` + ratio used to realign observed time range to 0-20 + fit_c0, fit_u0, fit_s0: `.var` + initial expression values at earliest observed time + fit_model: `.var` + inferred gene model + fit_direction: `.var` + inferred gene direction + fit_loss: `.var` + loss of model fit + fit_likelihood: `.var` + likelihood of model fit + fit_likelihood_c: `.var` + likelihood of chromatin fit + fit_anchor_c, fit_anchor_u, fit_anchor_s: `.varm` + anchor expressions + fit_anchor_c_sw, fit_anchor_u_sw, fit_anchor_s_sw: `.varm` + switch time-point expressions + fit_anchor_c_velo, fit_anchor_u_velo, fit_anchor_s_velo: `.varm` + velocities of anchors + fit_anchor_min_idx: `.var` + first anchor mapped to observations + fit_anchor_max_idx: `.var` + last anchor mapped to observations + fit_anchor_velo_min_idx: `.var` + first velocity anchor mapped to observations + fit_anchor_velo_max_idx: `.var` + last velocity anchor mapped to observations + fit_t: `.layers` + inferred gene time + fit_state: `.layers` + inferred state assignments + velo_s, velo_u, velo_chrom: `.layers` + velocities in spliced, unspliced, and chromatin space + velo_s_genes, velo_u_genes, velo_chrom_genes: `.var` + velocity genes + velo_s_params, velo_u_params, velo_chrom_params: `.var` + fitting arguments used + ATAC: `.layers` + KNN smoothed chromatin accessibilities copied from adata_atac + """ + + fit_args = {} + fit_args['max_iter'] = max_iter + fit_args['init_mode'] = init_mode + fit_args['fit_decoupling'] = fit_decoupling + n_anchors = np.clip(int(n_anchors), 201, 2000) + fit_args['t'] = n_anchors + fit_args['k'] = k_dist + fit_args['thresh_multiplier'] = thresh_multiplier + fit_args['weight_c'] = weight_c + fit_args['outlier'] = outlier + fit_args['n_pcs'] = n_pcs + fit_args['n_neighbors'] = n_neighbors + fit_args['fig_size'] = list(fig_size) + fit_args['point_size'] = point_size + + if adam and neural_net: + raise Exception("ADAM and Neural Net mode can not be run concurently." + " Please choose one to run on.") + + if not adam and not neural_net and not device == "cpu": + raise Exception("Multivelo only uses non-CPU devices for Adam or" + " Neural Network mode. Please use one of those or" + "set the device to \"cpu\"") + + if adam and not device[0:5] == "cuda:": + raise Exception("ADAM and Neural Net mode are only possible on a cuda " + "device. Please try again.") + if not adam and batch_size is not None: + raise Exception("Batch training is for ADAM only, please set " + "batch_size to None") + + if adam: + from cuml.neighbors import NearestNeighbors + + all_genes = adata_rna.var_names + if adata_atac is None: + import anndata as ad + rna_only = True + adata_atac = ad.AnnData(X=np.ones(adata_rna.shape), obs=adata_rna.obs, + var=adata_rna.var) + adata_atac.layers['Mc'] = np.ones(adata_rna.shape) + if adata_rna.shape != adata_atac.shape: + raise ValueError('Shape of RNA and ATAC adata objects do not match: ' + f'{adata_rna.shape} {adata_atac.shape}') + if not np.all(adata_rna.obs_names == adata_atac.obs_names): + raise ValueError('obs_names of RNA and ATAC adata objects do not ' + 'match, please check if they are consistent') + if not np.all(all_genes == adata_atac.var_names): + raise ValueError('var_names of RNA and ATAC adata objects do not ' + 'match, please check if they are consistent') + if 'connectivities' not in adata_rna.obsp.keys(): + raise ValueError('Missing connectivities entry in RNA adata object') + if extra_color_key is None: + extra_color = None + elif (isinstance(extra_color_key, str) and extra_color_key in adata_rna.obs + and adata_rna.obs[extra_color_key].dtype.name == 'category'): + ngroups = len(adata_rna.obs[extra_color_key].cat.categories) + extra_color = adata_rna.obs[extra_color_key].cat.rename_categories( + adata_rna.uns[extra_color_key+'_colors'][:ngroups]).to_numpy() + elif (isinstance(extra_color_key, str) and extra_color_key in + adata_atac.obs and + adata_rna.obs[extra_color_key].dtype.name == 'category'): + ngroups = len(adata_atac.obs[extra_color_key].cat.categories) + extra_color = adata_atac.obs[extra_color_key].cat.rename_categories( + adata_atac.uns[extra_color_key+'_colors'][:ngroups]).to_numpy() + else: + raise ValueError('Currently, extra_color_key must be a single string ' + 'of categories and available in adata obs, and its ' + 'colors can be found in adata uns') + if ('connectivities' not in adata_rna.obsp.keys() or + (adata_rna.obsp['connectivities'] > 0).sum(1).min() + > (n_neighbors-1)): + from scanpy import Neighbors + neighbors = Neighbors(adata_rna) + neighbors.compute_neighbors(n_neighbors=n_neighbors, knn=True, + n_pcs=n_pcs) + rna_conn = neighbors.connectivities + else: + rna_conn = adata_rna.obsp['connectivities'].copy() + rna_conn.setdiag(1) + rna_conn = rna_conn.multiply(1.0 / rna_conn.sum(1)).tocsr() + if not rna_only: + if 'connectivities' not in adata_atac.obsp.keys(): + main_info('Missing connectivities in ATAC adata object, using ' + 'RNA connectivities instead', indent_level=1) + atac_conn = rna_conn + else: + atac_conn = adata_atac.obsp['connectivities'].copy() + atac_conn.setdiag(1) + atac_conn = atac_conn.multiply(1.0 / atac_conn.sum(1)).tocsr() + if gene_list is None: + if 'highly_variable' in adata_rna.var: + gene_list = adata_rna.var_names[adata_rna.var['highly_variable']]\ + .values + else: + gene_list = adata_rna.var_names.values[ + (~np.isnan(np.asarray(adata_rna.layers['Mu'].sum(0)) + .reshape(-1) + if sparse.issparse(adata_rna.layers['Mu']) + else np.sum(adata_rna.layers['Mu'], axis=0))) + & (~np.isnan(np.asarray(adata_rna.layers['Ms'].sum(0)) + .reshape(-1) + if sparse.issparse(adata_rna.layers['Ms']) + else np.sum(adata_rna.layers['Ms'], axis=0))) + & (~np.isnan(np.asarray(adata_atac.layers['Mc'].sum(0)) + .reshape(-1) + if sparse.issparse(adata_atac.layers['Mc']) + else np.sum(adata_atac.layers['Mc'], axis=0)))] + elif isinstance(gene_list, (list, np.ndarray, pd.Index, pd.Series)): + gene_list = np.array([x for x in gene_list if x in all_genes]) + elif isinstance(gene_list, str): + gene_list = np.array([gene_list]) if gene_list in all_genes else [] + else: + raise ValueError('Invalid gene list, must be one of (str, np.ndarray,' + 'pd.Index, pd.Series)') + gn = len(gene_list) + if gn == 0: + raise ValueError('None of the genes specified are in the adata object') + main_info(f'{gn} genes will be fitted', indent_level=1) + + models = np.zeros(gn) + t_sws = np.zeros((gn, 3)) + rates = np.zeros((gn, 4)) + scale_ccs = np.zeros(gn) + rescale_cs = np.zeros(gn) + rescale_us = np.zeros(gn) + realign_ratios = np.zeros(gn) + initial_exps = np.zeros((gn, 3)) + times = np.zeros((adata_rna.n_obs, gn)) + states = np.zeros((adata_rna.n_obs, gn)) + if not rna_only: + velo_c = np.zeros((adata_rna.n_obs, gn)) + velo_u = np.zeros((adata_rna.n_obs, gn)) + velo_s = np.zeros((adata_rna.n_obs, gn)) + likelihoods = np.zeros(gn) + l_cs = np.zeros(gn) + ssd_cs = np.zeros(gn) + var_cs = np.zeros(gn) + directions = [] + anchor_c = np.zeros((n_anchors, gn)) + anchor_u = np.zeros((n_anchors, gn)) + anchor_s = np.zeros((n_anchors, gn)) + anchor_c_sw = np.zeros((3, gn)) + anchor_u_sw = np.zeros((3, gn)) + anchor_s_sw = np.zeros((3, gn)) + anchor_vc = np.zeros((n_anchors, gn)) + anchor_vu = np.zeros((n_anchors, gn)) + anchor_vs = np.zeros((n_anchors, gn)) + anchor_min_idx = np.zeros(gn) + anchor_max_idx = np.zeros(gn) + anchor_velo_min_idx = np.zeros(gn) + anchor_velo_max_idx = np.zeros(gn) + + if rna_only: + model_to_run = [2] + main_info('Skipping model checking for RNA-only, running model 2', + indent_level=1) + + m_per_g = False + if model_to_run is not None: + if isinstance(model_to_run, (list, np.ndarray, pd.Index, pd.Series)): + model_to_run = [int(x) for x in model_to_run] + if np.any(~np.isin(model_to_run, [0, 1, 2])): + raise ValueError('Invalid model number (must be values in' + ' [0,1,2])') + if len(model_to_run) == gn: + losses = np.zeros((gn, 1)) + m_per_g = True + func_to_call = regress_func + else: + losses = np.zeros((gn, len(model_to_run))) + func_to_call = multimodel_helper + elif isinstance(model_to_run, (int, float)): + model_to_run = int(model_to_run) + if not np.isin(model_to_run, [0, 1, 2]): + raise ValueError('Invalid model number (must be values in ' + '[0,1,2])') + model_to_run = [model_to_run] + losses = np.zeros((gn, 1)) + func_to_call = multimodel_helper + else: + raise ValueError('Invalid model number (must be values in ' + '[0,1,2])') + else: + losses = np.zeros((gn, 1)) + func_to_call = regress_func + + p_per_g = False + if partial is not None: + if isinstance(partial, (list, np.ndarray, pd.Index, pd.Series)): + if np.any(~np.isin(partial, [True, False])): + raise ValueError('Invalid partial argument (must be values in' + ' [True,False])') + if len(partial) == gn: + p_per_g = True + else: + raise ValueError('Incorrect partial argument length') + elif isinstance(partial, bool): + if not np.isin(partial, [True, False]): + raise ValueError('Invalid partial argument (must be values in' + ' [True,False])') + else: + raise ValueError('Invalid partial argument (must be values in' + ' [True,False])') + + d_per_g = False + if direction is not None: + if isinstance(direction, (list, np.ndarray, pd.Index, pd.Series)): + if np.any(~np.isin(direction, ['on', 'off', 'complete'])): + raise ValueError('Invalid direction argument (must be values' + ' in ["on","off","complete"])') + if len(direction) == gn: + d_per_g = True + else: + raise ValueError('Incorrect direction argument length') + elif isinstance(direction, str): + if not np.isin(direction, ['on', 'off', 'complete']): + raise ValueError('Invalid direction argument (must be values' + ' in ["on","off","complete"])') + else: + raise ValueError('Invalid direction argument (must be values in' + ' ["on","off","complete"])') + + known_pars = [rescale_u, alpha, beta, gamma, t_sw] + for x in known_pars: + if x is not None: + if isinstance(x, (list, np.ndarray)): + if np.sum(np.isnan(x)) + np.sum(np.isinf(x)) > 0: + raise ValueError('Known parameters cannot contain NaN or' + ' Inf') + elif isinstance(x, (int, float)): + if x == np.nan or x == np.inf: + raise ValueError('Known parameters cannot contain NaN or' + ' Inf') + else: + raise ValueError('Invalid known parameters type') + + if ((embedding not in adata_rna.obsm) and + (embedding not in adata_atac.obsm)): + raise ValueError(f'{embedding} is not found in obsm') + embed_coord = adata_rna.obsm[embedding] if embedding in adata_rna.obsm \ + else adata_atac.obsm[embedding] + global_pdist = pairwise_distances(embed_coord) + + u_mat = adata_rna[:, gene_list].layers['Mu'].A \ + if sparse.issparse(adata_rna.layers['Mu']) \ + else adata_rna[:, gene_list].layers['Mu'] + s_mat = adata_rna[:, gene_list].layers['Ms'].A \ + if sparse.issparse(adata_rna.layers['Ms']) \ + else adata_rna[:, gene_list].layers['Ms'] + c_mat = adata_atac[:, gene_list].layers['Mc'].A \ + if sparse.issparse(adata_atac.layers['Mc']) \ + else adata_atac[:, gene_list].layers['Mc'] + + ru = rescale_u if rescale_u is not None else None + + if parallel: + if (n_jobs is None or not isinstance(n_jobs, int) or n_jobs < 0 or + n_jobs > os.cpu_count()): + n_jobs = os.cpu_count() + if n_jobs > gn: + n_jobs = gn + batches = -(-gn // n_jobs) + if n_jobs > 1: + main_info(f'running {n_jobs} jobs in parallel', indent_level=1) + else: + n_jobs = 1 + batches = gn + if n_jobs == 1: + parallel = False + + pbar = tqdm(total=gn) + for group in range(batches): + gene_indices = range(group * n_jobs, np.min([gn, (group+1) * n_jobs])) + if parallel: + from joblib import Parallel, delayed + verb = 51 if settings.VERBOSITY >= 2 else 0 + plot = False + + # clear the settings file if it exists + open("settings.txt", "w").close() + + # write our current settings to the file + with open("settings.txt", "a") as sfile: + sfile.write(str(settings.VERBOSITY) + "\n") + sfile.write(str(settings.CWD) + "\n") + sfile.write(str(settings.LOG_FOLDER) + "\n") + sfile.write(str(settings.LOG_FILENAME) + "\n") + + res = Parallel(n_jobs=n_jobs, backend='loky', verbose=verb)( + delayed(func_to_call)( + c_mat[:, i], + u_mat[:, i], + s_mat[:, i], + model_to_run[i] if m_per_g else model_to_run, + max_iter, + init_mode, + device, + neural_net, + adam, + adam_lr, + adam_beta1, + adam_beta2, + batch_size, + global_pdist, + embed_coord, + rna_conn, + plot, + save_plot, + plot_dir, + fit_args, + gene_list[i], + partial[i] if p_per_g else partial, + direction[i] if d_per_g else direction, + rna_only, + fit, + fit_decoupling, + extra_color, + ru[i] if isinstance(ru, (list, np.ndarray)) else ru, + alpha[i] if isinstance(alpha, (list, np.ndarray)) + else alpha, + beta[i] if isinstance(beta, (list, np.ndarray)) + else beta, + gamma[i] if isinstance(gamma, (list, np.ndarray)) + else gamma, + t_sw[i] if isinstance(t_sw, (list, np.ndarray)) else t_sw, + settings.VERBOSITY, + settings.LOG_FOLDER, + settings.LOG_FILENAME) + for i in gene_indices) + + for i, r in zip(gene_indices, res): + (loss, model, direct_out, parameters, initial_exp, + time, state, velocity, likelihood, anchors) = r + switch, rate, scale_cc, rescale_c, rescale_u, realign_ratio = \ + parameters + likelihood, l_c, ssd_c, var_c = likelihood + losses[i, :] = loss + models[i] = model + directions.append(direct_out) + t_sws[i, :] = switch + rates[i, :] = rate + scale_ccs[i] = scale_cc + rescale_cs[i] = rescale_c + rescale_us[i] = rescale_u + realign_ratios[i] = realign_ratio + likelihoods[i] = likelihood + l_cs[i] = l_c + ssd_cs[i] = ssd_c + var_cs[i] = var_c + if fit: + initial_exps[i, :] = initial_exp + times[:, i] = time + states[:, i] = state + n_anchors_ = anchors[0].shape[0] + n_switch = anchors[1].shape[0] + if not rna_only: + velo_c[:, i] = smooth_scale(atac_conn, velocity[:, 0]) + anchor_c[:n_anchors_, i] = anchors[0][:, 0] + anchor_c_sw[:n_switch, i] = anchors[1][:, 0] + anchor_vc[:n_anchors_, i] = anchors[2][:, 0] + velo_u[:, i] = smooth_scale(rna_conn, velocity[:, 1]) + velo_s[:, i] = smooth_scale(rna_conn, velocity[:, 2]) + anchor_u[:n_anchors_, i] = anchors[0][:, 1] + anchor_s[:n_anchors_, i] = anchors[0][:, 2] + anchor_u_sw[:n_switch, i] = anchors[1][:, 1] + anchor_s_sw[:n_switch, i] = anchors[1][:, 2] + anchor_vu[:n_anchors_, i] = anchors[2][:, 1] + anchor_vs[:n_anchors_, i] = anchors[2][:, 2] + anchor_min_idx[i] = anchors[3] + anchor_max_idx[i] = anchors[4] + anchor_velo_min_idx[i] = anchors[5] + anchor_velo_max_idx[i] = anchors[6] + else: + i = group + gene = gene_list[i] + main_info(f'@@@@@fitting {gene}', indent_level=1) + (loss, model, direct_out, + parameters, initial_exp, + time, state, velocity, + likelihood, anchors) = \ + func_to_call(c_mat[:, i], u_mat[:, i], s_mat[:, i], + model_to_run[i] if m_per_g else model_to_run, + max_iter, init_mode, + device, + neural_net, + adam, + adam_lr, + adam_beta1, + adam_beta2, + batch_size, + global_pdist, embed_coord, + rna_conn, plot, save_plot, plot_dir, + fit_args, gene, + partial[i] if p_per_g else partial, + direction[i] if d_per_g else direction, + rna_only, fit, fit_decoupling, extra_color, + ru[i] if isinstance(ru, (list, np.ndarray)) + else ru, + alpha[i] if isinstance(alpha, (list, np.ndarray)) + else alpha, + beta[i] if isinstance(beta, (list, np.ndarray)) + else beta, + gamma[i] if isinstance(gamma, (list, np.ndarray)) + else gamma, + t_sw[i] if isinstance(t_sw, (list, np.ndarray)) + else t_sw, + settings.VERBOSITY, + settings.LOG_FOLDER, + settings.LOG_FILENAME) + switch, rate, scale_cc, rescale_c, rescale_u, realign_ratio = \ + parameters + likelihood, l_c, ssd_c, var_c = likelihood + losses[i, :] = loss + models[i] = model + directions.append(direct_out) + t_sws[i, :] = switch + rates[i, :] = rate + scale_ccs[i] = scale_cc + rescale_cs[i] = rescale_c + rescale_us[i] = rescale_u + realign_ratios[i] = realign_ratio + likelihoods[i] = likelihood + l_cs[i] = l_c + ssd_cs[i] = ssd_c + var_cs[i] = var_c + if fit: + initial_exps[i, :] = initial_exp + times[:, i] = time + states[:, i] = state + n_anchors_ = anchors[0].shape[0] + n_switch = anchors[1].shape[0] + if not rna_only: + velo_c[:, i] = smooth_scale(atac_conn, velocity[:, 0]) + anchor_c[:n_anchors_, i] = anchors[0][:, 0] + anchor_c_sw[:n_switch, i] = anchors[1][:, 0] + anchor_vc[:n_anchors_, i] = anchors[2][:, 0] + velo_u[:, i] = smooth_scale(rna_conn, velocity[:, 1]) + velo_s[:, i] = smooth_scale(rna_conn, velocity[:, 2]) + anchor_u[:n_anchors_, i] = anchors[0][:, 1] + anchor_s[:n_anchors_, i] = anchors[0][:, 2] + anchor_u_sw[:n_switch, i] = anchors[1][:, 1] + anchor_s_sw[:n_switch, i] = anchors[1][:, 2] + anchor_vu[:n_anchors_, i] = anchors[2][:, 1] + anchor_vs[:n_anchors_, i] = anchors[2][:, 2] + anchor_min_idx[i] = anchors[3] + anchor_max_idx[i] = anchors[4] + anchor_velo_min_idx[i] = anchors[5] + anchor_velo_max_idx[i] = anchors[6] + pbar.update(len(gene_indices)) + pbar.close() + directions = np.array(directions) + + filt = np.sum(losses != np.inf, 1) >= 1 + if np.sum(filt) == 0: + raise ValueError('None of the genes were fitted due to low quality,' + ' not returning') + adata_copy = adata_rna[:, gene_list[filt]].copy() + adata_copy.layers['ATAC'] = c_mat[:, filt] + adata_copy.var['fit_alpha_c'] = rates[filt, 0] + adata_copy.var['fit_alpha'] = rates[filt, 1] + adata_copy.var['fit_beta'] = rates[filt, 2] + adata_copy.var['fit_gamma'] = rates[filt, 3] + adata_copy.var['fit_t_sw1'] = t_sws[filt, 0] + adata_copy.var['fit_t_sw2'] = t_sws[filt, 1] + adata_copy.var['fit_t_sw3'] = t_sws[filt, 2] + adata_copy.var['fit_scale_cc'] = scale_ccs[filt] + adata_copy.var['fit_rescale_c'] = rescale_cs[filt] + adata_copy.var['fit_rescale_u'] = rescale_us[filt] + adata_copy.var['fit_alignment_scaling'] = realign_ratios[filt] + adata_copy.var['fit_model'] = models[filt] + adata_copy.var['fit_direction'] = directions[filt] + if model_to_run is not None and not m_per_g and not rna_only: + for i, m in enumerate(model_to_run): + adata_copy.var[f'fit_loss_M{m}'] = losses[filt, i] + else: + adata_copy.var['fit_loss'] = losses[filt, 0] + adata_copy.var['fit_likelihood'] = likelihoods[filt] + adata_copy.var['fit_likelihood_c'] = l_cs[filt] + adata_copy.var['fit_ssd_c'] = ssd_cs[filt] + adata_copy.var['fit_var_c'] = var_cs[filt] + if fit: + adata_copy.layers['fit_t'] = times[:, filt] + adata_copy.layers['fit_state'] = states[:, filt] + adata_copy.layers['velo_s'] = velo_s[:, filt] + adata_copy.layers['velo_u'] = velo_u[:, filt] + if not rna_only: + adata_copy.layers['velo_chrom'] = velo_c[:, filt] + adata_copy.var['fit_c0'] = initial_exps[filt, 0] + adata_copy.var['fit_u0'] = initial_exps[filt, 1] + adata_copy.var['fit_s0'] = initial_exps[filt, 2] + adata_copy.var['fit_anchor_min_idx'] = anchor_min_idx[filt] + adata_copy.var['fit_anchor_max_idx'] = anchor_max_idx[filt] + adata_copy.var['fit_anchor_velo_min_idx'] = anchor_velo_min_idx[filt] + adata_copy.var['fit_anchor_velo_max_idx'] = anchor_velo_max_idx[filt] + adata_copy.varm['fit_anchor_c'] = np.transpose(anchor_c[:, filt]) + adata_copy.varm['fit_anchor_u'] = np.transpose(anchor_u[:, filt]) + adata_copy.varm['fit_anchor_s'] = np.transpose(anchor_s[:, filt]) + adata_copy.varm['fit_anchor_c_sw'] = np.transpose(anchor_c_sw[:, filt]) + adata_copy.varm['fit_anchor_u_sw'] = np.transpose(anchor_u_sw[:, filt]) + adata_copy.varm['fit_anchor_s_sw'] = np.transpose(anchor_s_sw[:, filt]) + adata_copy.varm['fit_anchor_c_velo'] = np.transpose(anchor_vc[:, filt]) + adata_copy.varm['fit_anchor_u_velo'] = np.transpose(anchor_vu[:, filt]) + adata_copy.varm['fit_anchor_s_velo'] = np.transpose(anchor_vs[:, filt]) + v_genes = adata_copy.var['fit_likelihood'] >= 0.05 + adata_copy.var['velo_s_genes'] = adata_copy.var['velo_u_genes'] = \ + adata_copy.var['velo_chrom_genes'] = v_genes + adata_copy.uns['velo_s_params'] = adata_copy.uns['velo_u_params'] = \ + adata_copy.uns['velo_chrom_params'] = {'mode': 'dynamical'} + adata_copy.uns['velo_s_params'].update(fit_args) + adata_copy.uns['velo_u_params'].update(fit_args) + adata_copy.uns['velo_chrom_params'].update(fit_args) + adata_copy.obsp['_RNA_conn'] = rna_conn + if not rna_only: + adata_copy.obsp['_ATAC_conn'] = atac_conn + return adata_copy + + +def smooth_scale(conn, vector): + max_to = np.max(vector) + min_to = np.min(vector) + v = conn.dot(vector.T).T + max_from = np.max(v) + min_from = np.min(v) + res = ((v - min_from) * (max_to - min_to) / (max_from - min_from)) + min_to + return res + + +def top_n_sparse(conn, n): + conn_ll = conn.tolil() + for i in range(conn_ll.shape[0]): + row_data = np.array(conn_ll.data[i]) + row_idx = np.array(conn_ll.rows[i]) + new_idx = row_data.argsort()[-n:] + top_val = row_data[new_idx] + top_idx = row_idx[new_idx] + conn_ll.data[i] = top_val.tolist() + conn_ll.rows[i] = top_idx.tolist() + conn = conn_ll.tocsr() + idx1 = conn > 0 + idx2 = conn > 0.25 + idx3 = conn > 0.5 + conn[idx1] = 0.25 + conn[idx2] = 0.5 + conn[idx3] = 1 + conn.eliminate_zeros() + return conn + + +def set_velocity_genes(adata, + likelihood_lower=0.05, + rescale_u_upper=None, + rescale_u_lower=None, + rescale_c_upper=None, + rescale_c_lower=None, + primed_upper=None, + primed_lower=None, + decoupled_upper=None, + decoupled_lower=None, + alpha_c_upper=None, + alpha_c_lower=None, + alpha_upper=None, + alpha_lower=None, + beta_upper=None, + beta_lower=None, + gamma_upper=None, + gamma_lower=None, + scale_cc_upper=None, + scale_cc_lower=None + ): + """Reset velocity genes. + + This function resets velocity genes based on criteria of variables. + + Parameters + ---------- + adata: :class:`~anndata.AnnData` + Anndata result from dynamics recovery. + likelihood_lower: `float` (default: 0.05) + Minimum ikelihood. + rescale_u_upper: `float` (default: `None`) + Maximum rescale_u. + rescale_u_lower: `float` (default: `None`) + Minimum rescale_u. + rescale_c_upper: `float` (default: `None`) + Maximum rescale_c. + rescale_c_lower: `float` (default: `None`) + Minimum rescale_c. + primed_upper: `float` (default: `None`) + Maximum primed interval. + primed_lower: `float` (default: `None`) + Minimum primed interval. + decoupled_upper: `float` (default: `None`) + Maximum decoupled interval. + decoupled_lower: `float` (default: `None`) + Minimum decoupled interval. + alpha_c_upper: `float` (default: `None`) + Maximum alpha_c. + alpha_c_lower: `float` (default: `None`) + Minimum alpha_c. + alpha_upper: `float` (default: `None`) + Maximum alpha. + alpha_lower: `float` (default: `None`) + Minimum alpha. + beta_upper: `float` (default: `None`) + Maximum beta. + beta_lower: `float` (default: `None`) + Minimum beta. + gamma_upper: `float` (default: `None`) + Maximum gamma. + gamma_lower: `float` (default: `None`) + Minimum gamma. + scale_cc_upper: `float` (default: `None`) + Maximum scale_cc. + scale_cc_lower: `float` (default: `None`) + Minimum scale_cc. + + Returns + ------- + velo_s_genes, velo_u_genes, velo_chrom_genes: `.var` + new velocity genes for each modalities. + """ + + v_genes = (adata.var['fit_likelihood'] >= likelihood_lower) + if rescale_u_upper is not None: + v_genes &= adata.var['fit_rescale_u'] <= rescale_u_upper + if rescale_u_lower is not None: + v_genes &= adata.var['fit_rescale_u'] >= rescale_u_lower + if rescale_c_upper is not None: + v_genes &= adata.var['fit_rescale_c'] <= rescale_c_upper + if rescale_c_lower is not None: + v_genes &= adata.var['fit_rescale_c'] >= rescale_c_lower + t_sw1 = adata.var['fit_t_sw1'] + 20 / adata.uns['velo_s_params']['t'] * \ + adata.var['fit_anchor_min_idx'] * adata.var['fit_alignment_scaling'] + if primed_upper is not None: + v_genes &= t_sw1 <= primed_upper + if primed_lower is not None: + v_genes &= t_sw1 >= primed_lower + t_sw2 = np.clip(adata.var['fit_t_sw2'], None, 20) + t_sw3 = np.clip(adata.var['fit_t_sw3'], None, 20) + t_interval3 = t_sw3 - t_sw2 + if decoupled_upper is not None: + v_genes &= t_interval3 <= decoupled_upper + if decoupled_lower is not None: + v_genes &= t_interval3 >= decoupled_lower + if alpha_c_upper is not None: + v_genes &= adata.var['fit_alpha_c'] <= alpha_c_upper + if alpha_c_lower is not None: + v_genes &= adata.var['fit_alpha_c'] >= alpha_c_lower + if alpha_upper is not None: + v_genes &= adata.var['fit_alpha'] <= alpha_upper + if alpha_lower is not None: + v_genes &= adata.var['fit_alpha'] >= alpha_lower + if beta_upper is not None: + v_genes &= adata.var['fit_beta'] <= beta_upper + if beta_lower is not None: + v_genes &= adata.var['fit_beta'] >= beta_lower + if gamma_upper is not None: + v_genes &= adata.var['fit_gamma'] <= gamma_upper + if gamma_lower is not None: + v_genes &= adata.var['fit_gamma'] >= gamma_lower + if scale_cc_upper is not None: + v_genes &= adata.var['fit_scale_cc'] <= scale_cc_upper + if scale_cc_lower is not None: + v_genes &= adata.var['fit_scale_cc'] >= scale_cc_lower + main_info(f'{np.sum(v_genes)} velocity genes were selected', indent_level=1) + adata.var['velo_s_genes'] = adata.var['velo_u_genes'] = \ + adata.var['velo_chrom_genes'] = v_genes + + +def velocity_graph(adata, vkey='velo_s', xkey='Ms', **kwargs): + """Computes velocity graph. + + This function normalizes the velocity matrix and computes velocity graph + with `scvelo.tl.velocity_graph`. + + Parameters + ---------- + adata: :class:`~anndata.AnnData` + Anndata result from dynamics recovery. + vkey: `str` (default: `velo_s`) + Default to use spliced velocities. + xkey: `str` (default: `Ms`) + Default to use smoothed spliced counts. + Additional parameters passed to `scvelo.tl.velocity_graph`. + + Returns + ------- + Normalized velocity matrix and associated velocity genes and params. + Outputs of `scvelo.tl.velocity_graph`. + """ + if vkey not in adata.layers.keys(): + raise ValueError('Velocity matrix is not found. Please run multivelo' + '.recover_dynamics_chrom function first.') + if vkey+'_norm' not in adata.layers.keys(): + adata.layers[vkey+'_norm'] = adata.layers[vkey] / np.sum( + np.abs(adata.layers[vkey]), 0) + adata.layers[vkey+'_norm'] /= np.mean(adata.layers[vkey+'_norm']) + adata.uns[vkey+'_norm_params'] = adata.uns[vkey+'_params'] + if vkey+'_norm_genes' not in adata.var.columns: + adata.var[vkey+'_norm_genes'] = adata.var[vkey+'_genes'] + scv.tl.velocity_graph(adata, vkey=vkey+'_norm', xkey=xkey, **kwargs) + + +def velocity_embedding_stream(adata, vkey='velo_s', show=True, **kwargs): + """Plots velocity stream. + + This function plots velocity streamplot with + `scvelo.pl.velocity_embedding_stream`. + + Parameters + ---------- + adata: :class:`~anndata.AnnData` + Anndata result from dynamics recovery. + vkey: `str` (default: `velo_s`) + Default to use spliced velocities. The normalized matrix will be used. + show: `bool` (default: `True`) + Whether to show the plot. + Additional parameters passed to `scvelo.tl.velocity_graph`. + + Returns + ------- + If `show==False`, a matplotlib axis object. + """ + if vkey not in adata.layers: + raise ValueError('Velocity matrix is not found. Please run multivelo.' + 'recover_dynamics_chrom function first.') + if vkey+'_norm' not in adata.layers.keys(): + adata.layers[vkey+'_norm'] = adata.layers[vkey] / np.sum( + np.abs(adata.layers[vkey]), 0) + adata.uns[vkey+'_norm_params'] = adata.uns[vkey+'_params'] + if vkey+'_norm_genes' not in adata.var.columns: + adata.var[vkey+'_norm_genes'] = adata.var[vkey+'_genes'] + if vkey+'_norm_graph' not in adata.uns.keys(): + velocity_graph(adata, vkey=vkey, **kwargs) + out = scv.pl.velocity_embedding_stream(adata, vkey=vkey+'_norm', show=show, + **kwargs) + if not show: + return out + + +def latent_time(adata, vkey='velo_s', **kwargs): + """Computes latent time. + + This function computes latent time with `scvelo.tl.latent_time`. + + Parameters + ---------- + adata: :class:`~anndata.AnnData` + Anndata result from dynamics recovery. + vkey: `str` (default: `velo_s`) + Default to use spliced velocities. The normalized matrix will be used. + Additional parameters passed to `scvelo.tl.velocity_graph`. + + Returns + ------- + Outputs of `scvelo.tl.latent_time`. + """ + if vkey not in adata.layers.keys() or 'fit_t' not in adata.layers.keys(): + raise ValueError('Velocity or time matrix is not found. Please run ' + 'multivelo.recover_dynamics_chrom function first.') + if vkey+'_norm' not in adata.layers.keys(): + raise ValueError('Normalized velocity matrix is not found. Please ' + 'run multivelo.velocity_graph function first.') + if vkey+'_norm_graph' not in adata.uns.keys(): + velocity_graph(adata, vkey=vkey, **kwargs) + scv.tl.latent_time(adata, vkey=vkey+'_norm', **kwargs) + + +def LRT_decoupling(adata_rna, adata_atac, **kwargs): + """Computes likelihood ratio test for decoupling state. + + This function computes whether keeping decoupling state improves fit + Likelihood. + + Parameters + ---------- + adata_rna: :class:`~anndata.AnnData` + RNA anndata object + adata_atac: :class:`~anndata.AnnData` + ATAC anndata object. + Additional parameters passed to `recover_dynamics_chrom`. + + Returns + ------- + adata_result_w_decoupled: class:`~anndata.AnnData` + fit result with decoupling state + adata_result_w_decoupled: class:`~anndata.AnnData` + fit result without decoupling state + res: `pandas.DataFrame` + LRT statistics + """ + from scipy.stats.distributions import chi2 + main_info('fitting models with decoupling intervals', v=0) + adata_result_w_decoupled = recover_dynamics_chrom(adata_rna, adata_atac, + fit_decoupling=True, + **kwargs) + main_info('fitting models without decoupling intervals', v=0) + adata_result_wo_decoupled = recover_dynamics_chrom(adata_rna, adata_atac, + fit_decoupling=False, + **kwargs) + main_info('testing likelihood ratio', v=0) + shared_genes = pd.Index(np.intersect1d(adata_result_w_decoupled.var_names, + adata_result_wo_decoupled.var_names) + ) + l_c_w_decoupled = adata_result_w_decoupled[:, shared_genes].\ + var['fit_likelihood_c'].values + l_c_wo_decoupled = adata_result_wo_decoupled[:, shared_genes].\ + var['fit_likelihood_c'].values + n_obs = adata_rna.n_obs + LRT_c = -2 * n_obs * (np.log(l_c_wo_decoupled) - np.log(l_c_w_decoupled)) + p_c = chi2.sf(LRT_c, 1) + l_w_decoupled = adata_result_w_decoupled[:, shared_genes].\ + var['fit_likelihood'].values + l_wo_decoupled = adata_result_wo_decoupled[:, shared_genes].\ + var['fit_likelihood'].values + LRT = -2 * n_obs * (np.log(l_wo_decoupled) - np.log(l_w_decoupled)) + p = chi2.sf(LRT, 1) + res = pd.DataFrame({'likelihood_c_w_decoupled': l_c_w_decoupled, + 'likelihood_c_wo_decoupled': l_c_wo_decoupled, + 'LRT_c': LRT_c, + 'pval_c': p_c, + 'likelihood_w_decoupled': l_w_decoupled, + 'likelihood_wo_decoupled': l_wo_decoupled, + 'LRT': LRT, + 'pval': p, + }, index=shared_genes) + return adata_result_w_decoupled, adata_result_wo_decoupled, res + + +def transition_matrix_s(s_mat, velo_s, knn): + knn = knn.astype(int) + tm_val, tm_col, tm_row = [], [], [] + for i in range(knn.shape[0]): + two_step_knn = knn[i, :] + for j in knn[i, :]: + two_step_knn = np.append(two_step_knn, knn[j, :]) + two_step_knn = np.unique(two_step_knn) + for j in two_step_knn: + s = s_mat[i, :] + sn = s_mat[j, :] + ds = s - sn + dx = np.ravel(ds.A) + velo = velo_s[i, :] + cos_sim = np.dot(dx, velo)/(norm(dx)*norm(velo)) + tm_val.append(cos_sim) + tm_col.append(j) + tm_row.append(i) + tm = coo_matrix((tm_val, (tm_row, tm_col)), shape=(s_mat.shape[0], + s_mat.shape[0])).tocsr() + tm.setdiag(0) + tm_neg = tm.copy() + tm.data = np.clip(tm.data, 0, 1) + tm_neg.data = np.clip(tm_neg.data, -1, 0) + tm.eliminate_zeros() + tm_neg.eliminate_zeros() + return tm, tm_neg + + +def transition_matrix_chrom(c_mat, u_mat, s_mat, velo_c, velo_u, velo_s, knn): + knn = knn.astype(int) + tm_val, tm_col, tm_row = [], [], [] + for i in range(knn.shape[0]): + two_step_knn = knn[i, :] + for j in knn[i, :]: + two_step_knn = np.append(two_step_knn, knn[j, :]) + two_step_knn = np.unique(two_step_knn) + for j in two_step_knn: + u = u_mat[i, :].A + s = s_mat[i, :].A + c = c_mat[i, :].A + un = u_mat[j, :] + sn = s_mat[j, :] + cn = c_mat[j, :] + dc = (c - cn) / np.std(c) + du = (u - un) / np.std(u) + ds = (s - sn) / np.std(s) + dx = np.ravel(np.hstack((dc.A, du.A, ds.A))) + velo = np.hstack((velo_c[i, :], velo_u[i, :], velo_s[i, :])) + cos_sim = np.dot(dx, velo)/(norm(dx)*norm(velo)) + tm_val.append(cos_sim) + tm_col.append(j) + tm_row.append(i) + tm = coo_matrix((tm_val, (tm_row, tm_col)), shape=(c_mat.shape[0], + c_mat.shape[0])).tocsr() + tm.setdiag(0) + tm_neg = tm.copy() + tm.data = np.clip(tm.data, 0, 1) + tm_neg.data = np.clip(tm_neg.data, -1, 0) + tm.eliminate_zeros() + tm_neg.eliminate_zeros() + return tm, tm_neg + + +def likelihood_plot(adata, + genes=None, + figsize=(14, 10), + bins=50, + pointsize=4 + ): + """Likelihood plots. + + This function plots likelihood and variable distributions. + + Parameters + ---------- + adata: :class:`~anndata.AnnData` + Anndata result from dynamics recovery. + genes: `str`, list of `str` (default: `None`) + If `None`, will use all fitted genes. + figsize: `tuple` (default: (14,10)) + Figure size. + bins: `int` (default: 50) + Number of bins for histograms. + pointsize: `float` (default: 4) + Point size for scatter plots. + """ + if genes is None: + var = adata.var + else: + genes = np.array(genes) + var = adata[:, genes].var + likelihood = var[['fit_likelihood']].values + rescale_u = var[['fit_rescale_u']].values + rescale_c = var[['fit_rescale_c']].values + t_interval1 = var['fit_t_sw1'] + 20 / adata.uns['velo_s_params']['t'] \ + * var['fit_anchor_min_idx'] * var['fit_alignment_scaling'] + t_sw2 = np.clip(var['fit_t_sw2'], None, 20) + t_sw3 = np.clip(var['fit_t_sw3'], None, 20) + t_interval3 = t_sw3 - t_sw2 + log_s = np.log1p(np.sum(adata.layers['Ms'], axis=0)) + alpha_c = var[['fit_alpha_c']].values + alpha = var[['fit_alpha']].values + beta = var[['fit_beta']].values + gamma = var[['fit_gamma']].values + scale_cc = var[['fit_scale_cc']].values + + fig, axes = plt.subplots(4, 5, figsize=figsize) + axes[0, 0].hist(likelihood, bins=bins) + axes[0, 0].set_title('likelihood') + axes[0, 1].hist(rescale_u, bins=bins) + axes[0, 1].set_title('rescale u') + axes[0, 2].hist(rescale_c, bins=bins) + axes[0, 2].set_title('rescale c') + axes[0, 3].hist(t_interval1.values, bins=bins) + axes[0, 3].set_title('primed interval') + axes[0, 4].hist(t_interval3, bins=bins) + axes[0, 4].set_title('decoupled interval') + + axes[1, 0].scatter(log_s, likelihood, s=pointsize) + axes[1, 0].set_xlabel('log spliced') + axes[1, 0].set_ylabel('likelihood') + axes[1, 1].scatter(rescale_u, likelihood, s=pointsize) + axes[1, 1].set_xlabel('rescale u') + axes[1, 2].scatter(rescale_c, likelihood, s=pointsize) + axes[1, 2].set_xlabel('rescale c') + axes[1, 3].scatter(t_interval1.values, likelihood, s=pointsize) + axes[1, 3].set_xlabel('primed interval') + axes[1, 4].scatter(t_interval3, likelihood, s=pointsize) + axes[1, 4].set_xlabel('decoupled interval') + + axes[2, 0].hist(alpha_c, bins=bins) + axes[2, 0].set_title('alpha c') + axes[2, 1].hist(alpha, bins=bins) + axes[2, 1].set_title('alpha') + axes[2, 2].hist(beta, bins=bins) + axes[2, 2].set_title('beta') + axes[2, 3].hist(gamma, bins=bins) + axes[2, 3].set_title('gamma') + axes[2, 4].hist(scale_cc, bins=bins) + axes[2, 4].set_title('scale cc') + + axes[3, 0].scatter(alpha_c, likelihood, s=pointsize) + axes[3, 0].set_xlabel('alpha c') + axes[3, 0].set_ylabel('likelihood') + axes[3, 1].scatter(alpha, likelihood, s=pointsize) + axes[3, 1].set_xlabel('alpha') + axes[3, 2].scatter(beta, likelihood, s=pointsize) + axes[3, 2].set_xlabel('beta') + axes[3, 3].scatter(gamma, likelihood, s=pointsize) + axes[3, 3].set_xlabel('gamma') + axes[3, 4].scatter(scale_cc, likelihood, s=pointsize) + axes[3, 4].set_xlabel('scale cc') + fig.tight_layout() + + +def pie_summary(adata, genes=None): + """Summary of directions and models. + + This function plots a pie chart for (pre-determined or specified) + directions and models. + `induction`: induction-only genes. + `repression`: repression-only genes. + `Model 1`: model 1 complete genes. + `Model 2`: model 2 complete genes. + + Parameters + ---------- + adata: :class:`~anndata.AnnData` + Anndata result from dynamics recovery. + genes: `str`, list of `str` (default: `None`) + If `None`, will use all fitted genes. + """ + if genes is None: + genes = adata.var_names + fit_model = adata[:, (adata.var['fit_direction'] == 'complete') & + np.isin(adata.var_names, genes)].var['fit_model'].values + fit_direction = adata[:, genes].var['fit_direction'].values + data = [np.sum(fit_direction == 'on'), np.sum(fit_direction == 'off'), + np.sum(fit_model == 1), np.sum(fit_model == 2)] + index = ['induction', 'repression', 'Model 1', 'Model 2'] + index = [x for i, x in enumerate(index) if data[i] > 0] + data = [x for x in data if x > 0] + df = pd.DataFrame({'data': data}, index=index) + df.plot.pie(y='data', autopct='%1.1f%%', legend=False, startangle=30, + ylabel='') + circle = plt.Circle((0, 0), 0.8, fc='white') + fig = plt.gcf() + fig.gca().add_artist(circle) + + +def switch_time_summary(adata, genes=None): + """Summary of switch times. + + This function plots a box plot for observed switch times. + `primed`: primed intervals. + `coupled-on`: coupled induction intervals. + `decoupled`: decoupled intervals. + `coupled-off`: coupled repression intervals. + + Parameters + ---------- + adata: :class:`~anndata.AnnData` + Anndata result from dynamics recovery. + genes: `str`, list of `str` (default: `None`) + If `None`, will use velocity genes. + """ + t_sw = adata[:, adata.var['velo_s_genes'] + if genes is None + else genes] \ + .var[['fit_t_sw1', 'fit_t_sw2', 'fit_t_sw3']].copy() + t_sw = t_sw.mask(t_sw > 20, 20) + t_sw = t_sw.mask(t_sw < 0) + t_sw['interval 1'] = t_sw['fit_t_sw1'] + t_sw['t_sw2 - t_sw1'] = t_sw['fit_t_sw2'] - t_sw['fit_t_sw1'] + t_sw['t_sw3 - t_sw2'] = t_sw['fit_t_sw3'] - t_sw['fit_t_sw2'] + t_sw['20 - t_sw3'] = 20 - t_sw['fit_t_sw3'] + t_sw = t_sw.mask(t_sw <= 0) + t_sw = t_sw.mask(t_sw > 20) + t_sw.columns = pd.Index(['time 1', 'time 2', 'time 3', 'primed', + 'coupled-on', 'decoupled', 'coupled-off']) + t_sw = t_sw[['primed', 'coupled-on', 'decoupled', 'coupled-off']] + t_sw = t_sw / 20 + fig, ax = plt.subplots(figsize=(4, 5)) + ax = sns.boxplot(data=t_sw, width=0.5, palette='Set2', ax=ax) + ax.set_yticks(np.linspace(0, 1, 5)) + ax.set_title('Switch Intervals') + + +def dynamic_plot(adata, + genes, + by='expression', + color_by='state', + gene_time=True, + axis_on=True, + frame_on=True, + show_anchors=True, + show_switches=True, + downsample=1, + full_range=False, + figsize=None, + pointsize=2, + linewidth=1.5, + cmap='coolwarm' + ): + """Gene dynamics plot. + + This function plots accessibility, expression, or velocity by time. + + Parameters + ---------- + adata: :class:`~anndata.AnnData` + Anndata result from dynamics recovery. + genes: `str`, list of `str` + List of genes to plot. + by: `str` (default: `expression`) + Plot accessibilities and expressions if `expression`. Plot velocities + if `velocity`. + color_by: `str` (default: `state`) + Color by the four potential states if `state`. Other common values are + leiden, louvain, celltype, etc. + If not `state`, the color field must be present in `.uns`, which can + be pre-computed with `scanpy.pl.scatter`. + For `state`, red, orange, green, and blue represent state 1, 2, 3, and + 4, respectively. + gene_time: `bool` (default: `True`) + Whether to use individual gene fitted time, or shared global latent + time. + Mean values of 20 equal sized windows will be connected and shown if + `gene_time==False`. + axis_on: `bool` (default: `True`) + Whether to show axis labels. + frame_on: `bool` (default: `True`) + Whether to show plot frames. + show_anchors: `bool` (default: `True`) + Whether to display anchors. + show_switches: `bool` (default: `True`) + Whether to show switch times. The switch times are indicated by + vertical dotted line. + downsample: `int` (default: 1) + How much to downsample the cells. The remaining number will be + `1/downsample` of original. + full_range: `bool` (default: `False`) + Whether to show the full time range of velocities before smoothing or + subset to only smoothed range. + figsize: `tuple` (default: `None`) + Total figure size. + pointsize: `float` (default: 2) + Point size for scatter plots. + linewidth: `float` (default: 1.5) + Line width for anchor line or mean line. + cmap: `str` (default: `coolwarm`) + Color map for continuous color key. + """ + from pandas.api.types import is_numeric_dtype, is_categorical_dtype + if by not in ['expression', 'velocity']: + raise ValueError('"by" must be either "expression" or "velocity".') + if by == 'velocity': + show_switches = False + if color_by == 'state': + types = [0, 1, 2, 3] + colors = ['tab:red', 'tab:orange', 'tab:green', 'tab:blue'] + elif color_by in adata.obs and is_numeric_dtype(adata.obs[color_by]): + types = None + colors = adata.obs[color_by].values + elif color_by in adata.obs and is_categorical_dtype(adata.obs[color_by]) \ + and color_by+'_colors' in adata.uns.keys(): + types = adata.obs[color_by].cat.categories + colors = adata.uns[f'{color_by}_colors'] + else: + raise ValueError('Currently, color key must be a single string of ' + 'either numerical or categorical available in adata ' + 'obs, and the colors of categories can be found in ' + 'adata uns.') + + downsample = np.clip(int(downsample), 1, 10) + genes = np.array(genes) + missing_genes = genes[~np.isin(genes, adata.var_names)] + if len(missing_genes) > 0: + main_info(f'{missing_genes} not found', v=0) + genes = genes[np.isin(genes, adata.var_names)] + gn = len(genes) + if gn == 0: + return + if not gene_time: + show_anchors = False + latent_time = np.array(adata.obs['latent_time']) + time_window = latent_time // 0.05 + time_window = time_window.astype(int) + time_window[time_window == 20] = 19 + if 'velo_s_params' in adata.uns.keys() and 'outlier' \ + in adata.uns['velo_s_params']: + outlier = adata.uns['velo_s_params']['outlier'] + else: + outlier = 99 + + fig, axs = plt.subplots(gn, 3, squeeze=False, figsize=(10, 2.3*gn) + if figsize is None else figsize) + fig.patch.set_facecolor('white') + for row, gene in enumerate(genes): + u = adata[:, gene].layers['Mu' if by == 'expression' else 'velo_u'] + s = adata[:, gene].layers['Ms' if by == 'expression' else 'velo_s'] + c = adata[:, gene].layers['ATAC' if by == 'expression' + else 'velo_chrom'] + c = c.A if sparse.issparse(c) else c + u = u.A if sparse.issparse(u) else u + s = s.A if sparse.issparse(s) else s + c, u, s = np.ravel(c), np.ravel(u), np.ravel(s) + non_outlier = c <= np.percentile(c, outlier) + non_outlier &= u <= np.percentile(u, outlier) + non_outlier &= s <= np.percentile(s, outlier) + c, u, s = c[non_outlier], u[non_outlier], s[non_outlier] + time = np.array(adata[:, gene].layers['fit_t'] if gene_time + else latent_time) + if by == 'velocity': + time = np.reshape(time, (-1, 1)) + time = np.ravel(adata.obsp['_RNA_conn'].dot(time)) + time = time[non_outlier] + if types is not None: + for i in range(len(types)): + if color_by == 'state': + filt = adata[non_outlier, gene].layers['fit_state'] \ + == types[i] + else: + filt = adata[non_outlier, :].obs[color_by] == types[i] + filt = np.ravel(filt) + if np.sum(filt) > 0: + axs[row, 0].scatter(time[filt][::downsample], + c[filt][::downsample], s=pointsize, + c=colors[i], alpha=0.6) + axs[row, 1].scatter(time[filt][::downsample], + u[filt][::downsample], + s=pointsize, c=colors[i], alpha=0.6) + axs[row, 2].scatter(time[filt][::downsample], + s[filt][::downsample], s=pointsize, + c=colors[i], alpha=0.6) + else: + axs[row, 0].scatter(time[::downsample], c[::downsample], + s=pointsize, + c=colors[non_outlier][::downsample], + alpha=0.6, cmap=cmap) + axs[row, 1].scatter(time[::downsample], u[::downsample], + s=pointsize, + c=colors[non_outlier][::downsample], + alpha=0.6, cmap=cmap) + axs[row, 2].scatter(time[::downsample], s[::downsample], + s=pointsize, + c=colors[non_outlier][::downsample], + alpha=0.6, cmap=cmap) + + if not gene_time: + window_count = np.zeros(20) + window_mean_c = np.zeros(20) + window_mean_u = np.zeros(20) + window_mean_s = np.zeros(20) + for i in np.unique(time_window[non_outlier]): + idx = time_window[non_outlier] == i + window_count[i] = np.sum(idx) + window_mean_c[i] = np.mean(c[idx]) + window_mean_u[i] = np.mean(u[idx]) + window_mean_s[i] = np.mean(s[idx]) + window_idx = np.where(window_count > 20)[0] + axs[row, 0].plot(window_idx*0.05+0.025, window_mean_c[window_idx], + linewidth=linewidth, color='black', alpha=0.5) + axs[row, 1].plot(window_idx*0.05+0.025, window_mean_u[window_idx], + linewidth=linewidth, color='black', alpha=0.5) + axs[row, 2].plot(window_idx*0.05+0.025, window_mean_s[window_idx], + linewidth=linewidth, color='black', alpha=0.5) + + if show_anchors: + n_anchors = adata.uns['velo_s_params']['t'] + t_sw_array = np.array([adata[:, gene].var['fit_t_sw1'], + adata[:, gene].var['fit_t_sw2'], + adata[:, gene].var['fit_t_sw3']]) + t_sw_array = t_sw_array[t_sw_array < 20] + min_idx = int(adata[:, gene].var['fit_anchor_min_idx']) + max_idx = int(adata[:, gene].var['fit_anchor_max_idx']) + old_t = np.linspace(0, 20, n_anchors)[min_idx:max_idx+1] + new_t = old_t - np.min(old_t) + new_t = new_t * 20 / np.max(new_t) + if by == 'velocity' and not full_range: + anchor_interval = 20 / (max_idx + 1 - min_idx) + min_idx = int(adata[:, gene].var['fit_anchor_velo_min_idx']) + max_idx = int(adata[:, gene].var['fit_anchor_velo_max_idx']) + start = 0 + (min_idx - + adata[:, gene].var['fit_anchor_min_idx']) \ + * anchor_interval + end = 20 + (max_idx - + adata[:, gene].var['fit_anchor_max_idx']) \ + * anchor_interval + new_t = np.linspace(start, end, max_idx + 1 - min_idx) + ax = axs[row, 0] + a_c = adata[:, gene].varm['fit_anchor_c' if by == 'expression' + else 'fit_anchor_c_velo']\ + .ravel()[min_idx:max_idx+1] + if show_switches: + for t_sw in t_sw_array: + if t_sw > 0: + ax.vlines(t_sw, np.min(c), np.max(c), colors='black', + linestyles='dashed', alpha=0.5) + ax.plot(new_t[0:new_t.shape[0]], a_c, linewidth=linewidth, + color='black', alpha=0.5) + ax = axs[row, 1] + a_u = adata[:, gene].varm['fit_anchor_u' if by == 'expression' + else 'fit_anchor_u_velo']\ + .ravel()[min_idx:max_idx+1] + if show_switches: + for t_sw in t_sw_array: + if t_sw > 0: + ax.vlines(t_sw, np.min(u), np.max(u), colors='black', + linestyles='dashed', alpha=0.5) + ax.plot(new_t[0:new_t.shape[0]], a_u, linewidth=linewidth, + color='black', alpha=0.5) + ax = axs[row, 2] + a_s = adata[:, gene].varm['fit_anchor_s' if by == 'expression' + else 'fit_anchor_s_velo']\ + .ravel()[min_idx:max_idx+1] + if show_switches: + for t_sw in t_sw_array: + if t_sw > 0: + ax.vlines(t_sw, np.min(s), np.max(s), colors='black', + linestyles='dashed', alpha=0.5) + ax.plot(new_t[0:new_t.shape[0]], a_s, linewidth=linewidth, + color='black', alpha=0.5) + + axs[row, 0].set_title(f'{gene} ATAC' if by == 'expression' + else f'{gene} chromatin velocity') + axs[row, 0].set_xlabel('t' if by == 'expression' else '~t') + axs[row, 0].set_ylabel('c' if by == 'expression' else 'dc/dt') + axs[row, 1].set_title(f'{gene} unspliced' + ('' if by == 'expression' + else ' velocity')) + axs[row, 1].set_xlabel('t' if by == 'expression' else '~t') + axs[row, 1].set_ylabel('u' if by == 'expression' else 'du/dt') + axs[row, 2].set_title(f'{gene} spliced' + ('' if by == 'expression' + else ' velocity')) + axs[row, 2].set_xlabel('t' if by == 'expression' else '~t') + axs[row, 2].set_ylabel('s' if by == 'expression' else 'ds/dt') + + for j in range(3): + ax = axs[row, j] + if not axis_on: + ax.xaxis.set_ticks_position('none') + ax.yaxis.set_ticks_position('none') + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + if not frame_on: + ax.xaxis.set_ticks_position('none') + ax.yaxis.set_ticks_position('none') + ax.set_frame_on(False) + fig.tight_layout() + + +def scatter_plot(adata, + genes, + by='us', + color_by='state', + n_cols=5, + axis_on=True, + frame_on=True, + show_anchors=True, + show_switches=True, + show_all_anchors=False, + title_more_info=False, + velocity_arrows=False, + downsample=1, + figsize=None, + pointsize=2, + markersize=5, + linewidth=2, + cmap='coolwarm', + view_3d_elev=None, + view_3d_azim=None, + full_name=False + ): + """Gene scatter plot. + + This function plots phase portraits of the specified plane. + + Parameters + ---------- + adata: :class:`~anndata.AnnData` + Anndata result from dynamics recovery. + genes: `str`, list of `str` + List of genes to plot. + by: `str` (default: `us`) + Plot unspliced-spliced plane if `us`. Plot chromatin-unspliced plane + if `cu`. + Plot 3D phase portraits if `cus`. + color_by: `str` (default: `state`) + Color by the four potential states if `state`. Other common values are + leiden, louvain, celltype, etc. + If not `state`, the color field must be present in `.uns`, which can be + pre-computed with `scanpy.pl.scatter`. + For `state`, red, orange, green, and blue represent state 1, 2, 3, and + 4, respectively. + When `by=='us'`, `color_by` can also be `c`, which displays the log + accessibility on U-S phase portraits. + n_cols: `int` (default: 5) + Number of columns to plot on each row. + axis_on: `bool` (default: `True`) + Whether to show axis labels. + frame_on: `bool` (default: `True`) + Whether to show plot frames. + show_anchors: `bool` (default: `True`) + Whether to display anchors. + show_switches: `bool` (default: `True`) + Whether to show switch times. The three switch times and the end of + trajectory are indicated by + circle, cross, dismond, and star, respectively. + show_all_anchors: `bool` (default: `False`) + Whether to display full range of (predicted) anchors even for + repression-only genes. + title_more_info: `bool` (default: `False`) + Whether to display model, direction, and likelihood information for + the gene in title. + velocity_arrows: `bool` (default: `False`) + Whether to show velocity arrows of cells on the phase portraits. + downsample: `int` (default: 1) + How much to downsample the cells. The remaining number will be + `1/downsample` of original. + figsize: `tuple` (default: `None`) + Total figure size. + pointsize: `float` (default: 2) + Point size for scatter plots. + markersize: `float` (default: 5) + Point size for switch time points. + linewidth: `float` (default: 2) + Line width for connected anchors. + cmap: `str` (default: `coolwarm`) + Color map for log accessibilities or other continuous color keys when + plotting on U-S plane. + view_3d_elev: `float` (default: `None`) + Matplotlib 3D plot `elev` argument. `elev=90` is the same as U-S plane, + and `elev=0` is the same as C-U plane. + view_3d_azim: `float` (default: `None`) + Matplotlib 3D plot `azim` argument. `azim=270` is the same as U-S + plane, and `azim=0` is the same as C-U plane. + full_name: `bool` (default: `False`) + Show full names for chromatin, unspliced, and spliced rather than + using abbreviated terms c, u, and s. + """ + from pandas.api.types import is_numeric_dtype, is_categorical_dtype + if by not in ['us', 'cu', 'cus']: + raise ValueError("'by' argument must be one of ['us', 'cu', 'cus']") + if color_by == 'state': + types = [0, 1, 2, 3] + colors = ['tab:red', 'tab:orange', 'tab:green', 'tab:blue'] + elif by == 'us' and color_by == 'c': + types = None + elif color_by in adata.obs and is_numeric_dtype(adata.obs[color_by]): + types = None + colors = adata.obs[color_by].values + elif color_by in adata.obs and is_categorical_dtype(adata.obs[color_by]) \ + and color_by+'_colors' in adata.uns.keys(): + types = adata.obs[color_by].cat.categories + colors = adata.uns[f'{color_by}_colors'] + else: + raise ValueError('Currently, color key must be a single string of ' + 'either numerical or categorical available in adata' + ' obs, and the colors of categories can be found in' + ' adata uns.') + + if 'velo_s_params' not in adata.uns.keys() \ + or 'fit_anchor_s' not in adata.varm.keys(): + show_anchors = False + if color_by == 'state' and 'fit_state' not in adata.layers.keys(): + raise ValueError('fit_state is not found. Please run ' + 'recover_dynamics_chrom function first or provide a ' + 'valid color key.') + + downsample = np.clip(int(downsample), 1, 10) + genes = np.array(genes) + missing_genes = genes[~np.isin(genes, adata.var_names)] + if len(missing_genes) > 0: + main_info(f'{missing_genes} not found', v=0) + genes = genes[np.isin(genes, adata.var_names)] + gn = len(genes) + if gn == 0: + return + if gn < n_cols: + n_cols = gn + if by == 'cus': + fig, axs = plt.subplots(-(-gn // n_cols), n_cols, squeeze=False, + figsize=(3.2*n_cols, 2.7*(-(-gn // n_cols))) + if figsize is None else figsize, + subplot_kw={'projection': '3d'}) + else: + fig, axs = plt.subplots(-(-gn // n_cols), n_cols, squeeze=False, + figsize=(2.7*n_cols, 2.4*(-(-gn // n_cols))) + if figsize is None else figsize) + fig.patch.set_facecolor('white') + count = 0 + for gene in genes: + u = adata[:, gene].layers['Mu'].copy() if 'Mu' in adata.layers \ + else adata[:, gene].layers['unspliced'].copy() + s = adata[:, gene].layers['Ms'].copy() if 'Ms' in adata.layers \ + else adata[:, gene].layers['spliced'].copy() + u = u.A if sparse.issparse(u) else u + s = s.A if sparse.issparse(s) else s + u, s = np.ravel(u), np.ravel(s) + if 'ATAC' not in adata.layers.keys() and \ + 'Mc' not in adata.layers.keys(): + show_anchors = False + elif 'ATAC' in adata.layers.keys(): + c = adata[:, gene].layers['ATAC'].copy() + c = c.A if sparse.issparse(c) else c + c = np.ravel(c) + elif 'Mc' in adata.layers.keys(): + c = adata[:, gene].layers['Mc'].copy() + c = c.A if sparse.issparse(c) else c + c = np.ravel(c) + + if velocity_arrows: + if 'velo_u' in adata.layers.keys(): + vu = adata[:, gene].layers['velo_u'].copy() + elif 'velocity_u' in adata.layers.keys(): + vu = adata[:, gene].layers['velocity_u'].copy() + else: + vu = np.zeros(adata.n_obs) + max_u = np.max([np.max(u), 1e-6]) + u /= max_u + vu = np.ravel(vu) + vu /= np.max([np.max(np.abs(vu)), 1e-6]) + if 'velo_s' in adata.layers.keys(): + vs = adata[:, gene].layers['velo_s'].copy() + elif 'velocity' in adata.layers.keys(): + vs = adata[:, gene].layers['velocity'].copy() + max_s = np.max([np.max(s), 1e-6]) + s /= max_s + vs = np.ravel(vs) + vs /= np.max([np.max(np.abs(vs)), 1e-6]) + if 'velo_chrom' in adata.layers.keys(): + vc = adata[:, gene].layers['velo_chrom'].copy() + max_c = np.max([np.max(c), 1e-6]) + c /= max_c + vc = np.ravel(vc) + vc /= np.max([np.max(np.abs(vc)), 1e-6]) + + row = count // n_cols + col = count % n_cols + ax = axs[row, col] + if types is not None: + for i in range(len(types)): + if color_by == 'state': + filt = adata[:, gene].layers['fit_state'] == types[i] + else: + filt = adata.obs[color_by] == types[i] + filt = np.ravel(filt) + if by == 'us': + if velocity_arrows: + ax.quiver(s[filt][::downsample], u[filt][::downsample], + vs[filt][::downsample], + vu[filt][::downsample], color=colors[i], + alpha=0.5, scale_units='xy', scale=10, + width=0.005, headwidth=4, headaxislength=5.5) + else: + ax.scatter(s[filt][::downsample], + u[filt][::downsample], s=pointsize, + c=colors[i], alpha=0.7) + elif by == 'cu': + if velocity_arrows: + ax.quiver(u[filt][::downsample], + c[filt][::downsample], + vu[filt][::downsample], + vc[filt][::downsample], color=colors[i], + alpha=0.5, scale_units='xy', scale=10, + width=0.005, headwidth=4, headaxislength=5.5) + else: + ax.scatter(u[filt][::downsample], + c[filt][::downsample], s=pointsize, + c=colors[i], alpha=0.7) + else: + if velocity_arrows: + ax.quiver(s[filt][::downsample], + u[filt][::downsample], c[filt][::downsample], + vs[filt][::downsample], + vu[filt][::downsample], + vc[filt][::downsample], + color=colors[i], alpha=0.4, length=0.1, + arrow_length_ratio=0.5, normalize=True) + else: + ax.scatter(s[filt][::downsample], + u[filt][::downsample], + c[filt][::downsample], s=pointsize, + c=colors[i], alpha=0.7) + elif color_by == 'c': + if 'velo_s_params' in adata.uns.keys() and \ + 'outlier' in adata.uns['velo_s_params']: + outlier = adata.uns['velo_s_params']['outlier'] + else: + outlier = 99.8 + non_zero = (u > 0) & (s > 0) & (c > 0) + non_outlier = u < np.percentile(u, outlier) + non_outlier &= s < np.percentile(s, outlier) + non_outlier &= c < np.percentile(c, outlier) + c -= np.min(c) + c /= np.max(c) + if velocity_arrows: + ax.quiver(s[non_zero & non_outlier][::downsample], + u[non_zero & non_outlier][::downsample], + vs[non_zero & non_outlier][::downsample], + vu[non_zero & non_outlier][::downsample], + np.log1p(c[non_zero & non_outlier][::downsample]), + alpha=0.5, + scale_units='xy', scale=10, width=0.005, + headwidth=4, headaxislength=5.5, cmap=cmap) + else: + ax.scatter(s[non_zero & non_outlier][::downsample], + u[non_zero & non_outlier][::downsample], + s=pointsize, + c=np.log1p(c[non_zero & non_outlier][::downsample]), + alpha=0.8, cmap=cmap) + else: + if by == 'us': + if velocity_arrows: + ax.quiver(s[::downsample], u[::downsample], + vs[::downsample], vu[::downsample], + colors[::downsample], alpha=0.5, + scale_units='xy', scale=10, width=0.005, + headwidth=4, headaxislength=5.5, cmap=cmap) + else: + ax.scatter(s[::downsample], u[::downsample], s=pointsize, + c=colors[::downsample], alpha=0.7, cmap=cmap) + elif by == 'cu': + if velocity_arrows: + ax.quiver(u[::downsample], c[::downsample], + vu[::downsample], vc[::downsample], + colors[::downsample], alpha=0.5, + scale_units='xy', scale=10, width=0.005, + headwidth=4, headaxislength=5.5, cmap=cmap) + else: + ax.scatter(u[::downsample], c[::downsample], s=pointsize, + c=colors[::downsample], alpha=0.7, cmap=cmap) + else: + if velocity_arrows: + ax.quiver(s[::downsample], u[::downsample], + c[::downsample], vs[::downsample], + vu[::downsample], vc[::downsample], + colors[::downsample], alpha=0.4, length=0.1, + arrow_length_ratio=0.5, normalize=True, + cmap=cmap) + else: + ax.scatter(s[::downsample], u[::downsample], + c[::downsample], s=pointsize, + c=colors[::downsample], alpha=0.7, cmap=cmap) + + if show_anchors: + min_idx = int(adata[:, gene].var['fit_anchor_min_idx']) + max_idx = int(adata[:, gene].var['fit_anchor_max_idx']) + a_c = adata[:, gene].varm['fit_anchor_c']\ + .ravel()[min_idx:max_idx+1].copy() + a_u = adata[:, gene].varm['fit_anchor_u']\ + .ravel()[min_idx:max_idx+1].copy() + a_s = adata[:, gene].varm['fit_anchor_s']\ + .ravel()[min_idx:max_idx+1].copy() + if velocity_arrows: + a_c /= max_c + a_u /= max_u + a_s /= max_s + if by == 'us': + ax.plot(a_s, a_u, linewidth=linewidth, color='black', + alpha=0.7, zorder=1000) + elif by == 'cu': + ax.plot(a_u, a_c, linewidth=linewidth, color='black', + alpha=0.7, zorder=1000) + else: + ax.plot(a_s, a_u, a_c, linewidth=linewidth, color='black', + alpha=0.7, zorder=1000) + if show_all_anchors: + a_c_pre = adata[:, gene].varm['fit_anchor_c']\ + .ravel()[:min_idx].copy() + a_u_pre = adata[:, gene].varm['fit_anchor_u']\ + .ravel()[:min_idx].copy() + a_s_pre = adata[:, gene].varm['fit_anchor_s']\ + .ravel()[:min_idx].copy() + if velocity_arrows: + a_c_pre /= max_c + a_u_pre /= max_u + a_s_pre /= max_s + if len(a_c_pre) > 0: + if by == 'us': + ax.plot(a_s_pre, a_u_pre, linewidth=linewidth/1.3, + color='black', alpha=0.6, zorder=1000) + elif by == 'cu': + ax.plot(a_u_pre, a_c_pre, linewidth=linewidth/1.3, + color='black', alpha=0.6, zorder=1000) + else: + ax.plot(a_s_pre, a_u_pre, a_c_pre, + linewidth=linewidth/1.3, color='black', + alpha=0.6, zorder=1000) + if show_switches: + t_sw_array = np.array([adata[:, gene].var['fit_t_sw1'] + .values[0], + adata[:, gene].var['fit_t_sw2'] + .values[0], + adata[:, gene].var['fit_t_sw3'] + .values[0]]) + in_range = (t_sw_array > 0) & (t_sw_array < 20) + a_c_sw = adata[:, gene].varm['fit_anchor_c_sw'].ravel().copy() + a_u_sw = adata[:, gene].varm['fit_anchor_u_sw'].ravel().copy() + a_s_sw = adata[:, gene].varm['fit_anchor_s_sw'].ravel().copy() + if velocity_arrows: + a_c_sw /= max_c + a_u_sw /= max_u + a_s_sw /= max_s + if in_range[0]: + c_sw1, u_sw1, s_sw1 = a_c_sw[0], a_u_sw[0], a_s_sw[0] + if by == 'us': + ax.plot([s_sw1], [u_sw1], "om", markersize=markersize, + zorder=2000) + elif by == 'cu': + ax.plot([u_sw1], [c_sw1], "om", markersize=markersize, + zorder=2000) + else: + ax.plot([s_sw1], [u_sw1], [c_sw1], "om", + markersize=markersize, zorder=2000) + if in_range[1]: + c_sw2, u_sw2, s_sw2 = a_c_sw[1], a_u_sw[1], a_s_sw[1] + if by == 'us': + ax.plot([s_sw2], [u_sw2], "Xm", markersize=markersize, + zorder=2000) + elif by == 'cu': + ax.plot([u_sw2], [c_sw2], "Xm", markersize=markersize, + zorder=2000) + else: + ax.plot([s_sw2], [u_sw2], [c_sw2], "Xm", + markersize=markersize, zorder=2000) + if in_range[2]: + c_sw3, u_sw3, s_sw3 = a_c_sw[2], a_u_sw[2], a_s_sw[2] + if by == 'us': + ax.plot([s_sw3], [u_sw3], "Dm", markersize=markersize, + zorder=2000) + elif by == 'cu': + ax.plot([u_sw3], [c_sw3], "Dm", markersize=markersize, + zorder=2000) + else: + ax.plot([s_sw3], [u_sw3], [c_sw3], "Dm", + markersize=markersize, zorder=2000) + if max_idx > adata.uns['velo_s_params']['t'] - 4: + if by == 'us': + ax.plot([a_s[-1]], [a_u[-1]], "*m", + markersize=markersize, zorder=2000) + elif by == 'cu': + ax.plot([a_u[-1]], [a_c[-1]], "*m", + markersize=markersize, zorder=2000) + else: + ax.plot([a_s[-1]], [a_u[-1]], [a_c[-1]], "*m", + markersize=markersize, zorder=2000) + + if by == 'cus' and \ + (view_3d_elev is not None or view_3d_azim is not None): + # US: elev=90, azim=270. CU: elev=0, azim=0. + ax.view_init(elev=view_3d_elev, azim=view_3d_azim) + title = gene + if title_more_info: + if 'fit_model' in adata.var: + title += f" M{int(adata[:,gene].var['fit_model'].values[0])}" + if 'fit_direction' in adata.var: + title += f" {adata[:,gene].var['fit_direction'].values[0]}" + if 'fit_likelihood' in adata.var \ + and not np.all(adata.var['fit_likelihood'].values == -1): + title += " " + f"{adata[:,gene].var['fit_likelihood'].values[0]:.3g}" + ax.set_title(f'{title}', fontsize=11) + if by == 'us': + ax.set_xlabel('spliced' if full_name else 's') + ax.set_ylabel('unspliced' if full_name else 'u') + elif by == 'cu': + ax.set_xlabel('unspliced' if full_name else 'u') + ax.set_ylabel('chromatin' if full_name else 'c') + elif by == 'cus': + ax.set_xlabel('spliced' if full_name else 's') + ax.set_ylabel('unspliced' if full_name else 'u') + ax.set_zlabel('chromatin' if full_name else 'c') + if by in ['us', 'cu']: + if not axis_on: + ax.xaxis.set_ticks_position('none') + ax.yaxis.set_ticks_position('none') + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + if not frame_on: + ax.xaxis.set_ticks_position('none') + ax.yaxis.set_ticks_position('none') + ax.set_frame_on(False) + elif by == 'cus': + if not axis_on: + ax.set_xlabel('') + ax.set_ylabel('') + ax.set_zlabel('') + ax.xaxis.set_ticklabels([]) + ax.yaxis.set_ticklabels([]) + ax.zaxis.set_ticklabels([]) + if not frame_on: + ax.xaxis._axinfo['grid']['color'] = (1, 1, 1, 0) + ax.yaxis._axinfo['grid']['color'] = (1, 1, 1, 0) + ax.zaxis._axinfo['grid']['color'] = (1, 1, 1, 0) + ax.xaxis._axinfo['tick']['inward_factor'] = 0 + ax.xaxis._axinfo['tick']['outward_factor'] = 0 + ax.yaxis._axinfo['tick']['inward_factor'] = 0 + ax.yaxis._axinfo['tick']['outward_factor'] = 0 + ax.zaxis._axinfo['tick']['inward_factor'] = 0 + ax.zaxis._axinfo['tick']['outward_factor'] = 0 + count += 1 + for i in range(col+1, n_cols): + fig.delaxes(axs[row, i]) + fig.tight_layout() \ No newline at end of file diff --git a/dynamo/multivelo/globals.py b/dynamo/multivelo/globals.py new file mode 100644 index 000000000..cb18de654 --- /dev/null +++ b/dynamo/multivelo/globals.py @@ -0,0 +1,58 @@ +import os +import platform + +# Determine platform on which analysis is running +running_on = platform.system() + +# Set up locale configuration here +REPO_PATH, ROOT_PATH = None, None # To make the lint checker happy ... +if running_on == 'Darwin': + # ... OSX system + # ... ... root path + ROOT_PATH = '/Users/cordessf/OneDrive' # <============= CHANGE THIS !!! + + # ... ... repo path + REPO_PATH = os.path.join(ROOT_PATH, 'ACI', 'Repositories') +elif running_on == 'Linux': + # ... Linux system + # ... ... root path + ROOT_PATH = '/data/LIRGE' # <============= CHANGE THIS !!! + + # ... ... repo path + REPO_PATH = os.path.join(ROOT_PATH, 'Repositories') + +# ... Path to base directory (where code and results are kept) +BASE_PATH = os.path.join(REPO_PATH, 'MultiDynamo') + +# ... Path to cache intermediate results +CACHE_PATH = os.path.join(ROOT_PATH, 'cache') +if not os.path.exists(CACHE_PATH): + os.makedirs(CACHE_PATH) + +# ... Path to data +DATA_PATH = os.path.join(ROOT_PATH, 'external_data', 'multiome') +if not os.path.exists(DATA_PATH): + os.makedirs(DATA_PATH) + +# ... Path to reference data +REFERENCE_DATA_PATH = os.path.join(ROOT_PATH, 'reference_data') +if not os.path.exists(REFERENCE_DATA_PATH): + os.makedirs(REFERENCE_DATA_PATH) + +# Structure the data as it would come out of a cellranger run +# ... cellranger outs directory +OUTS_PATH = os.path.join(DATA_PATH, 'outs') +if not os.path.exists(OUTS_PATH): + os.makedirs(OUTS_PATH) + +# Path to ATAC-seq data +ATAC_PATH = os.path.join(ROOT_PATH, 'external_data', '10k_human_PBMC_ATAC') + +# Path to genome annotation +GTF_PATH = os.path.join(REFERENCE_DATA_PATH, 'annotation', 'Homo_sapiens.GRCh38.112.gtf.gz') + +# Path to multiomic data +MULTIOME_PATH = DATA_PATH + +# Path to RNA-seq data +RNA_PATH = os.path.join(ROOT_PATH, 'external_data', '10k_human_PBMC_RNA') diff --git a/dynamo/multivelo/neural_nets/dir0.pt b/dynamo/multivelo/neural_nets/dir0.pt new file mode 100644 index 000000000..92fb3454e Binary files /dev/null and b/dynamo/multivelo/neural_nets/dir0.pt differ diff --git a/dynamo/multivelo/neural_nets/dir1.pt b/dynamo/multivelo/neural_nets/dir1.pt new file mode 100644 index 000000000..d67c03f0d Binary files /dev/null and b/dynamo/multivelo/neural_nets/dir1.pt differ diff --git a/dynamo/multivelo/neural_nets/dir2_m1.pt b/dynamo/multivelo/neural_nets/dir2_m1.pt new file mode 100644 index 000000000..7dc053f80 Binary files /dev/null and b/dynamo/multivelo/neural_nets/dir2_m1.pt differ diff --git a/dynamo/multivelo/neural_nets/dir2_m2.pt b/dynamo/multivelo/neural_nets/dir2_m2.pt new file mode 100644 index 000000000..e4b2a10da Binary files /dev/null and b/dynamo/multivelo/neural_nets/dir2_m2.pt differ diff --git a/dynamo/multivelo/old_MultiVelocity.py b/dynamo/multivelo/old_MultiVelocity.py new file mode 100644 index 000000000..44dc13565 --- /dev/null +++ b/dynamo/multivelo/old_MultiVelocity.py @@ -0,0 +1,1401 @@ +from anndata import AnnData +import matplotlib.pyplot as plt +from multiprocessing import Pool +from mudata import MuData +import numpy as np +import os +from os import PathLike +import pandas as pd +import scanpy as sc +from scipy.sparse import coo_matrix, csr_matrix, hstack, issparse +from scipy.sparse.linalg import svds + +from typing import ( + Dict, + List, + Literal, + Optional, + Tuple, + Union +) + +import warnings + +# Import from dynamo +from ..dynamo_logger import ( + LoggerManager, + main_exception, + main_info, +) + +# Imports from MultiDynamo +from .ChromatinVelocity import ChromatinVelocity +from .MultiConfiguration import MDKM +from .pyWNN import pyWNN + + +# Static function +# direction_cosine +def direction_cosine(args): + i, j, expression_mtx, velocity_mtx = args + + if i == j: + return i, j, -1 + + delta_ij = None + if isinstance(expression_mtx, csr_matrix): + delta_ij = (expression_mtx.getrow(j) - expression_mtx.getrow(i)).toarray().flatten() + elif isinstance(expression_mtx, np.ndarray): + delta_ij = (expression_mtx[j, :] - expression_mtx[i, :]).flatten() + else: + main_exception(f'Expression matrix is instance of class {type(expression_mtx)}') + + vel_i = velocity_mtx.getrow(i).toarray().flatten() + + dot_product = np.dot(delta_ij, vel_i) # vel_i.dot(delta_ij) + magnitude_vel_i = np.linalg.norm(vel_i) + magnitude_delta_ij = np.linalg.norm(delta_ij) + + if magnitude_vel_i != 0 and magnitude_delta_ij != 0: + cosine_similarity = dot_product / (magnitude_vel_i * magnitude_delta_ij) + else: + # One of velocity or delta_ij is zero, so can't compute a cosine, we'll just set to + # lowest possible value (-1) + cosine_similarity = -1 + + return i, j, cosine_similarity + + +# get_connectivities - patterned after function in scVelo +def get_connectivities(adata: AnnData, + mode: str = 'connectivities', + n_neighbors: int = None, + recurse_neighbors: bool = False + ) -> Union[csr_matrix, None]: + if 'neighbors' in adata.uns.keys(): + C = get_neighbors(adata=adata, mode=mode) + if n_neighbors is not None and n_neighbors < get_n_neighbors(adata=adata): + if mode == 'connectivities': + C = select_connectivities(C, n_neighbors) + else: + C = select_distances(C, n_neighbors) + connectivities = C > 0 + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + connectivities.setdiag(1) + if recurse_neighbors: + connectivities += connectivities.dot(connectivities * 0.5) + connectivities.data = np.clip(connectivities.data, 0, 1) + connectivities = connectivities.multiply(1.0 / connectivities.sum(1)) + return connectivities.tocsr().astype(np.float32) + else: + return None + + +# get_n_neighbors - lifted from scVelo +def get_n_neighbors(adata: AnnData) -> int: + return adata.uns.get('neighbors', {}).get('params', {}).get('n_neighbors', 0) + + +def get_neighbors(adata: AnnData, + mode: str = 'distances'): + if hasattr(adata, 'obsp') and mode in adata.obsp: + return adata.obsp[mode] + elif 'neighbors' in adata.uns.keys() and mode in adata.uns['neighbors']: + return adata.uns['neighbors'][mode] + else: + main_exception(f'The selected mode {mode} is not valid.') + + +def lifted_chromatin_velocity(arg): + i, j, chromatin_state, cosines, expression_mtx, rna_velocity = arg + + if i == j: + main_exception('A cell should never be its own integral neighbor.') + + # Compute change in chromatin state + delta_c_ij = None + if isinstance(chromatin_state, csr_matrix): + delta_c_ij = (chromatin_state.getrow(j) - chromatin_state.getrow(i)).toarray().flatten() + elif isinstance(chromatin_state, np.ndarray): + delta_c_ij = (chromatin_state[j, :] - chromatin_state[i, :]).flatten() + else: + main_exception(f'Chromatin state matrix is instance of class {type(chromatin_state)}') + + # Retrieve cosine + cosine = cosines[i, j] + + # Compute change in RNA expression + delta_s_ij = None + if isinstance(expression_mtx, csr_matrix): + delta_s_ij = (expression_mtx.getrow(j) - expression_mtx.getrow(i)).toarray().flatten() + elif isinstance(expression_mtx, np.ndarray): + delta_s_ij = (expression_mtx[j, :] - expression_mtx[i, :]).flatten() + else: + main_exception(f'RNA expression matrix is instance of class {type(expression_mtx)}') + + # Compute norms + norm_delta_s_ij = np.linalg.norm(delta_s_ij) + norm_rna_velocity = np.linalg.norm(rna_velocity.toarray()) + + if norm_delta_s_ij != 0: + chromatin_velocity = (norm_rna_velocity * cosine / norm_delta_s_ij) * delta_c_ij + else: + chromatin_velocity = np.zeros(chromatin_state.shape[1]) + + return i, chromatin_velocity + + +def regression(c, + u, + s, + ss, + us, + uu, + fit_args, + mode, + gene): + c_90 = np.percentile(c, 90) + u_90 = np.percentile(u, 90) + s_90 = np.percentile(s, 90) + + low_quality = (c_90 == 0 or s_90 == 0 or u_90 == 0) + + if low_quality: + # main_info(f'Skipping low quality gene {gene}.') + return np.zeros(len(u)), np.zeros(len(u)), 0, 0, np.inf + + cvc = ChromatinVelocity(c, + u, + s, + ss, + us, + uu, + fit_args, + gene=gene) + + if cvc.low_quality: + return np.zeros(len(u)), np.zeros(len(u)), 0, 0, np.inf + + if mode == 'deterministic': + cvc.compute_deterministic() + elif mode == 'stochastic': + cvc.compute_stochastic() + velocity = cvc.get_velocity(mode=mode) + gamma = cvc.get_gamma(mode=mode) + r2 = cvc.get_r2(mode=mode) + loss = cvc.get_loss(mode=mode) + variance_velocity = (None if mode == 'deterministic' + else cvc.get_variance_velocity()) + return velocity, variance_velocity, gamma, r2, loss + + +def select_connectivities(connectivities, + n_neighbors=None): + C = connectivities.copy() + n_counts = (C > 0).sum(1).A1 if issparse(C) else (C > 0).sum(1) + n_neighbors = ( + n_counts.min() if n_neighbors is None else min(n_counts.min(), + n_neighbors) + ) + rows = np.where(n_counts > n_neighbors)[0] + cumsum_neighs = np.insert(n_counts.cumsum(), 0, 0) + dat = C.data + + for row in rows: + n0, n1 = cumsum_neighs[row], cumsum_neighs[row + 1] + rm_idx = n0 + dat[n0:n1].argsort()[::-1][n_neighbors:] + dat[rm_idx] = 0 + + C.eliminate_zeros() + return C + + +def select_distances(dist, + n_neighbors: int = None): + D = dist.copy() + n_counts = (D > 0).sum(1).A1 if issparse(D) else (D > 0).sum(1) + n_neighbors = ( + n_counts.min() if n_neighbors is None else min(n_counts.min(), n_neighbors) + ) + rows = np.where(n_counts > n_neighbors)[0] + cumsum_neighs = np.insert(n_counts.cumsum(), 0, 0) + dat = D.data + + for row in rows: + n0, n1 = cumsum_neighs[row], cumsum_neighs[row + 1] + rm_idx = n0 + dat[n0:n1].argsort()[n_neighbors:] + dat[rm_idx] = 0 + + D.eliminate_zeros() + return D + + +# smooth_scale - lifted from MultiVelo +def smooth_scale(conn, + vector): + max_to = np.max(vector) + min_to = np.min(vector) + v = conn.dot(vector.T).T + max_from = np.max(v) + min_from = np.min(v) + res = ((v - min_from) * (max_to - min_to) / (max_from - min_from)) + min_to + return res + + +# top_n_sparse - lifted from MultiVelo +def top_n_sparse(conn, n): + conn_ll = conn.tolil() + for i in range(conn_ll.shape[0]): + row_data = np.array(conn_ll.data[i]) + row_idx = np.array(conn_ll.rows[i]) + new_idx = row_data.argsort()[-n:] + top_val = row_data[new_idx] + top_idx = row_idx[new_idx] + conn_ll.data[i] = top_val.tolist() + conn_ll.rows[i] = top_idx.tolist() + conn = conn_ll.tocsr() + idx1 = conn > 0 + idx2 = conn > 0.25 + idx3 = conn > 0.5 + conn[idx1] = 0.25 + conn[idx2] = 0.5 + conn[idx3] = 1 + conn.eliminate_zeros() + return conn + + +class MultiVelocity: + def __init__(self, + mdata: MuData, + cosine_similarities: csr_matrix = None, + cre_dict: Dict = None, + include_gene_body: bool = False, + integral_neighbors: Dict = None, + linkage_fn: str = 'feature_linkage.bedpe', # in 'outs/analysis/feature_linkage' directory + linkage_method: Literal['cellranger', 'cicero', 'scenic+'] = 'cellranger', + max_peak_dist: int = 10000, + min_corr: float = 0.5, + neighbor_method: Literal['multivi', 'wnn'] = 'multivi', + nn_dist: csr_matrix = None, + nn_idx: csr_matrix = None, + peak_annot_fn: str = 'peak_annotation.tsv', # in 'outs' directory + promoter_dict: Dict = None + ): + # Initialize instance variables + self.mdata = mdata.copy() if mdata is not None else None + + self._cre_dict = cre_dict.copy() if cre_dict is not None else None + + self.cosine_similarities = cosine_similarities.copy() if cosine_similarities is not None else None + + self.include_gene_body = include_gene_body + + self.integral_neighbors = integral_neighbors.copy() if integral_neighbors is not None else None + + self.linkage_fn = linkage_fn + + self.linkage_method = linkage_method + + self.max_peak_dist = max_peak_dist + + self.min_corr = min_corr + + self.neighbor_method = neighbor_method + + self.nn_dist = nn_dist.copy() if nn_dist is not None else None + + self.nn_idx = nn_idx.copy() if nn_idx is not None else None + + self.peak_annot_fn = peak_annot_fn + + self._promoter_dict = promoter_dict.copy() if promoter_dict is not None else None + + def atac_elements(self): + return self.mdata['atac'].var_names.tolist() + + def compute_linkages(self) -> None: + if self.linkage_method == 'cellranger': + self.compute_linkages_via_cellranger() + elif self.linkage_method == 'cicero': + self.compute_linkages_via_cicero() + elif self.linkage_method == 'scenic+': + self.compute_linkages_via_scenicplus() + else: + main_exception(f'Unrecognized method to compute linkages ({self.linkage_method}) requested.') + + def compute_linkages_via_cellranger(self) -> None: + # This reads the cellranger-arc 'feature_linkage.bedpe' and 'peak_annotation.tsv' files + # to extract dictionaries attributing cis-regulatory elements with specific genes + main_info('Computing linkages via cellranger ...') + linkage_logger = LoggerManager.gen_logger('compute_linkages_via_cellranger') + linkage_logger.log_time() + + # Confirm that this is matched ATAC- and RNA-seq data + if not self.mdata.mod['atac'].uns[MDKM.MATCHED_ATAC_RNA_DATA_KEY]: + main_exception('Cannot use cellranger to compute CRE linkages for UNMATCHED data') + + outs_data_path = os.path.join(self.mdata.mod['atac'].uns['base_data_path'], 'outs') + # Confirm that the base path to the 'outs' directory exists + if not os.path.exists(outs_data_path): + main_exception(f'The path to the 10X outs directory ({outs_data_path}) does not exist.') + + # Read annotations + peak_annot_path = os.path.join(outs_data_path, self.peak_annot_fn) + if not os.path.exists(peak_annot_path): + main_exception(f'The path to the peak annotation file ({peak_annot_path}) does not exist.') + + corr_dict, distal_dict, gene_body_dict, promoter_dict = {}, {}, {}, {} + with open(peak_annot_path) as f: + # Scan the header to determine version of CellRanger used in making the peak annotation file + header = next(f) + fields = header.split('\t') + + # Peak annotation should contain 4 columns for version 1.X of CellRanger and 6 columns for + # version 2.X + if len(fields) not in [4, 6]: + main_exception('Peak annotation file should contain 4 columns (CellRanger ARC 1.0.0) ' + + 'or 6 columns (CellRanger ARC 2.0.0)') + else: + offset = 0 if len(fields) == 4 else 2 + + for line in f: + fields = line.rstrip().split('\t') + + peak = f'{fields[0]}:{fields[1]}-{fields[2]}' if offset else \ + f"{fields[0].split('_')[0]}:{fields[0].split('_')[1]}-{fields[0].split('_')[2]}" + + if fields[1 + offset] == '': + continue + + genes, dists, types = \ + fields[1 + offset].split(';'), fields[2 + offset].split(';'), fields[3 + offset].split(';') + + for gene, dist, annot in zip(genes, dists, types): + if annot == 'promoter': + promoter_dict.setdefault(gene, []).append(peak) + elif annot == 'distal': + if dist == '0': + gene_body_dict.setdefault(gene, []).append(peak) + else: + distal_dict.setdefault(gene, []).append(peak) + + # Read linkages + linkage_path = os.path.join(outs_data_path, 'analysis', 'feature_linkage', self.linkage_fn) + if not os.path.exists(linkage_path): + main_exception(f'The path to the linkage file ({linkage_path}) does not exist.') + with open(linkage_path) as f: + for line in f: + fields = line.rstrip().split('\t') + + # Form proper peak coordinates + peak_1, peak_2 = f'{fields[0]}:{fields[1]}-{fields[2]}', f'{fields[3]}:{fields[4]}-{fields[5]}' + + # Split the gene pairs + genes_annots_1, genes_annots_2 = \ + fields[6].split('><')[0][1:].split(';'), fields[6].split('><')[1][:-1].split(';') + + # Extract correlation + correlation = float(fields[7]) + + # Extract distance between peaks + dist = float(fields[11]) + + if fields[12] == 'peak-peak': + for gene_annot_1 in genes_annots_1: + gene_1, annot_1 = gene_annot_1.split('_') + for gene_annot_2 in genes_annots_2: + gene_2, annot_2 = gene_annot_2.split('_') + + if (((annot_1 == 'promoter') != (annot_2 == 'promoter')) and + ((gene_1 == gene_2) or (dist < self.max_peak_dist))): + gene = gene_1 if annot_1 == 'promoter' else gene_2 + + if (peak_2 not in corr_dict.get(gene, []) and annot_1 == 'promoter' and + (gene_2 not in gene_body_dict or peak_2 not in gene_body_dict.get(gene_2, []))): + corr_dict.setdefault(gene, [[], []])[0].append(peak_2) + corr_dict[gene][1].append(correlation) + + if (peak_1 not in corr_dict.get(gene, []) and annot_2 == 'promoter' and + (gene_1 not in gene_body_dict or peak_1 not in gene_body_dict.get(gene_1, []))): + corr_dict.setdefault(gene, [[], []])[0].append(peak_1) + corr_dict[gene][1].append(correlation) + + elif fields[12] == 'peak-gene': + gene_2 = genes_annots_2[0] + for gene_annot_1 in genes_annots_1: + gene_1, annot_1 = gene_annot_1.split('_') + + if (gene_1 == gene_2) or (dist < self.max_peak_dist): + gene = gene_1 + + if (peak_1 not in corr_dict.get(gene, []) and annot_1 != 'promoter' and + (gene_1 not in gene_body_dict or peak_1 not in gene_body_dict.get(gene_1, []))): + corr_dict.setdefault(gene, [[], []])[0].append(peak_1) + corr_dict[gene][1].append(correlation) + + elif fields[12] == 'gene-peak': + gene_1 = genes_annots_1[0] + for gene_annot_2 in genes_annots_2: + gene_2, annot_2 = gene_annot_2.split('_') + + if (gene_1 == gene_2) or (dist < self.max_peak_dist): + gene = gene_1 + + if (peak_2 not in corr_dict.get(gene, []) and annot_2 != 'promoter' and + (gene_2 not in gene_body_dict or peak_2 not in gene_body_dict.get(gene_2, []))): + corr_dict.setdefault(gene, [[], []])[0].append(peak_2) + corr_dict[gene][1].append(correlation) + + cre_dict = {} + gene_dict = promoter_dict + promoter_genes = list(promoter_dict.keys()) + + for gene in promoter_genes: + if self.include_gene_body: # add gene-body peaks + if gene in gene_body_dict: + for peak in gene_body_dict[gene]: + if peak not in gene_dict[gene]: + gene_dict[gene].append(peak) + cre_dict[gene] = [] + if gene in corr_dict: # add enhancer peaks + for j, peak in enumerate(corr_dict[gene][0]): + corr = corr_dict[gene][1][j] + if corr > self.min_corr: + if peak not in gene_dict[gene]: + gene_dict[gene].append(peak) + cre_dict[gene].append(peak) + + # Update the enhancer and promoter dictionaries + self._update_cre_and_promoter_dicts(cre_dict=cre_dict, + promoter_dict=promoter_dict) + + linkage_logger.finish_progress(progress_name='compute_linkages_via_cellranger') + + def compute_linkages_via_cicero(self) -> None: + # TODO: Use cicero to filter significant linkages + pass + + def compute_linkages_via_scenicplus(self) -> None: + # TODO: Use scenicplus to filter significant linkages + pass + + def compute_neighbors(self, + atac_lsi_key: str = MDKM.ATAC_OBSM_LSI_KEY, + lr: float = 0.0001, + max_epochs: int = 10, # 10 for debug mode 500 for release, + mv_algorithm: bool = True, + n_comps_atac: int = 20, + n_comps_rna: int = 20, + n_neighbors: int = 20, + pc_key: str = MDKM.ATAC_OBSM_PC_KEY, + random_state: int = 42, + rna_pca_key: str = MDKM.RNA_OBSM_PC_KEY, + scale_factor: float = 1e4, + use_highly_variable: bool = False + ) -> None: + if self.neighbor_method == 'multivi': + self.compute_neighbors_via_multivi( + lr=lr, + max_epochs=max_epochs) + elif self.neighbor_method == 'wnn': + self.weighted_nearest_neighbors( + atac_lsi_key=atac_lsi_key, + n_components_atac=n_comps_atac, + n_components_rna=n_comps_rna, + nn=n_neighbors, + random_state=random_state, + rna_pca_key=rna_pca_key, + use_highly_variable=use_highly_variable) + else: + main_exception(f'Unrecognized method to compute neighbors ({self.neighbor_method}) requested.') + + def compute_neighbors_via_multivi( + self, + lr: float = 0.0001, + max_epochs: int = 500, + n_comps: int = 20, + n_neighbors: int = 20, + ) -> None: + import scvi + main_info('Computing nearest neighbors in latent representation generated by MULTIVI ...', indent_level=1) + nn_logger = LoggerManager.gen_logger('compute_nn_via_mvi') + nn_logger.log_time() + + # Extract the ATAC-seq and RNA-seq portions + atac_adata, rna_adata = self.mdata.mod['atac'], self.mdata.mod['rna'] + n_peaks, n_genes = atac_adata.n_vars, rna_adata.n_vars + + # Ensure that the ATAC- and RNA-seq portions have same number of cells + assert (atac_adata.n_obs == rna_adata.n_obs) + + # Restructure the data into MULTIVI format - we do not perform TF-IDF transformation + # ... X - counts or normalized counts??? + tmp_adata_X = hstack([rna_adata.layers[MDKM.RNA_COUNTS_LAYER], atac_adata.layers[MDKM.ATAC_COUNTS_LAYER]]) + + # ... obs + tmp_adata_obs = rna_adata.obs.copy() + + # ... var + tmp_adata_var = pd.concat([rna_adata.var.copy(), atac_adata.var.copy()], join='inner', axis=0) + + tmp_adata = AnnData(X=tmp_adata_X.copy(), obs=tmp_adata_obs, var=tmp_adata_var) + tmp_adata.layers['counts'] = tmp_adata.X.copy() + + # Get the number of cells + num_cells = tmp_adata.n_obs + + # Generate a random permutation of cell indices + cell_indices = np.random.permutation(num_cells) + + # Determine the split point + split_point = num_cells // 2 + + # Split indices into two groups + cell_indices_1 = cell_indices[:split_point] + cell_indices_2 = cell_indices[split_point:] + + # Subset the AnnData object into two disjoint AnnData objects + tmp_adata_1 = tmp_adata[cell_indices_1].copy() + tmp_adata_1.obs['modality'] = 'first_set' + tmp_adata_2 = tmp_adata[cell_indices_2].copy() + tmp_adata_2.obs['modality'] = 'second_set' + + tmp_adata = scvi.data.organize_multiome_anndatas(tmp_adata_1, tmp_adata_2) + + # Run MULTIVI + # ... setup AnnData object for scvi-tools + main_info('Setting up combined data for MULTIVI', indent_level=2) + scvi.model.MULTIVI.setup_anndata(tmp_adata, batch_key='modality') + + # ... instantiate the SCVI model + main_info('Instantiating MULTIVI model', indent_level=2) + multivi_model = scvi.model.MULTIVI(adata=tmp_adata, n_genes=n_genes, n_regions=n_peaks, n_latent=n_comps) + multivi_model.view_anndata_setup() + + # ... train the model + main_info('Training MULTIVI model', indent_level=2) + multivi_model.train(max_epochs=max_epochs, lr=lr) + + # Extract latent representation + main_info('extracting latent representation for ATAC-seq', indent_level=3) + atac_adata.obsm['X_mvi_latent'] = multivi_model.get_latent_representation().copy() + rna_adata.obsm['X_mvi_latent'] = multivi_model.get_latent_representation().copy() + + # Compute nearest neighbors + main_info('Computing nearest neighbors in MVI latent representation', indent_level=2) + sc.pp.neighbors(rna_adata, n_neighbors=n_neighbors, n_pcs=n_comps, use_rep='X_mvi_latent') + + # Redundantly copy over to atac-seq modality + atac_adata.obsp['distances'] = rna_adata.obsp['distances'].copy() + atac_adata.obsp['connectivities'] = rna_adata.obsp['connectivities'].copy() + atac_adata.uns['neighbors'] = rna_adata.uns['neighbors'].copy() + + # Extract the matrix storing the distances between each cell and its neighbors + cx = coo_matrix(rna_adata.obsp['distances'].copy()) + + # the number of cells + cells = rna_adata.obsp['distances'].shape[0] + + # define the shape of our final results + # and make the arrays that will hold the results + new_shape = (cells, n_neighbors) + nn_dist = np.zeros(shape=new_shape) + nn_idx = np.zeros(shape=new_shape) + + # new_col defines what column we store data in our result arrays + new_col = 0 + + # loop through the distance matrices + for i, j, v in zip(cx.row, cx.col, cx.data): + # store the distances between neighbor cells + nn_dist[i][new_col % n_neighbors] = v + + # for each cell's row, store the row numbers of its neighbor cells + # (1-indexing instead of 0- is a holdover from R multimodalneighbors()) + nn_idx[i][new_col % n_neighbors] = int(j) + 1 + + new_col += 1 + + # Add index and distance to the MultiomeVelocity object + self.nn_idx = nn_idx + self.nn_dist = nn_dist + + # Copy the subset AnnData scRNA-seq and scATAC-seq objects back into the MultiomeVelocity object + self.mdata.mod['atac'] = atac_adata.copy() + self.mdata.mod['rna'] = rna_adata.copy() + + nn_logger.finish_progress(progress_name='compute_nn_via_mvi') + + def compute_second_moments( + self, + adjusted: bool = False + ) -> Tuple[csr_matrix, csr_matrix, csr_matrix]: + # Extract transcriptome + rna_adata = self.mdata.mod['rna'] + + # Obtain connectivities matrix + connectivities = get_connectivities(rna_adata) + + s, u = (csr_matrix(rna_adata.layers[MDKM.RNA_SPLICED_LAYER]), + csr_matrix(rna_adata.layers[MDKM.RNA_UNSPLICED_LAYER])) + if s.shape[0] == 1: + s, u = s.T, u.T + Mss = csr_matrix.dot(connectivities, s.multiply(s)).astype(np.float32).A + Mus = csr_matrix.dot(connectivities, s.multiply(u)).astype(np.float32).A + Muu = csr_matrix.dot(connectivities, u.multiply(u)).astype(np.float32).A + if adjusted: + Mss = 2 * Mss - rna_adata.layers[MDKM.RNA_FIRST_MOMENT_SPLICED_LAYER].reshape(Mss.shape) + Mus = 2 * Mus - rna_adata.layers[MDKM.RNA_FIRST_MOMENT_UNSPLICED_LAYER].reshape(Mus.shape) + Muu = 2 * Muu - rna_adata.layers[MDKM.RNA_FIRST_MOMENT_UNSPLICED_LAYER].reshape(Muu.shape) + return Mss, Mus, Muu + + def compute_velocities(self, + linkage_method: Optional[Literal['cellranger', 'cicero', 'scenic+']] = 'cellranger', + mode: Literal['deterministic', 'stochastic'] = 'deterministic', + neighbor_method: Literal['multivi', 'wnn'] = 'wnn', + num_processes: int = 6) -> None: + if linkage_method is not None: + self.linkage_method = linkage_method + + if neighbor_method is not None: + self.neighbor_method = neighbor_method + + if (self.linkage_method is None) or (self.neighbor_method is None): + main_exception('linkage_method and neighbor_method mus be specified.') + + # Compute linkages + self.compute_linkages() + + # Compute neighbors + self.compute_neighbors() + + # Compute smoother accessibility + self.knn_smoothed_chrom() + + # Compute transcriptomic velocity + self.transcriptomic_velocity(mode=mode, num_processes=num_processes) + + # Compute lift of transcriptomic velocity + self.lift_transcriptomic_velocity(num_processes=num_processes) + + def find_cell_along_integral_curve(self, + num_processes: int = 6, + plot_dir_cosines: bool = False): + # Extract the ATAC- and RNA-seq portions + atac_adata, rna_adata = self.mdata.mod['atac'], self.mdata.mod['rna'] + + expression_mtx = rna_adata.layers[MDKM.RNA_FIRST_MOMENT_SPLICED_LAYER] + velocity_mtx = rna_adata.layers[MDKM.RNA_SPLICED_VELOCITY_LAYER] + + # Extract connectivities + connectivities = get_connectivities(rna_adata) + + # Get non-zero indices from connectivities + nonzero_idx = connectivities.nonzero() + + # Prepare argument list for parallel processing + args_list = [(i, j, expression_mtx, velocity_mtx) + for i, j in zip(nonzero_idx[0], nonzero_idx[1])] + + # Use multiprocessing to compute the results + with Pool(processes=num_processes) as pool: + results = pool.map(direction_cosine, args_list) + + # Convert results to sparse matrix + data = [cosines for _, _, cosines in results] + i_indices = [i_idx for i_idx, _, _ in results] + j_indices = [j_idx for _, j_idx, _ in results] + direction_cosines = csr_matrix((data, (i_indices, j_indices)), shape=connectivities.shape) + + # Find nearest neighbor along integral curve + integral_neighbors = direction_cosines.argmax(axis=1).A.flatten() + + if plot_dir_cosines: + # Summarize statistics about the best direction cosines + max_dir_cosines = direction_cosines.max(axis=1).A.flatten() + plt.hist(max_dir_cosines, bins=25) + plt.title('Frequencies of direction cosines') + plt.xlabel('Direction Cosines') + plt.ylabel('Frequency') + plt.show() + + # Save the results in this class + # TODO: Consider whether to add to AnnData objects + self.cosine_similarities = direction_cosines + self.integral_neighbors = {int(idx): int(integral_neighbor) + for idx, integral_neighbor in enumerate(integral_neighbors)} + + @classmethod + def from_mdata(cls, + mdata: MuData): + # Deep copy MuData object for export + atac_adata, rna_adata = mdata.mod['atac'].copy(), mdata.mod['rna'].copy() + + # ... from atac + # ... bit of kludge: dictionaries appear to require type casting after deserialization + deser_cre_dict = atac_adata.uns['cre_dict'].copy() + cre_dict = {} + for gene, cre_list in deser_cre_dict.items(): + cre_dict[str(gene)] = [str(cre) for cre in cre_list] + # ... bit of kludge: dictionaries appear to require type casting after deserialization + deser_promoter_dict = atac_adata.uns['promoter_dict'] + promoter_dict = {} + for gene, promoter_list in deser_promoter_dict.items(): + promoter_dict[str(gene)] = [str(promoter) for promoter in promoter_list] + + multi_dynamo_kwargs = atac_adata.uns['multi_dynamo_kwargs'] + include_gene_body = multi_dynamo_kwargs.get('include_gene_body', False) + linkage_fn = multi_dynamo_kwargs.get('linkage_fn', 'feature_linkage.bedpe') + linkage_method = multi_dynamo_kwargs.get('linkage_method', 'cellranger') + max_peak_dist = multi_dynamo_kwargs.get('max_peak_dist', 10000) + min_corr = multi_dynamo_kwargs.get('min_corr', 0.5) + peak_annot_fn = multi_dynamo_kwargs.get('min_corr', 'peak_annotation.tsv') + + # ... from rna + nn_dist = rna_adata.obsm['multi_dynamo_nn_dist'] + nn_idx = rna_adata.obsm['multi_dynamo_nn_idx'] + + cosine_similarities = rna_adata.obsp['cosine_similarities'] + # ... bit of kludge: dictionaries appear to require type casting after deserialization + integral_neighbors = {int(k): int(v) for k,v in rna_adata.uns['integral_neighbors'].items()} + + multi_dynamo_kwargs = rna_adata.uns['multi_dynamo_kwargs'] + neighbor_method = multi_dynamo_kwargs.get('neighbor_method', 'multivi') + + multi_velocity = cls(mdata=mdata, + cre_dict=cre_dict, + cosine_similarities=cosine_similarities, + include_gene_body=include_gene_body, + integral_neighbors=integral_neighbors, + linkage_fn=linkage_fn, + linkage_method=linkage_method, + max_peak_dist=max_peak_dist, + min_corr=min_corr, + nn_dist=nn_dist, + nn_idx=nn_idx, + neighbor_method=neighbor_method, + peak_annot_fn=peak_annot_fn, + promoter_dict=promoter_dict) + + return multi_velocity + + def get_cre_dict(self): + return self._cre_dict + + def get_mdata(self): + return self.mdata + + def get_nn_dist(self): + return self.nn_dist + + def get_nn_idx(self): + return self.nn_idx + + def get_promoter_dict(self): + return self._promoter_dict + + # knn_smoothed_chrom - method adapted from MultiVelo + def knn_smoothed_chrom(self, + nn: int = 20 + ) -> None: + # Consistency checks + nn_idx = None + if self.nn_idx is None: + main_exception('Missing KNN index matrix. Try calling compute_neighbors first.') + else: + nn_idx = self.nn_idx + + nn_dist = None + if self.nn_dist is None: + main_exception('Missing KNN distance matrix. Try calling compute_neighbors first.') + else: + nn_dist = self.nn_dist + + atac_adata, rna_adata = self.mdata.mod['atac'], self.mdata.mod['rna'] + n_cells = atac_adata.n_obs + + if (nn_idx.shape[0] != n_cells) or (nn_dist.shape[0] != n_cells): + main_exception('Number of rows of KNN indices does not equal to number of cells.') + + X = coo_matrix(([], ([], [])), shape=(n_cells, 1)) + from umap.umap_ import fuzzy_simplicial_set + conn, sigma, rho, dists = fuzzy_simplicial_set(X=X, + n_neighbors=nn, + random_state=None, + metric=None, + knn_indices=nn_idx-1, + knn_dists=nn_dist, + return_dists=True) + + conn = conn.tocsr().copy() + n_counts = (conn > 0).sum(1).A1 + if nn is not None and nn < n_counts.min(): + conn = top_n_sparse(conn, nn) + conn.setdiag(1) + conn_norm = conn.multiply(1.0 / conn.sum(1)).tocsr() + + # Compute first moment of chromatin accessibility + atac_adata.layers[MDKM.RNA_FIRST_MOMENT_CHROM_LAYER] = \ + csr_matrix.dot(conn_norm, atac_adata.layers['counts']).copy() + + # Overwrite ATAC- and RNA-seq connectivities + atac_adata.obsp['connectivities'] = conn.copy() + rna_adata.obsp['connectivities'] = conn.copy() + + self.mdata.mod['atac'] = atac_adata.copy() + self.mdata.mod['rna'] = rna_adata.copy() + + def lift_transcriptomic_velocity(self, + num_processes: int = 6): + # Compute integral neighbors + main_info('Starting computation of integral neighbors ...') + self.find_cell_along_integral_curve(num_processes=num_processes) + + # Extract the ATAC- and RNA-seq data + atac_adata, rna_adata = self.mdata.mod['atac'], self.mdata.mod['rna'] + + # Retrieve specified layer for chromatin state + chromatin_state = atac_adata.layers[MDKM.ATAC_TFIDF_LAYER] + + cosine_similarities = None + if self.cosine_similarities is None: + main_exception('Please compute integral neighbors before calling lift_transcriptomic_velocity.') + else: + cosine_similarities = self.cosine_similarities + + # Retrieve specified layer for expression matrix + expression_mtx = rna_adata.layers[MDKM.RNA_FIRST_MOMENT_SPLICED_LAYER] + + integral_neighbors = None + if self.integral_neighbors is None: + main_exception('Please compute integral neighbors before calling lift_transcriptomic_velocity.') + else: + integral_neighbors = self.integral_neighbors + + # Retrieve specified layer for the velocity matrix + velocity_mtx = rna_adata.layers[MDKM.RNA_SPLICED_VELOCITY_LAYER] + + # Prepare argument list for parallel processing + args_list = [(i, j, chromatin_state, cosine_similarities, expression_mtx, velocity_mtx[i, :]) + for i, j in integral_neighbors.items()] + + # Use multiprocessing to compute the results + with Pool(processes=num_processes) as pool: + results = pool.map(lifted_chromatin_velocity, args_list) + + # Convert results to sparse matrix + chromatin_velocity_mtx = np.zeros(chromatin_state.shape) + for i, chromatin_velocity in results: + chromatin_velocity_mtx[i, :] = chromatin_velocity + + atac_adata.layers[MDKM.ATAC_CHROMATIN_VELOCITY_LAYER] = chromatin_velocity_mtx + + # Copy the scATAC-seq AnnData object into the MultiomeVelocity object + self.mdata.mod['atac'] = atac_adata.copy() + + def _restrict_dicts_to_gene_list(self, + gene_list: List[str], + cre_dict: Dict[str, List[str]] = None, + promoter_dict: Dict[str, List[str]] = None + ) -> Tuple[List[str], List[str], Dict[str, List[str]], Dict[str, List[str]]]: + # Elements present in scATAC-seq data + present_elements = self.atac_elements() + + if len(gene_list) == 0: + main_exception('Require non-trivial gene_list for _restrict_to_gene_list.') + + if len(cre_dict) == 0 or len(promoter_dict) == 0: + main_exception('Require non-trivial enhancer and promoter dicts for _restrict_to_gene_list.') + + # Elements associated to genes in gene_list and present in scATAC-seq data + shared_elements = [] + + # Dictionary from gene to element list for all genes present in gene_list and with + # corresponding elements in enhancer dicts + shared_cre_dict = {} + for gene, element_list in cre_dict.items(): + if gene in gene_list: + shared_elements_for_gene =\ + [element for element in element_list if element in present_elements] + shared_elements_for_gene = list(set(shared_elements_for_gene)) + + shared_elements += shared_elements_for_gene + shared_cre_dict[gene] = shared_elements_for_gene + + # Add all promoters for genes in gene_list + shared_promoter_dict = {} + for gene, element_list in promoter_dict.items(): + if gene in gene_list: + shared_elements_for_gene = \ + [element for element in element_list if element in present_elements] + shared_elements_for_gene = list(set(shared_elements_for_gene)) # Bit pedantic ... + + shared_elements += shared_elements_for_gene + shared_promoter_dict[gene] = shared_elements_for_gene + + # Make elements into unique list + shared_elements = list(set(shared_elements)) + + # Determine which genes actually have elements present in the scATAC-seq data + all_dict_genes = list(set(list(shared_cre_dict.keys()) + list(shared_promoter_dict.keys()))) + shared_genes = [] + for gene in all_dict_genes: + enhancers_for_gene = len(shared_cre_dict.get(gene, [])) > 0 + + promoters_for_gene = len(shared_promoter_dict.get(gene, [])) > 0 + + if enhancers_for_gene or promoters_for_gene: + shared_genes.append(gene) + + # Clean up trivial entries in dicts + if not enhancers_for_gene and gene in shared_cre_dict: + del shared_cre_dict[gene] + + if not promoters_for_gene and gene in shared_promoter_dict: + del shared_promoter_dict[gene] + + shared_genes = list(set(shared_genes)) + + return shared_elements, shared_genes, shared_cre_dict, shared_promoter_dict + + def restrict_to_gene_list(self, + gene_list: List[str] = None, + subset: bool = False) -> Tuple[List[str], List[str]]: + # Extract genes from scRNA-seq data + rna_genes = self.rna_genes() + + if gene_list is None: + # If no gene_list offered, then use the genes found in scRNA-seq dataset + gene_list = rna_genes + else: + # Otherwise ensure gene is contained within the shared list + if not set(gene_list).issubset(set(rna_genes)): + main_exception('gene_list is not a subset of genes found in scRNA-seq dataset.') + + shared_elements, shared_genes, shared_enhancer_dict, shared_promoter_dict = \ + self._restrict_dicts_to_gene_list(gene_list=gene_list, + cre_dict=self._cre_dict, + promoter_dict=self._promoter_dict) + + if subset: + # Subset the scATAC-seq data to shared elements + self.mdata.mod['atac'] = self.mdata.mod['atac'][:, shared_elements].copy() + + # Subset the scRNA_seq data to shared genes + self.mdata.mod['rna'] = self.mdata.mod['rna'][:, shared_genes].copy() + + return shared_elements, shared_genes + + def rna_genes(self): + return self.mdata.mod['rna'].var_names.tolist() + + def to_mdata(self) -> MuData: + # Deep copy MuData object for export + atac_adata, rna_adata = self.mdata.mod['atac'].copy(), self.mdata.mod['rna'].copy() + + # ... embellish atac + atac_adata.uns['cre_dict'] = self._cre_dict.copy() + atac_adata.uns['promoter_dict'] = self._promoter_dict.copy() + atac_adata.uns['multi_dynamo_kwargs'] = {'include_gene_body': self.include_gene_body, + 'linkage_fn': self.linkage_fn, + 'linkage_method': self.linkage_method, + 'max_peak_dist': self.max_peak_dist, + 'min_corr': self.min_corr, + 'peak_annot_fn': self.peak_annot_fn} + + # ... embellish rna + rna_adata.obsm['multi_dynamo_nn_dist'] = self.nn_dist.copy() + rna_adata.obsm['multi_dynamo_nn_idx'] = self.nn_idx.copy() + + rna_adata.obsp['cosine_similarities'] = self.cosine_similarities.copy() + rna_adata.uns['integral_neighbors'] = {str(k): str(v) for k,v in self.integral_neighbors.items()}.copy() + rna_adata.uns['multi_dynamo_kwargs'] = {'neighbor_method': self.neighbor_method} + + return MuData({'atac': atac_adata, 'rna': rna_adata}) + + # transcriptomic_velocity: this could really be any of the many methods that already exist, including those in + # dynamo and we plan to add this capability later. + def transcriptomic_velocity(self, + adjusted: bool = False, + min_r2: float = 1e-2, + mode: Literal['deterministic', 'stochastic'] = 'deterministic', + n_neighbors: int = 20, + n_pcs: int = 20, + num_processes: int = 6, + outlier: float = 99.8): + # Extract transcriptome and chromatin accessibility + atac_adata, rna_adata = self.mdata.mod['atac'], self.mdata.mod['rna'] + + # Assemble dictionary of arguments for fits + fit_args = {'min_r2': min_r2, + 'mode': mode, + 'n_pcs': n_pcs, + 'n_neighbors': n_neighbors, + 'outlier': outlier} + + # Obtain connectivities from the scRNA-seq object + rna_conn = rna_adata.obsp['connectivities'] + + # Compute moments for transcriptome data + main_info('computing moments for transcriptomic data ...') + rna_adata.layers[MDKM.RNA_FIRST_MOMENT_SPLICED_LAYER] = ( + csr_matrix.dot(rna_conn, csr_matrix(rna_adata.layers[MDKM.RNA_SPLICED_LAYER])) + .astype(np.float32) + .toarray() + ) + rna_adata.layers[MDKM.RNA_FIRST_MOMENT_UNSPLICED_LAYER] = ( + csr_matrix.dot(rna_conn, csr_matrix(rna_adata.layers[MDKM.RNA_UNSPLICED_LAYER])) + .astype(np.float32) + .toarray() + ) + + # Initialize select second moments for the transcriptomic data + Mss, Mus, Muu = None, None, None + if mode == 'stochastic': + main_info('computing second moments', indent_level=2) + Mss, Mus, Muu = self.compute_second_moments(adjusted=adjusted) + + rna_adata.layers[MDKM.RNA_SECOND_MOMENT_SS_LAYER] = Mss.copy() + rna_adata.layers[MDKM.RNA_SECOND_MOMENT_US_LAYER] = Mus.copy() + rna_adata.layers[MDKM.RNA_SECOND_MOMENT_UU_LAYER] = Muu.copy() + + if 'highly_variable' in rna_adata.var: + main_info('using highly variable genes', indent_level=2) + rna_gene_list = rna_adata.var_names[rna_adata.var['highly_variable']].values + else: + rna_gene_list = rna_adata.var_names.values[ + (~np.isnan(np.asarray(rna_adata.layers[MDKM.RNA_FIRST_MOMENT_UNSPLICED_LAYER].sum(0)) + .reshape(-1) + if issparse(rna_adata.layers[MDKM.RNA_FIRST_MOMENT_UNSPLICED_LAYER]) + else np.sum(rna_adata.layers[MDKM.RNA_FIRST_MOMENT_UNSPLICED_LAYER], axis=0))) + & (~np.isnan(np.asarray(rna_adata.layers[MDKM.RNA_FIRST_MOMENT_SPLICED_LAYER].sum(0)) + .reshape(-1) + if issparse(rna_adata.layers[MDKM.RNA_FIRST_MOMENT_SPLICED_LAYER]) + else np.sum(rna_adata.layers[MDKM.RNA_FIRST_MOMENT_SPLICED_LAYER], axis=0)))] + + # Restrict to genes with corresponding peaks in scATAC-seq data + shared_elements, shared_genes = self.restrict_to_gene_list(gene_list=rna_gene_list, + subset=True) + + n_fitted_genes = len(shared_genes) + if n_fitted_genes: + main_info(f'{n_fitted_genes} genes will be fitted') + else: + main_exception('None of the genes specified are in the adata object') + + velo_s = np.zeros((rna_adata.n_obs, n_fitted_genes)) + variance_velo_s = np.zeros((rna_adata.n_obs, n_fitted_genes)) + gammas = np.zeros(n_fitted_genes) + r2s = np.zeros(n_fitted_genes) + losses = np.zeros(n_fitted_genes) + + u_mat = (rna_adata[:, shared_genes].layers[MDKM.RNA_FIRST_MOMENT_UNSPLICED_LAYER].A + if issparse(rna_adata.layers[MDKM.RNA_FIRST_MOMENT_UNSPLICED_LAYER]) + else rna_adata[:, shared_genes].layers[MDKM.RNA_FIRST_MOMENT_UNSPLICED_LAYER]) + s_mat = (rna_adata[:, shared_genes].layers[MDKM.RNA_FIRST_MOMENT_SPLICED_LAYER].A + if issparse(rna_adata.layers[MDKM.RNA_FIRST_MOMENT_SPLICED_LAYER]) + else rna_adata[:, shared_genes].layers[MDKM.RNA_FIRST_MOMENT_SPLICED_LAYER]) + + M_c = csr_matrix(atac_adata[:, shared_elements].layers[MDKM.RNA_FIRST_MOMENT_CHROM_LAYER]) \ + if issparse(atac_adata.layers[MDKM.RNA_FIRST_MOMENT_CHROM_LAYER]) else \ + atac_adata[:, shared_elements].layers[MDKM.RNA_FIRST_MOMENT_CHROM_LAYER] + c_mat = M_c.toarray() if issparse(M_c) else M_c + + # Create dictionary from gene to index + gene_to_idx_dict = {gene: idx for idx, gene in enumerate(shared_genes)} + + # Create dictionary from peak to index + peak_to_idx_dict = {element: idx for idx, element in enumerate(shared_elements)} + + # Create unified gene to list of elements dict + tmp_elements_for_gene_dict = {} + for gene, element_list in self._cre_dict.items(): + tmp_elements_for_gene_dict[gene] = tmp_elements_for_gene_dict.setdefault(gene, []) + element_list + + for gene, element_list in self._promoter_dict.items(): + tmp_elements_for_gene_dict[gene] = tmp_elements_for_gene_dict.setdefault(gene, []) + element_list + + elements_for_gene_dict = {} + for gene, element_list in tmp_elements_for_gene_dict.items(): + elements_for_gene_dict[gene] = list(set(element_list)) + + # Create dictionary from gene indices to list of peaks by indices + gene_idx_to_peak_idx = {gene_to_idx_dict[gene]: [peak_to_idx_dict[peak] for peak in peak_list] + for gene, peak_list in elements_for_gene_dict.items()} + + # Define batch arguments + batches_of_arguments = [] + for i in range(n_fitted_genes): + gene = shared_genes[i] + peak_idx = gene_idx_to_peak_idx[i] + + batches_of_arguments.append( + (c_mat[:, peak_idx], + u_mat[:, i], + s_mat[:, i], + None if mode == 'deterministic' else Mss[:, i], + None if mode == 'deterministic' else Mus[:, i], + None if mode == 'deterministic' else Muu[:, i], + fit_args, + mode, + gene)) + + # Carry out fits in parallel + with Pool(processes=num_processes) as pool: + results = pool.starmap(regression, batches_of_arguments) + + # Reformat the results + for idx, (velocity, velocity_variance, gamma, r2, loss) in enumerate(results): + gammas[idx] = gamma + r2s[idx] = r2 + losses[idx] = loss + velo_s[:, idx] = smooth_scale(rna_conn, velocity) + + if mode == 'stochastic': + variance_velo_s[:, idx] = smooth_scale(rna_conn, + velocity_variance) + + # Determine which fits failed + kept_genes = [gene for gene, loss in zip(shared_genes, losses) if loss != np.inf] + if len(kept_genes) == 0: + main_exception('None of the genes were fit due to low quality.') + + # Subset the transcriptome to the genes for which the fits were successful + rna_copy = rna_adata[:, kept_genes].copy() + + # Add the fit results + keep = [loss != np.inf for loss in losses] + + # ... layers + rna_copy.layers[MDKM.RNA_SPLICED_VELOCITY_LAYER] = csr_matrix(velo_s[:, keep]) + if mode == 'stochastic': + rna_copy.layers['variance_velo_s'] = csr_matrix(variance_velo_s[:, keep]) + + # ... .obsp + rna_copy.obsp['_RNA_conn'] = rna_conn + + # ... .uns + # ... ... augment the dynamical and normalization information + dyn_and_norm_info = rna_copy.uns['pp'].copy() + dyn_and_norm_info['experiment_total_layers'] = None + dyn_and_norm_info['layers_norm_method'] = None + dyn_and_norm_info['tkey'] = None + rna_copy.uns['pp'] = dyn_and_norm_info.copy() + + dynamics = {'filter_gene_mode': 'final', + 't': None, + 'group': None, + 'X_data': None, + 'X_fit_data': None, + 'asspt_mRNA': 'ss', + 'experiment_type': dyn_and_norm_info.get('experiment_type', 'conventional'), + 'normalized': True, + 'model': mode, + 'est_method': 'gmm', # Consider altering + 'has_splicing': dyn_and_norm_info.get('has_splicing', True), + 'has_labeling': dyn_and_norm_info.get('has_labeling', False), + 'splicing_labeling': dyn_and_norm_info.get('splicing_labeling', False), + 'has_protein': dyn_and_norm_info.get('has_protein', False), + 'use_smoothed': True, + 'NTR_vel': False, + 'log_unnormalized': True, + # Ensure X is indeed log normalized (compute exp1m, sum and check rowsums) + 'fraction_for_deg': False} + rna_copy.uns['dynamics'] = dynamics.copy() + + rna_copy.uns['velo_s_params'] = {'mode': mode, + 'fit_offset': False, + 'perc': outlier} + rna_copy.uns['velo_s_params'].update(fit_args) + + # ... ... These are the column names for the array in .varm['vel_params'] + rna_copy.uns['vel_params_names'] = ['beta', 'gamma', 'half_life', 'alpha_b', 'alpha_r2', 'gamma_b', + 'gamma_r2', 'gamma_logLL', 'delta_b', 'delta_r2', 'bs', 'bf', + 'uu0', 'ul0', 'su0', 'sl0', 'U0', 'S0', 'total0'] + + # ... .var + rna_copy.var['fit_gamma'] = gammas[keep] + rna_copy.var['fit_loss'] = losses[keep] + rna_copy.var['fit_r2'] = r2s[keep] + + # Introduce var['use_for_dynamics'] for dynamo + v_gene_ind = rna_copy.var['fit_r2'] >= min_r2 + rna_copy.var['use_for_dynamics'] = v_gene_ind + rna_copy.var['velo_s_genes'] = v_gene_ind + + # ... .varm + vel_params_array = np.full((rna_copy.shape[1], len(rna_copy.uns['vel_params_names'])), np.nan) + + # ... ... ... transfer 'gamma' + gamma_index = np.where(np.array(rna_copy.uns['vel_params_names']) == 'gamma')[0][0] + vel_params_array[:, gamma_index] = rna_copy.var['fit_gamma'] + + # ... ... ... transfer 'gamma_r2' + gamma_r2_index = np.where(np.array(rna_copy.uns['vel_params_names']) == 'gamma_r2')[0][0] + vel_params_array[:, gamma_r2_index] = rna_copy.var['fit_r2'] + + rna_copy.varm['vel_params'] = vel_params_array + + # Copy the subset AnnData scRNA-seq and scATAC-seq objects back into the MultiomeVelocity object + self.mdata.mod['rna'] = rna_copy.copy() + + # Filter the scATAC-seq peaks to retain only those corresponding to fit genes + shared_elements, shared_genes = self.restrict_to_gene_list(gene_list=kept_genes, + subset=True) + + # Confer same status to element corresponding to genes declared as 'use_for_dynamics' + v_genes = [gene for gene, v_ind in zip(shared_genes, v_gene_ind) if v_ind] + # v_elements, v_genes = self.restrict_to_gene_list(gene_list=v_genes, subset=False) + # v_element_ind = [element in v_elements for element in shared_elements] + # TODO: Need to special case when no genes rise to significance + v_element_ind = [True for _ in range(atac_adata.n_vars)] + + # Introduce var['use_for_dynamics'] for dynamo + # TODO: This does NOT appear to work properly yet - so left permissive + atac_adata.var['use_for_dynamics'] = v_element_ind + + self.mdata.mod['atac'] = atac_adata.copy() + + def _update_cre_and_promoter_dicts(self, + cre_dict: Dict[str, List[str]] = None, + promoter_dict: Dict[str, List[str]] = None): + if cre_dict is not None or promoter_dict is not None: + # Should only have exogenous enhancer and promoter dicts if none are present in object + if self._cre_dict is not None or self._promoter_dict is not None: + main_exception('Should only specify exogenous CRE and promoter dicts if none are present in object.') + else: + # Extract the dictionaries + cre_dict = self._cre_dict + promoter_dict = self._promoter_dict + + # Extract the RNA genes + rna_genes = self.rna_genes() + + # ... determine which genes are actually present in the scATAC-seq data and for these + # which elements are present + shared_elements, shared_genes, shared_cre_dict, shared_promoter_dict = \ + self._restrict_dicts_to_gene_list(gene_list=rna_genes, + cre_dict=cre_dict, + promoter_dict=promoter_dict) + + if len(shared_genes) == 0: + main_exception('scATAC-seq data and scRNA-seq data do NOT share any genes.') + + # Subset the scATAC-seq data to shared elements + self.mdata.mod['atac'] = self.mdata.mod['atac'][:, shared_elements].copy() + + # Subset the scRNA_seq data to shared genes + self.mdata.mod['rna'] = self.mdata.mod['rna'][:, shared_genes].copy() + + # Initialize the original enhancer and promoter dicts + self._cre_dict = shared_cre_dict + self._promoter_dict = shared_promoter_dict + + def weighted_nearest_neighbors( + self, + atac_lsi_key: str = MDKM.ATAC_OBSM_LSI_KEY, + n_components_atac: int = 20, + n_components_rna: int = 20, + nn: int = 20, + random_state: int = 42, + rna_pca_key: str = MDKM.RNA_OBSM_PC_KEY, + use_highly_variable: bool = False): + main_info('Starting computation of weighted nearest neighbors ...', indent_level=1) + nn_logger = LoggerManager.gen_logger('weighted_nearest_neighbors') + nn_logger.log_time() + + # Restrict to shared genes and their elements - as tied together by the attribution of CRE to genes + shared_elements, shared_genes = self.restrict_to_gene_list(subset=True) + + # Extract scATAC-seq and scRNA-seq data + atac_adata = self.mdata.mod['atac'][:, shared_elements].copy() + rna_adata = self.mdata.mod['rna'][:, shared_genes].copy() + + if rna_pca_key not in rna_adata.obsm: + # TODO: Consider normalizing counts here, if needed + + # Carry out PCA on scRNA-seq data + main_info('computing PCA on normalized and scaled scRNA-seq data', indent_level=2) + sc.tl.pca(rna_adata, + n_comps=n_components_rna, + random_state=random_state, + use_highly_variable=use_highly_variable) + + if atac_lsi_key not in atac_adata.obsm: + # Carry out singular value decomposition on the scATAC-seq data + main_info('computing latent semantic indexing of scATAC-seq data ...') + lsi = svds(atac_adata.X, k=n_components_atac) + + # get the lsi result + atac_adata.obsm[atac_lsi_key] = lsi[0] + + # Cross copy the LSI decomposition + rna_adata.obsm[atac_lsi_key] = atac_adata.obsm[atac_lsi_key] + + # Use Dylan Kotliar's python implementation of + # TODO: As alternative to PCA could use the latent space from variational autoencoder. + WNNobj = pyWNN(rna_adata, + reps=[rna_pca_key, atac_lsi_key], + npcs=[n_components_rna, n_components_atac], + n_neighbors=nn, + seed=42) + + adata_seurat = WNNobj.compute_wnn(rna_adata) + + # extract the matrix storing the distances between each cell and its neighbors + cx = coo_matrix(adata_seurat.obsp["WNN_distance"]) + + # the number of cells + cells = adata_seurat.obsp['WNN_distance'].shape[0] + + # define the shape of our final results + # and make the arrays that will hold the results + new_shape = (cells, nn) + nn_dist = np.zeros(shape=new_shape) + nn_idx = np.zeros(shape=new_shape) + + # new_col defines what column we store data in + # our result arrays + new_col = 0 + + # loop through the distance matrices + for i, j, v in zip(cx.row, cx.col, cx.data): + + # store the distances between neighbor cells + nn_dist[i][new_col % nn] = v + + # for each cell's row, store the row numbers of its neighbor cells + # (1-indexing instead of 0- is a holdover from R multimodalneighbors()) + nn_idx[i][new_col % nn] = int(j) + 1 + + new_col += 1 + + # Add index and distance to the MultiomeVelocity object + self.nn_idx = nn_idx + self.nn_dist = nn_dist + + # Revert to canonical naming of connectivities and distances + # ... .uns['neighbors'] + atac_adata.uns['neighbors'] = adata_seurat.uns['WNN'].copy() + rna_adata.uns['neighbors'] = adata_seurat.uns['WNN'].copy() + del adata_seurat.uns['WNN'] + + # ... .obsp['connectivities'] + atac_adata.obsp['connectivities'] = adata_seurat.obsp['WNN'].copy() + rna_adata.obsp['connectivities'] = adata_seurat.obsp['WNN'].copy() + del adata_seurat.obsp['WNN'] + + # ... .obsp['distances'] + atac_adata.obsp['distances'] = adata_seurat.obsp['WNN_distance'].copy() + rna_adata.obsp['distances'] = adata_seurat.obsp['WNN_distance'].copy() + del adata_seurat.obsp['WNN_distance'] + + # Copy the subset AnnData scRNA-seq and scATAC-seq objects back into the MultiomeVelocity object + self.mdata.mod['atac'] = atac_adata.copy() + self.mdata.mod['rna'] = rna_adata.copy() + + def write(self, + filename: Union[PathLike, str]) -> None: + export_mdata = self.to_mdata() + export_mdata.write_h5mu(filename) diff --git a/dynamo/multivelo/old_MultiomicVectorField.py b/dynamo/multivelo/old_MultiomicVectorField.py new file mode 100644 index 000000000..8e21443b7 --- /dev/null +++ b/dynamo/multivelo/old_MultiomicVectorField.py @@ -0,0 +1,445 @@ +import anndata as ad +from anndata import AnnData +import matplotlib.pyplot as plt +from mudata import MuData +import numpy as np +import pandas as pd +from scipy.sparse import csr_matrix +from typing import ( + Dict, + List, + Literal, + Optional, + Tuple, + Union, +) + +# Imports from MultiDynamo +from .MultiConfiguration import MDKM +from .old_MultiVelocity import MultiVelocity + +from ..pl import cell_wise_vectors, streamline_plot, topography +from ..pd import fate, perturbation +from ..mv import animate_fates +from ..pp import pca +from ..tl import reduceDimension, cell_velocities +from ..vf import VectorField + + +# Helper functions +def compute_animations(adata, + cell_type_key: str, + cores: int = 6, + delta_epsilon: float = 0.25, + epsilon: float = 1.0, + max_tries: int = 10, + n_cells: int = 100, + n_earliest: int = 30, + prefix: str = None, + skip_cell_types: List = [] + ) -> None: + # Extract cell metadata + cell_metadata = adata.obs.copy() + + # Add UMAP + cell_metadata['umap_1'] = adata.obsm['X_umap'][:, 0] + cell_metadata['umap_2'] = adata.obsm['X_umap'][:, 1] + + # Group by cell_type_key and find the rows with the maximal 'rotated_umap_1' + grouped = cell_metadata.groupby(cell_type_key) + + # Find the mean locations of cell types + top_indices_1, top_indices_2 = {}, {} + for cell_type, celltype_data in grouped: + subset_df = celltype_data.nsmallest(n_cells, 'umap_1') + top_indices_1[cell_type] = subset_df['umap_1'].mean() + subset_df = celltype_data.nlargest(n_cells, 'umap_2') + top_indices_2[cell_type] = subset_df['umap_2'].mean() + + cell_types = cell_metadata[cell_type_key].cat.categories.tolist() + progenitor_list = [] + + for cell_type in cell_types: + if (skip_cell_types is not None) and (cell_type in skip_cell_types): + continue + + print(f'Computing animation for cell type {cell_type}') + + # Find the progenitors + n_tries, progenitors = 1, [] + while len(progenitors) < n_cells and n_tries < max_tries + 1: + progenitors = adata.obs_names[adata.obs.celltype.isin([cell_type]) & + (abs(cell_metadata['umap_1'] - top_indices_1[cell_type]) < ( + epsilon + n_tries * delta_epsilon)) & + (abs(cell_metadata['umap_2'] - top_indices_2[cell_type]) < ( + epsilon + n_tries * delta_epsilon))] + n_tries += 1 + + if len(progenitors) >= n_earliest: + # Progenitors for all subset simulation + print(f'Adding {n_earliest} cells of type {cell_type}.') + progenitor_list.extend(progenitors[0:min(len(progenitors), n_earliest)]) + + # Progenitors for this animation + # progenitors = progenitors[0:min(len(progenitors), n_cells)] + + # Determine their fate + # dyn.pd.fate(adata, basis='umap_perturbation', init_cells=progenitors, interpolation_num=100, + # direction='forward', inverse_transform=False, average=False, cores=6) + + # Compute the animation + # animation_fn = cell_type + '_perturbed_fate_ani.mp4' + # animation_fn = animation_fn.replace('/', '-') + # dyn.mv.animate_fates(adata, basis='umap_perturbation', color='celltype', n_steps=100, + # interval=100, save_show_or_return='save', + # save_kwargs={'filename': animation_fn, + # 'writer': 'ffmpeg'}) + + # Determine fate of progenitor_list + fate(adata, basis='umap_perturbation', init_cells=progenitor_list, interpolation_num=100, + direction='forward', inverse_transform=False, average=False, cores=cores) + + # Compute the animation + file_name = prefix + '_perturbation.mpeg' + file_name = file_name.replace(':', '-') + file_name = file_name.replace('/', '-') + animate_fates(adata, basis='umap_perturbation', color='celltype', n_steps=100, + interval=100, save_show_or_return='save', + save_kwargs={'filename': file_name, + 'writer': 'ffmpeg'}) + +def genes_and_elements_for_dynamics(atac_adata: AnnData, + rna_adata: AnnData, + cre_dict: Dict[str, List[str]], + promoter_dict: Dict[str, List[str]], + min_r2: float = 0.01) -> List[bool]: + # Get fit parameters + vel_params_array = rna_adata.varm['vel_params'] + + # Extract 'gamma_r2' + gamma_r2_index = np.where(np.array(rna_adata.uns['vel_params_names']) == 'gamma_r2')[0][0] + r2 = vel_params_array[:, gamma_r2_index] + + # Set genes for dynamics + genes_for_dynamics = rna_adata.var_names[r2 > min_r2].to_list() + use_for_dynamics = [gene in genes_for_dynamics for gene in rna_adata.var_names.to_list()] + + # Compute elements for dynamics + cre_for_dynamics = [] + for gene, cre_list in cre_dict.items(): + if gene in genes_for_dynamics: + cre_for_dynamics += cre_list + + for gene, promoter_list in promoter_dict.items(): + if gene in genes_for_dynamics: + cre_for_dynamics += promoter_list + + use_for_dynamics += [element in cre_for_dynamics for element in atac_adata.var_names] + + return use_for_dynamics + + +class MultiomicVectorField: + def __init__(self, + multi_velocity: Union[MultiVelocity, MuData], + min_gamma: float = None, + min_r2: float = 0.01, + rescale_velo_c: float = 1.0): + # This is basically an adapter from multiomic data to format where we can borrow tools previously developed + # in dynamo. + if isinstance(multi_velocity, MuData): + multi_velocity = MultiVelocity.from_mdata(multi_velocity) + + # ... mdata + mdata = multi_velocity.get_mdata() + atac_adata, rna_adata = mdata.mod['atac'], mdata.mod['rna'] + + # ... CRE dictionary + cre_dict = multi_velocity.get_cre_dict() + + # ... promoter dictionary + promoter_dict = multi_velocity.get_promoter_dict() + + # To estimate the multi-omic velocity field, we assemble a single AnnData object from the following components + # NOTE: In our descriptions below *+* signifies the directo sum of two vector spaces + # ... .layers + # ... ... counts: counts => rna counts *+* atac counts + rna_counts = rna_adata.layers[MDKM.RNA_COUNTS_LAYER].toarray().copy() + atac_counts = atac_adata.layers[MDKM.ATAC_COUNTS_LAYER].toarray().copy() + counts = np.concatenate((rna_counts, atac_counts), axis=1) + + # ... ... raw: spliced, unspliced ==> spliced *+* chromatin, unspliced *+* 0 + chromatin_state = atac_adata.layers[MDKM.ATAC_COUNTS_LAYER].toarray().copy() + spliced = rna_adata.layers[MDKM.RNA_SPLICED_LAYER].toarray().copy() + unspliced = rna_adata.layers[MDKM.RNA_UNSPLICED_LAYER].toarray().copy() + + spliced = np.concatenate((spliced, chromatin_state), axis=1) + unspliced = np.concatenate((unspliced, np.zeros(chromatin_state.shape)), axis=1) + del chromatin_state + + # ... ... first moments: M_s, M_u => M_s *+* Mc, M_u *+* 0 + Mc = atac_adata.layers[MDKM.RNA_FIRST_MOMENT_CHROM_LAYER].toarray().copy() + Ms = rna_adata.layers[MDKM.RNA_FIRST_MOMENT_SPLICED_LAYER].copy() + Mu = rna_adata.layers[MDKM.RNA_FIRST_MOMENT_UNSPLICED_LAYER].copy() + + Ms = np.concatenate((Ms, Mc), axis=1) + Mu = np.concatenate((Mu, np.zeros(Mc.shape)), axis=1) + del Mc + + # ... ... velocity_S ==> velocity_S + lifted_velo_c + velocity_C = atac_adata.layers[MDKM.ATAC_CHROMATIN_VELOCITY_LAYER].copy() + velocity_S = rna_adata.layers[MDKM.RNA_SPLICED_VELOCITY_LAYER].toarray().copy() + + velocity_S = np.concatenate((velocity_S, rescale_velo_c * velocity_C), axis=1) + del velocity_C + + # ... .obs + # ... ... carry over entire obs for now + obs_df = rna_adata.obs.copy() + + # ... .obsp + # ... ... connectivities ==> connectivities + connectivities = rna_adata.obsp['connectivities'].copy() + + # ... ... distances ==> distances + distances = rna_adata.obsp['distances'].copy() + + # ... .uns + # ... ... dynamics ==> dynamics + dynamics = rna_adata.uns['dynamics'].copy() + + # ... ... neighbors ==> neighbors + neighbors = rna_adata.uns['neighbors'].copy() + + # ... ... pp ==> pp + pp = rna_adata.uns['pp'].copy() + + # ... ... vel_params_names ==> vel_params_names + vel_params_names = rna_adata.uns['vel_params_names'].copy() + + # ... .var + # ... ... var_names ==> (rna) var_names + (atac) var_names + var_names = rna_adata.var_names.tolist() + atac_adata.var_names.tolist() + + # ... ... feature_type ==> n_genes * 'gene', n_elements * 'CRE' + feature_type = rna_adata.n_vars * ['gene'] + atac_adata.n_vars * ['CRE'] + + # ... ... use_for_pca + use_for_dynamics = genes_and_elements_for_dynamics(atac_adata=atac_adata, + rna_adata=rna_adata, + cre_dict=cre_dict, + promoter_dict=promoter_dict, + min_r2=min_r2) + + # ... ... use_for_pca + use_for_pca = genes_and_elements_for_dynamics(atac_adata=atac_adata, + rna_adata=rna_adata, + cre_dict=cre_dict, + promoter_dict=promoter_dict, + min_r2=min_r2) + + var_df = pd.DataFrame(data={'feature_type': feature_type, + 'use_for_dynamics': use_for_dynamics, + 'use_for_pca': use_for_pca}, + index=var_names) + + # ... .varm + # ... ... vel_params => vel_params + (1,1) + vel_params_array = rna_adata.varm['vel_params'] + + chrom_vel_params_array = np.full((atac_adata.n_vars, len(vel_params_names)), np.nan) + + # ... ... create vacuous 'gamma' for chromatin data + gamma_index = np.where(np.array(vel_params_names) == 'gamma')[0][0] + chrom_vel_params_array[:, gamma_index] = np.ones(atac_adata.n_vars) + + # ... ... create vacuous 'gamma_r2' for chromatin data + gamma_r2_index = np.where(np.array(vel_params_names) == 'gamma_r2')[0][0] + chrom_vel_params_array[:, gamma_r2_index] = np.ones(atac_adata.n_vars) + + # ... ... concatenate the arrays + vel_params_array = np.concatenate((vel_params_array, chrom_vel_params_array), axis=0) + + # X ==> X + X + X = np.concatenate((rna_adata.X.toarray().copy(), atac_adata.X.toarray().copy()), axis=1) + + # Instantiate the multiomic AnnData object + adata_multi = AnnData(obs=obs_df, + var=var_df, + X=X) + # ... add .layers + # ... ... counts + adata_multi.layers[MDKM.RNA_COUNTS_LAYER] = counts + + # ... ... raw + adata_multi.layers[MDKM.RNA_SPLICED_LAYER] = spliced + adata_multi.layers[MDKM.RNA_UNSPLICED_LAYER] = unspliced + + # ... ... first moments + adata_multi.layers[MDKM.RNA_FIRST_MOMENT_SPLICED_LAYER] = Ms + adata_multi.layers[MDKM.RNA_FIRST_MOMENT_UNSPLICED_LAYER] = Mu + + # ... ... rna velocity + adata_multi.layers[MDKM.RNA_SPLICED_VELOCITY_LAYER] = velocity_S + + # ... add .obsp + adata_multi.obsp['connectivities'] = connectivities + adata_multi.obsp['distances'] = distances + + # ... add .uns + adata_multi.uns['dynamics'] = dynamics + adata_multi.uns['neighbors'] = neighbors + adata_multi.uns['pp'] = pp + adata_multi.uns['vel_params_names'] = vel_params_names + + # ... add varm + adata_multi.varm['vel_params'] = vel_params_array + + # Set instance variables + + self.multi_adata = adata_multi.copy() + + def cell_velocities(self, + cores: int = 6, + min_r2: float = 0.5, + n_neighbors: int = 30, + n_pcs: int = 30, + random_seed: int = 42, + trans_matrix_method: Literal["kmc", "fp", "cosine", "pearson", "transform"] = "pearson", + ) -> AnnData: + # We'll save ourselves some grief and just compute both the PCA and UMAP representations + # of the vector field up front + # ... extract the multiomic AnnData object + adata_multi = self.multi_adata.copy() + + # ... compute PCA + adata_multi = pca(adata=adata_multi, + n_pca_components=n_pcs, + random_state=random_seed) + + # ... compute the appropriate dimensional reduction + reduceDimension(adata_multi, + basis='pca', + cores=cores, + n_pca_components=n_pcs, + n_components=2, + n_neighbors=n_neighbors, + reduction_method='umap') + + # ... project high dimensional velocities onto PCA embeddings and compute cell transitions + cell_velocities(adata_multi, + basis='pca', + method=trans_matrix_method, + min_r2=min_r2, + other_kernels_dict={'transform': 'sqrt'}) + + # ... project high dimensional velocities onto PCA embeddings and compute cell transitions + cell_velocities(adata_multi, + basis='umap', + method=trans_matrix_method, + min_r2=min_r2, + other_kernels_dict={'transform': 'sqrt'}) + + self.multi_adata = adata_multi.copy() + + return self.multi_adata + + def compute_vector_field(self, + cores: int = 6, + restart_num: int = 5 + ): + VectorField(self.multi_adata, + basis='pca', + cores=cores, + grid_num=100, + M=1000, + pot_curl_div=True, + restart_num=restart_num, + restart_seed=[i * 888888888 for i in range(1, restart_num + 1)]) + ''' + dyn.vf.VectorField(self.multi_adata, + basis='umap', + cores=cores, + grid_num=100, + M=1000, + pot_curl_div=True, + restart_num=restart_num, + restart_seed=[i * 888888888 for i in range(1, restart_num + 1)]) + ''' + + def plot_cell_wise_vectors(self, + color: str = 'cell_type', + figsize: Tuple[float, float] = (9, 6), + **save_kwargs + ) -> None: + fig, ax = plt.subplots(figsize=figsize) + cell_wise_vectors(self.multi_adata, + basis='umap', + color=[color], + pointsize=0.1, + quiver_length=6, + quiver_size=6, + save_kwargs=save_kwargs, + save_show_or_return='show', + show_arrowed_spines=False, + show_legend='on_data', + ax = ax) + plt.show() + + def plot_streamline_plot(self, + color: str = 'cell_type', + figsize: Tuple[float, float] = (9, 6), + **save_kwargs + ) -> None: + fig, ax = plt.subplots(figsize=figsize) + streamline_plot(self.multi_adata, + basis='umap', + color=[color], + show_arrowed_spines=True, + show_legend='on_data', + ax = ax) + plt.show() + + def plot_topography(self, + color: str = 'cell_type', + figsize: Tuple[float, float] = (9, 6), + **save_kwargs + ) -> None: + fig, ax = plt.subplots(figsize=figsize) + topography(self.multi_adata, + basis='pca', + background='white', + color=color, + frontier=True, + n = 200, + show_legend='on data', + streamline_color='black', + ax = ax) + + def predict_perturbation(self, + gene: str, + expression: float, + cell_type_key: str = 'cell_type', + compute_animation: bool = False, + emb_basis: str = 'umap', + skip_cell_types: List = None + ) -> AnnData: + + perturbed_multi_adata = perturbation(self.multi_adata, + genes=gene, + expression=expression, + emb_basis='umap') + streamline_plot(self.multi_adata, color=["cell_type", gene], + basis="umap_perturbation") + + if compute_animation: + # Fit analytic vector field + VectorField(self.multi_adata, + basis='umap_perturbation') + + compute_animations(adata=self.multi_adata, + cell_type_key=cell_type_key, + prefix=gene, + skip_cell_types=skip_cell_types) + + return perturbed_multi_adata diff --git a/dynamo/multivelo/pyWNN.py b/dynamo/multivelo/pyWNN.py new file mode 100644 index 000000000..763fef4fb --- /dev/null +++ b/dynamo/multivelo/pyWNN.py @@ -0,0 +1,270 @@ +# This has been taken and lightly modified from Dylan's Kotliar's github repository +from anndata import AnnData +import numpy as np +import scanpy as sc +from sklearn import preprocessing +from scipy.sparse import csr_matrix, lil_matrix, diags +import sys +import time +from typing import List + +# Import from dynamo +from ..dynamo_logger import ( + LoggerManager, + main_debug, + main_exception, + main_finish_progress, + main_info, + main_info_insert_adata, + main_warning, +) + + +def compute_bw(knn_adj, embedding, n_neighbors=20): + intersect = knn_adj.dot(knn_adj.T) + indices = intersect.indices + indptr = intersect.indptr + data = intersect.data + data = data / ((n_neighbors * 2) - data) + bandwidth = [] + num = 0 + for i in range(intersect.shape[0]): + cols = indices[indptr[i]:indptr[i + 1]] + rowvals = data[indptr[i]:indptr[i + 1]] + idx = np.argsort(rowvals) + valssort = rowvals[idx] + numinset = len(cols) + if numinset < n_neighbors: + sys.exit('Fewer than 20 cells with Jacard sim > 0') + else: + curval = valssort[n_neighbors] + for num in range(n_neighbors, numinset): + if valssort[num] != curval: + break + else: + num += 1 + minjacinset = cols[idx][:num] + if num < n_neighbors: + main_exception('compute_bw method failed.') + sys.exit(-1) + else: + euc_dist = ((embedding[minjacinset, :] - embedding[i, :]) ** 2).sum(axis=1) ** .5 + euc_dist_sorted = np.sort(euc_dist)[::-1] + bandwidth.append(np.mean(euc_dist_sorted[:n_neighbors])) + return np.array(bandwidth) + +def compute_affinity(dist_to_predict, dist_to_nn, bw): + affinity = dist_to_predict - dist_to_nn + affinity[affinity < 0] = 0 + affinity = affinity * -1 + affinity = np.exp(affinity / (bw - dist_to_nn)) + return affinity + +def dist_from_adj(adjacency, embed1, embed2, nndist1, nndist2): + dist1 = lil_matrix(adjacency.shape) + dist2 = lil_matrix(adjacency.shape) + + indices = adjacency.indices + indptr = adjacency.indptr + ncells = adjacency.shape[0] + + tic = time.perf_counter() + for i in range(ncells): + for j in range(indptr[i], indptr[i + 1]): + col = indices[j] + a = (((embed1[i, :] - embed1[col, :]) ** 2).sum() ** .5) - nndist1[i] + if a == 0: + dist1[i, col] = np.nan + else: + dist1[i, col] = a + b = (((embed2[i, :] - embed2[col, :]) ** 2).sum() ** .5) - nndist2[i] + if b == 0: + dist2[i, col] = np.nan + else: + dist2[i, col] = b + + if (i % 2000) == 0: + toc = time.perf_counter() + main_info('%d out of %d %.2f seconds elapsed' % (i, ncells, toc - tic), indent_level=3) + + return csr_matrix(dist1), csr_matrix(dist2) + +def get_nearestneighbor(knn, neighbor=1): + # For each row of knn, returns the column with the lowest value i.e. the nearest neighbor + indices = knn.indices + indptr = knn.indptr + data = knn.data + nn_idx = [] + for i in range(knn.shape[0]): + cols = indices[indptr[i]:indptr[i + 1]] + rowvals = data[indptr[i]:indptr[i + 1]] + idx = np.argsort(rowvals) + nn_idx.append(cols[idx[neighbor - 1]]) + return np.array(nn_idx) + +def select_topK(dist, n_neighbors=20): + indices = dist.indices + indptr = dist.indptr + data = dist.data + nrows = dist.shape[0] + + final_data = [] + final_col_ind = [] + + for i in range(nrows): + cols = indices[indptr[i]:indptr[i + 1]] + rowvals = data[indptr[i]:indptr[i + 1]] + idx = np.argsort(rowvals) + final_data.append(rowvals[idx[(-1 * n_neighbors):]]) + final_col_ind.append(cols[idx[(-1 * n_neighbors):]]) + + final_data = np.concatenate(final_data) + final_col_ind = np.concatenate(final_col_ind) + final_row_ind = np.tile(np.arange(nrows), (n_neighbors, 1)).reshape(-1, order='F') + + result = csr_matrix((final_data, (final_row_ind, final_col_ind)), shape=(nrows, dist.shape[1])) + + return result + +class pyWNN(): + + def __init__(self, + adata: AnnData, + reps: List[str] = None, + n_neighbors: int = 20, + npcs: List[int] = None, + seed: int = 14, + distances: csr_matrix = None + ) -> None: + """\ + Class for running weighted nearest neighbors analysis as described in Hao + et al 2021. + """ + # Set default arguments + if npcs is None: + npcs = [20, 20] + + if reps is None: + reps = ['X_pca', 'X_apca'] + + self.seed = seed + np.random.seed(seed) + + if len(reps) > 2: + sys.exit('WNN currently only implemented for 2 modalities') + + self.adata = adata.copy() + self.reps = [r + '_norm' for r in reps] + self.npcs = npcs + for (i, r) in enumerate(reps): + self.adata.obsm[self.reps[i]] = preprocessing.normalize(adata.obsm[r][:, 0:npcs[i]]) + + self.n_neighbors = n_neighbors + if distances is None: + main_info('Computing KNN distance matrices using default Scanpy implementation') + # ... n_neighbors in each modality + sc.pp.neighbors(self.adata, n_neighbors=n_neighbors, n_pcs=npcs[0], use_rep=self.reps[0], + metric='euclidean', key_added='1') + sc.pp.neighbors(self.adata, n_neighbors=n_neighbors, n_pcs=npcs[1], use_rep=self.reps[1], + metric='euclidean', key_added='2') + + # ... top 200 nearest neighbors in each modality + sc.pp.neighbors(self.adata, n_neighbors=200, n_pcs=npcs[0], use_rep=self.reps[0], metric='euclidean', + key_added='1_200') + sc.pp.neighbors(self.adata, n_neighbors=200, n_pcs=npcs[1], use_rep=self.reps[1], metric='euclidean', + key_added='2_200') + self.distances = ['1_distances', '2_distances', '1_200_distances', '2_200_distances'] + else: + main_info('Using pre-computed KNN distance matrices') + self.distances = distances + + for d in self.distances: + # Convert to sparse CSR matrices as needed + if type(self.adata.obsp[d]) is not csr_matrix: + self.adata.obsp[d] = csr_matrix(self.adata.obsp[d]) + + self.NNdist = [] + self.NNidx = [] + self.NNadjacency = [] + self.BWs = [] + + for (i, r) in enumerate(self.reps): + nn = get_nearestneighbor(self.adata.obsp[self.distances[i]]) + dist_to_nn = ((self.adata.obsm[r] - self.adata.obsm[r][nn, :]) ** 2).sum(axis=1) ** .5 + nn_adj = (self.adata.obsp[self.distances[i]] > 0).astype(int) + nn_adj_wdiag = csr_matrix(nn_adj.copy()) + nn_adj_wdiag.setdiag(1) + bw = compute_bw(nn_adj_wdiag, self.adata.obsm[r], n_neighbors=self.n_neighbors) + self.NNidx.append(nn) + self.NNdist.append(dist_to_nn) + self.NNadjacency.append(nn_adj) + self.BWs.append(bw) + + self.cross = [] + self.weights = [] + self.within = [] + self.WNN = None + self.WNNdist = None + + def compute_weights(self) -> None: + cmap = {0: 1, 1: 0} + affinity_ratios = [] + self.within = [] + self.cross = [] + for (i, r) in enumerate(self.reps): + within_predict = self.NNadjacency[i].dot(self.adata.obsm[r]) / (self.n_neighbors - 1) + cross_predict = self.NNadjacency[cmap[i]].dot(self.adata.obsm[r]) / (self.n_neighbors - 1) + + within_predict_dist = ((self.adata.obsm[r] - within_predict) ** 2).sum(axis=1) ** .5 + cross_predict_dist = ((self.adata.obsm[r] - cross_predict) ** 2).sum(axis=1) ** .5 + within_affinity = compute_affinity(within_predict_dist, self.NNdist[i], self.BWs[i]) + cross_affinity = compute_affinity(cross_predict_dist, self.NNdist[i], self.BWs[i]) + affinity_ratios.append(within_affinity / (cross_affinity + 0.0001)) + self.within.append(within_predict_dist) + self.cross.append(cross_predict_dist) + + self.weights.append(1 / (1 + np.exp(affinity_ratios[1] - affinity_ratios[0]))) + self.weights.append(1 - self.weights[0]) + + def compute_wnn( + self, + adata: AnnData + ) -> AnnData: + main_info('Computing modality weights', indent_level=2) + self.compute_weights() + union_adj_mat = ((self.adata.obsp[self.distances[2]] + self.adata.obsp[self.distances[3]]) > 0).astype(int) + + main_info('Computing weighted distances for union of 200 nearest neighbors between modalities', indent_level=2) + full_dists = dist_from_adj(union_adj_mat, self.adata.obsm[self.reps[0]], self.adata.obsm[self.reps[1]], + self.NNdist[0], self.NNdist[1]) + weighted_dist = csr_matrix(union_adj_mat.shape) + for (i, dist) in enumerate(full_dists): + dist = diags(-1 / (self.BWs[i] - self.NNdist[i]), format='csr').dot(dist) + dist.data = np.exp(dist.data) + ind = np.isnan(dist.data) + dist.data[ind] = 1 + dist = diags(self.weights[i]).dot(dist) + weighted_dist += dist + + main_info('Selecting top K neighbors', indent_level=2) + self.WNN = select_topK(weighted_dist, n_neighbors=self.n_neighbors) + WNNdist = self.WNN.copy() + x = (1 - WNNdist.data) / 2 + x[x < 0] = 0 + x[x > 1] = 1 + WNNdist.data = np.sqrt(x) + self.WNNdist = WNNdist + + adata.obsp['WNN'] = self.WNN + adata.obsp['WNN_distance'] = self.WNNdist + adata.obsm[self.reps[0]] = self.adata.obsm[self.reps[0]] + adata.obsm[self.reps[1]] = self.adata.obsm[self.reps[1]] + adata.uns['WNN'] = {'connectivities_key': 'WNN', + 'distances_key': 'WNN_distance', + 'params': {'n_neighbors': self.n_neighbors, + 'method': 'WNN', + 'random_state': self.seed, + 'metric': 'euclidean', + 'use_rep': self.reps[0], + 'n_pcs': self.npcs[0]}} + return (adata) diff --git a/dynamo/multivelo/settings.py b/dynamo/multivelo/settings.py new file mode 100644 index 000000000..5d07355a9 --- /dev/null +++ b/dynamo/multivelo/settings.py @@ -0,0 +1,27 @@ +import os + +"""Settings +""" + +# the desired verbosity +global VERBOSITY + +# cwd: The current working directory +global CWD + +# the name of the file to which we're writing the log files +global LOG_FOLDER + +# the name of the file to which we're writing the logs +# (If left to the default value of None, we don't write to a file) +global LOG_FILENAME + +# the name of the gene the code is processing +global GENE + +VERBOSITY = 1 +CWD = os.path.abspath(os.getcwd()) +LOG_FOLDER = os.path.join(CWD, "../logs") +LOG_FILENAME = None +GENE = None + diff --git a/dynamo/multivelo/sparse_matrix_utils.py b/dynamo/multivelo/sparse_matrix_utils.py new file mode 100644 index 000000000..d64aa2765 --- /dev/null +++ b/dynamo/multivelo/sparse_matrix_utils.py @@ -0,0 +1,94 @@ +import os +import warnings + +if "NVCC" not in os.environ: + os.environ["NVCC"] = "/usr/local/cuda-11.5/bin/nvcc" + warnings.warn( + "NVCC Path not found, set to : /usr/local/cuda-11.5/bin/nvcc . \nPlease set NVCC as appropitate to your environment" + ) + +import cupy as cp +from numba import cuda +import math + +## Cuda JIT +code = """ +#include +extern "C" __global__ +void sort_sparse_array(double *data, int*indices, int *indptr, int n_rows) +{ + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if(tid >= n_rows) return; + thrust::sort_by_key(thrust::seq, data+ indptr[tid], data + indptr[tid+1], indices + indptr[tid]); +} +""" + +kernel = cp.RawModule(code=code, backend="nvcc") +sort_f = kernel.get_function("sort_sparse_array") + +## Numba function +@cuda.jit +def find_top_k_values( + data, indices, indptr, output_values_ar, output_idx_ar, k, n_rows +): + gid = cuda.grid(1) + + if gid >= n_rows: + return + + row_st_ind = indptr[gid] + row_end_ind = indptr[gid + 1] - 1 + + k = min(k, 1 + row_end_ind - row_st_ind) + for i in range(0, k): + index = row_st_ind + i + if data[index] != 0: + output_values_ar[gid][i] = data[index] + output_idx_ar[gid][i] = indices[index] + + +def find_top_k_values_sparse_matrix(X, k): + + X = X.copy() + + ### Output arrays to save the top k values + values_ar = cp.full(fill_value=0, shape=(X.shape[0], k), dtype=cp.float64) + idx_ar = cp.full(fill_value=-1, shape=(X.shape[0], k), dtype=cp.int32) + + ### sort in decreasing order + X.data = X.data * -1 + sort_f( + (math.ceil(X.shape[0] / 32),), (32,), (X.data, X.indices, X.indptr, X.shape[0]) + ) + X.data = X.data * -1 + + ## configure kernel based on number of tasks + find_top_k_values_k = find_top_k_values.forall(X.shape[0]) + + find_top_k_values_k(X.data, X.indices, X.indptr, values_ar, idx_ar, k, X.shape[0]) + + return idx_ar, values_ar + + +def top_n_sparse(X, n): + """Return indices,values of top n values in each row of a sparse matrix + Args: + X: The sparse matrix from which to get the + top n indices and values per row + n: The number of highest values to extract from each row + Returns: + indices: The top n indices per row + values: The top n values per row + """ + value_ls, idx_ls = [], [] + batch_size = 500 + for s in range(0, X.shape[0], batch_size): + e = min(s + batch_size, X.shape[0]) + idx_ar, value_ar = find_top_k_values_sparse_matrix(X[s:e], n) + value_ls.append(value_ar) + idx_ls.append(idx_ar) + + indices = cp.concatenate(idx_ls) + values = cp.concatenate(value_ls) + + return indices, values \ No newline at end of file diff --git a/dynamo/tools/utils.py b/dynamo/tools/utils.py index f39ec0265..0dd9e3b00 100755 --- a/dynamo/tools/utils.py +++ b/dynamo/tools/utils.py @@ -2718,6 +2718,7 @@ def get_ekey_vkey_from_adata(adata: AnnData) -> Tuple[str, str, str]: mapper = get_mapper() layer = [] + if has_splicing: if has_labeling: if "X_new" not in adata.layers.keys(): # unlabel spliced: S diff --git a/requirements.txt b/requirements.txt index 4cf7fc3d4..b1e42bf0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,3 +25,4 @@ openpyxl typing-extensions session-info>=1.0.0 adjustText +mudata diff --git a/setup.cfg b/setup.cfg index 37db253e2..807606544 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,3 +6,6 @@ tag = True [bumpversion:file:setup.py] [bumpversion:file:docs/source/conf.py] + +[options.package_data] +* = multivelo/neural_nets/*