diff --git a/scar/main/_scar.py b/scar/main/_scar.py index 24e2773..675cbf7 100644 --- a/scar/main/_scar.py +++ b/scar/main/_scar.py @@ -6,6 +6,7 @@ from typing import Optional, Union import numpy as np, pandas as pd, anndata as ad +from collections import OrderedDict from sklearn.model_selection import train_test_split from tqdm import tqdm from tqdm.contrib import DummyTqdmFile @@ -207,6 +208,7 @@ def __init__( sparsity: float = 0.9, batch_key: str = None, device: str = "auto", + cache_capacity: int = 20000, verbose: bool = True, ): """initialize object""" @@ -268,6 +270,11 @@ def __init__( """float, the sparsity of expected native signals. (0, 1]. \ Forced to be one in the mode of "sgRNA(s)" and "tag(s)". """ + self.cache_capacity = cache_capacity + """int, the capacity of cache. + + .. versionadded:: 0.6.1 + """ if isinstance(raw_count, ad.AnnData): if batch_key: @@ -438,9 +445,10 @@ def train( train_ids, test_ids = train_test_split(list_ids, train_size=train_size) # Generators - training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, list_ids=train_ids) + training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, list_ids=train_ids, device=self.device, cache_capacity=self.cache_capacity) training_generator = torch.utils.data.DataLoader( - training_set, batch_size=batch_size, shuffle=shuffle + training_set, batch_size=batch_size, shuffle=shuffle, + drop_last=True ) # val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, list_ids=test_ids) # val_generator = torch.utils.data.DataLoader( @@ -492,9 +500,9 @@ def train( vae_nets.train() for x_batch, ambient_freq, batch_id_onehot in training_generator: # Move data to device - x_batch = x_batch.to(self.device) - ambient_freq = ambient_freq.to(self.device) - batch_id_onehot = batch_id_onehot.to(self.device) + # x_batch = x_batch.to(self.device) + # ambient_freq = ambient_freq.to(self.device) + # batch_id_onehot = batch_id_onehot.to(self.device) optim.zero_grad() dec_nr, dec_prob, means, var, dec_dp = vae_nets(x_batch, batch_id_onehot) @@ -589,7 +597,7 @@ def inference( native_frequencies, and noise_ratio. \ A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type. """ - total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id) + total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, cache_capacity=self.cache_capacity) n_features = self.n_features sample_size = self.raw_count.shape[0] self.native_counts = np.empty([sample_size, n_features]) @@ -606,9 +614,9 @@ def inference( for x_batch_tot, ambient_freq_tot, x_batch_id_onehot_tot in generator_full_data: # Move data to device - x_batch_tot = x_batch_tot.to(self.device) - x_batch_id_onehot_tot = x_batch_id_onehot_tot.to(self.device) - ambient_freq_tot = ambient_freq_tot.to(self.device) + # x_batch_tot = x_batch_tot.to(self.device) + # x_batch_id_onehot_tot = x_batch_id_onehot_tot.to(self.device) + # ambient_freq_tot = ambient_freq_tot.to(self.device) minibatch_size = x_batch_tot.shape[ 0 @@ -710,20 +718,44 @@ def assignment(self, cutoff=3, moi=None): raise NotImplementedError +class LRUCache: + def __init__(self, capacity: int): + self.cache = OrderedDict() + self.capacity = capacity + + def get(self, key): + if key not in self.cache: + return None + self.cache.move_to_end(key) + return self.cache[key] + + def put(self, key, value): + if key in self.cache: + self.cache.move_to_end(key) + self.cache[key] = value + if len(self.cache) > self.capacity: + self.cache.popitem(last=False) + class UMIDataset(torch.utils.data.Dataset): """Characterizes dataset for PyTorch""" - def __init__(self, raw_count, ambient_profile, batch_id, list_ids=None): + def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None, cache_capacity=20000): """Initialization""" + self.raw_count = torch.from_numpy(raw_count.fillna(0).values).int() if isinstance(raw_count, pd.DataFrame) else raw_count - self.ambient_profile = torch.from_numpy(ambient_profile).float() - self.batch_id = torch.from_numpy(batch_id).to(torch.int64) - self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64) + self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device) + self.batch_id = torch.from_numpy(batch_id).to(torch.int64).to(device) + self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64).to(device) + self.device = device + self.cache_capacity = cache_capacity if list_ids: self.list_ids = list_ids else: self.list_ids = list(range(raw_count.shape[0])) + + # Cache data + self.cache = {} def __len__(self): """Denotes the total number of samples""" @@ -731,9 +763,18 @@ def __len__(self): def __getitem__(self, index): """Generates one sample of data""" + + if index in self.cache: + return self.cache[index] + # Select sample sc_id = self.list_ids[index] - sc_count = self.raw_count[sc_id] if isinstance(self.raw_count, torch.Tensor) else torch.from_numpy(self.raw_count[sc_id].X.toarray().flatten()).int() - sc_ambient = self.ambient_profile[self.batch_id[sc_id], :] - sc_batch_id_onehot = self.batch_onehot[self.batch_id[sc_id], :] + sc_count = self.raw_count[sc_id].to(self.device) if isinstance(self.raw_count, torch.Tensor) else torch.from_numpy(self.raw_count[sc_id].X.toarray().flatten()).int().to(self.device) + sc_ambient = self.ambient_profile[self.batch_id[sc_id], :].to(self.device) + sc_batch_id_onehot = self.batch_onehot[self.batch_id[sc_id], :].to(self.device) + + # Cache the sample + if len(self.cache) <= self.cache_capacity: + self.cache[index] = (sc_count, sc_ambient, sc_batch_id_onehot) + return sc_count, sc_ambient, sc_batch_id_onehot \ No newline at end of file