Skip to content

Commit

Permalink
refactor(scar): optimize dataloading
Browse files Browse the repository at this point in the history
  • Loading branch information
CaibinSh committed Aug 8, 2024
1 parent 51a85e7 commit 41dd2a2
Showing 1 changed file with 20 additions and 34 deletions.
54 changes: 20 additions & 34 deletions scar/main/_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from scipy import sparse
import numpy as np, pandas as pd, anndata as ad

from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, random_split, DataLoader
from tqdm import tqdm
from tqdm.contrib import DummyTqdmFile

Expand Down Expand Up @@ -224,7 +224,6 @@ def __init__(
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
self.device = torch.device("mps")
self.logger.info("MPS is detected and will be used.")
# self.logger.warning("PyTorch is slower on MPS than on the CPU; we recommend using the CPU by specifying device='cpu' on Mac.")
else:
self.device = torch.device("cpu")
self.logger.info("No GPU detected. Use CPU instead.")
Expand Down Expand Up @@ -440,21 +439,14 @@ def train(
After training, a trained_model attribute will be added.
"""

list_ids = list(range(self.raw_count.shape[0]))
train_ids, test_ids = train_test_split(list_ids, train_size=train_size)

# Generators
training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, list_ids=train_ids, device=self.device, cache_capacity=self.cache_capacity)
training_generator = torch.utils.data.DataLoader(
total_dataset = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, cache_capacity=self.cache_capacity)
training_set, validation_set = random_split(total_dataset, [train_size, 1 - train_size])
training_generator = DataLoader(
training_set, batch_size=batch_size, shuffle=shuffle,
drop_last=True
)
self.dataset = training_set
# val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, list_ids=test_ids)
# val_generator = torch.utils.data.DataLoader(
# val_set, batch_size=batch_size, shuffle=shuffle
# )
self.dataset = total_dataset

loss_values = []

Expand Down Expand Up @@ -600,7 +592,6 @@ def inference(
native_frequencies, and noise_ratio. \
A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type.
"""
# total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, cache_capacity=self.cache_capacity)
n_features = self.n_features
sample_size = self.raw_count.shape[0]

Expand All @@ -626,7 +617,7 @@ def inference(
if not batch_size:
batch_size = sample_size
i = 0
generator_full_data = torch.utils.data.DataLoader(
generator_full_data = DataLoader(
self.dataset, batch_size=batch_size, shuffle=False
)

Expand Down Expand Up @@ -739,10 +730,10 @@ def assignment(self, cutoff=3, moi=None):
if moi:
raise NotImplementedError

class UMIDataset(torch.utils.data.Dataset):
class UMIDataset(Dataset):
"""Characterizes dataset for PyTorch"""

def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None, cache_capacity=20000):
def __init__(self, raw_count, ambient_profile, batch_id, device, cache_capacity=20000):
"""Initialization"""

self.raw_count = torch.from_numpy(raw_count.fillna(0).values).int() if isinstance(raw_count, pd.DataFrame) else raw_count
Expand All @@ -752,31 +743,26 @@ def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None,
self.device = device
self.cache_capacity = cache_capacity

if list_ids:
self.list_ids = list_ids
else:
self.list_ids = list(range(raw_count.shape[0]))

# Cache data
self.cache = {}

def __len__(self):
"""Denotes the total number of samples"""
return len(self.list_ids)
return self.raw_count.shape[0]

def __getitem__(self, index):
"""Generates one sample of data"""

if index in self.cache:
return self.cache[index]

# Select sample
sc_count = self.raw_count[index].to(self.device) if isinstance(self.raw_count, torch.Tensor) else torch.from_numpy(self.raw_count[index].X.toarray().flatten()).int().to(self.device)
sc_ambient = self.ambient_profile[self.batch_id[index], :]
sc_batch_id_onehot = self.batch_onehot[self.batch_id[index], :]

# Cache the sample
if len(self.cache) <= self.cache_capacity:
self.cache[index] = (sc_count, sc_ambient, sc_batch_id_onehot)

return sc_count, sc_ambient, sc_batch_id_onehot
else:
# Select samples
sc_count = self.raw_count[index].to(self.device) if isinstance(self.raw_count, torch.Tensor) else torch.from_numpy(self.raw_count[index].X.toarray().flatten()).int().to(self.device)
sc_ambient = self.ambient_profile[self.batch_id[index], :]
sc_batch_id_onehot = self.batch_onehot[self.batch_id[index], :]

# Cache samples
if len(self.cache) <= self.cache_capacity:
self.cache[index] = (sc_count, sc_ambient, sc_batch_id_onehot)
return sc_count, sc_ambient, sc_batch_id_onehot

0 comments on commit 41dd2a2

Please sign in to comment.