From 8b773108070f1e56e806ddb33207c2da9a6bd07c Mon Sep 17 00:00:00 2001 From: Caibin Sheng Date: Sun, 26 May 2024 10:55:50 +0200 Subject: [PATCH] fix: optomize dataloading for ambient profile --- scar/main/_scar.py | 54 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/scar/main/_scar.py b/scar/main/_scar.py index c6cca95..4bf042c 100644 --- a/scar/main/_scar.py +++ b/scar/main/_scar.py @@ -91,6 +91,17 @@ class model: Thank Will Macnair for the valuable feedback. .. versionadded:: 0.4.0 + batch_key : str, optional + batch key in AnnData.obs, by default None. \ + If assigned, batch ambient removel will be performed and \ + the ambient profile will be estimated for each batch. + + .. versionadded:: 0.6.1 + + device : str, optional + either "auto, "cpu" or "cuda", by default "auto" + verbose : bool, optional + whether to print the details, by default True Raises ------ @@ -200,6 +211,7 @@ def __init__( feature_type: str = "mRNA", count_model: str = "binomial", sparsity: float = 0.9, + batch_key: str = None, device: str = "auto", verbose: bool = True, ): @@ -262,7 +274,7 @@ def __init__( """float, the sparsity of expected native signals. (0, 1]. \ Forced to be one in the mode of "sgRNA(s)" and "tag(s)". """ - + if isinstance(raw_count, str): raw_count = pd.read_pickle(raw_count) elif isinstance(raw_count, np.ndarray): @@ -274,8 +286,24 @@ def __init__( elif isinstance(raw_count, pd.DataFrame): pass elif isinstance(raw_count, ad.AnnData): + if batch_key: + if batch_key not in raw_count.obs.columns: + raise ValueError(f"{batch_key} not found in AnnData.obs.") + + self.logger.info( + f"Estimating ambient profile for each batch defined by {batch_key} in AnnData.obs..." + ) + batch_id_per_cell = pd.Categorical(raw_count.obs[batch_key]).codes + ambient_profile = np.empty((len(np.unique(batch_id_per_cell)),raw_count.shape[1])) + for batch_id in np.unique(batch_id_per_cell): + subset = raw_count[batch_id_per_cell==batch_id] + ambient_profile[batch_id, :] = subset.X.sum(axis=0) / subset.X.sum() + + # add a mapper to locate the batch id + self.batch_id = batch_id_per_cell + # get ambient profile from AnnData.uns - if (ambient_profile is None) and ("ambient_profile_all" in raw_count.uns): + elif (ambient_profile is None) and ("ambient_profile_all" in raw_count.uns): self.logger.info( "Found ambient profile in AnnData.uns['ambient_profile_all']" ) @@ -324,8 +352,10 @@ def __init__( ambient_profile = ( ambient_profile.squeeze() .reshape(1, -1) - .repeat(raw_count.shape[0], axis=0) ) + # add a mapper to locate the artificial batch id + self.batch_id = np.zeros(raw_count.shape[0]) + self.ambient_profile = torch.from_numpy(ambient_profile).float().to(self.device) """ambient_profile : np.ndarray, the probability of occurrence of each ambient transcript. """ @@ -410,21 +440,17 @@ def train( train_ids, test_ids = train_test_split(list_ids, train_size=train_size) # Generators - training_set = UMIDataset(self.raw_count, self.ambient_profile, train_ids) + training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, train_ids) training_generator = torch.utils.data.DataLoader( training_set, batch_size=batch_size, shuffle=shuffle ) - val_set = UMIDataset(self.raw_count, self.ambient_profile, test_ids) + val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, test_ids) val_generator = torch.utils.data.DataLoader( val_set, batch_size=batch_size, shuffle=shuffle ) loss_values = [] - # self.n_batch_train = len(training_generator) - # self.n_batch_val = len(val_generator) - # self.batch_size = batch_size - # Define model vae_nets = VAE( n_features=self.n_features, @@ -459,7 +485,7 @@ def train( desc="Training", ) progress_bar.clear() - for epoch in range(epochs): + for _ in range(epochs): train_tot_loss = 0 train_kld_loss = 0 train_recon_loss = 0 @@ -559,7 +585,7 @@ def inference( native_frequencies, and noise_ratio. \ A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type. """ - total_set = UMIDataset(self.raw_count, self.ambient_profile) + total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id) n_features = self.n_features sample_size = self.raw_count.shape[0] self.native_counts = np.empty([sample_size, n_features]) @@ -677,10 +703,12 @@ def assignment(self, cutoff=3, moi=None): class UMIDataset(torch.utils.data.Dataset): """Characterizes dataset for PyTorch""" - def __init__(self, raw_count, ambient_profile, list_ids=None): + def __init__(self, raw_count, ambient_profile, batch_id, list_ids=None): """Initialization""" self.raw_count = raw_count self.ambient_profile = ambient_profile + self.batch_id = batch_id + if list_ids: self.list_ids = list_ids else: @@ -695,5 +723,5 @@ def __getitem__(self, index): # Select sample sc_id = self.list_ids[index] sc_count = self.raw_count[sc_id, :] - sc_ambient = self.ambient_profile[sc_id, :] + sc_ambient = self.ambient_profile[self.batch_id[sc_id], :] return sc_count, sc_ambient