Skip to content

Commit

Permalink
refactor: load entire data to GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
CaibinSh committed Jul 25, 2024
1 parent 7ada773 commit cab3872
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions scar/main/_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def __init__(
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
self.device = torch.device("mps")
self.logger.info("MPS is detected and will be used.")
self.logger.warning("PyTorch is slower on MPS than on the CPU; we recommend using the CPU by specifying device='cpu' on Mac.")
# self.logger.warning("PyTorch is slower on MPS than on the CPU; we recommend using the CPU by specifying device='cpu' on Mac.")
else:
self.device = torch.device("cpu")
self.logger.info("No GPU detected. Use CPU instead.")
Expand Down Expand Up @@ -441,11 +441,11 @@ def train(
train_ids, test_ids = train_test_split(list_ids, train_size=train_size)

# Generators
training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, list_ids=train_ids)
training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, list_ids=train_ids)
training_generator = torch.utils.data.DataLoader(
training_set, batch_size=batch_size, shuffle=shuffle
)
val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, list_ids=test_ids)
val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, list_ids=test_ids)
val_generator = torch.utils.data.DataLoader(
val_set, batch_size=batch_size, shuffle=shuffle
)
Expand Down Expand Up @@ -495,9 +495,9 @@ def train(
vae_nets.train()
for x_batch, ambient_freq, batch_id_onehot in training_generator:
# Move data to device
x_batch = x_batch.to(self.device)
ambient_freq = ambient_freq.to(self.device)
batch_id_onehot = batch_id_onehot.to(self.device)
# x_batch = x_batch.to(self.device)
# ambient_freq = ambient_freq.to(self.device)
# batch_id_onehot = batch_id_onehot.to(self.device)

optim.zero_grad()
dec_nr, dec_prob, means, var, dec_dp = vae_nets(x_batch, batch_id_onehot)
Expand Down Expand Up @@ -592,7 +592,7 @@ def inference(
native_frequencies, and noise_ratio. \
A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type.
"""
total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id)
total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device)
n_features = self.n_features
sample_size = self.raw_count.shape[0]
self.native_counts = np.empty([sample_size, n_features])
Expand All @@ -609,9 +609,9 @@ def inference(

for x_batch_tot, ambient_freq_tot, x_batch_id_onehot_tot in generator_full_data:
# Move data to device
x_batch_tot = x_batch_tot.to(self.device)
x_batch_id_onehot_tot = x_batch_id_onehot_tot.to(self.device)
ambient_freq_tot = ambient_freq_tot.to(self.device)
# x_batch_tot = x_batch_tot.to(self.device)
# x_batch_id_onehot_tot = x_batch_id_onehot_tot.to(self.device)
# ambient_freq_tot = ambient_freq_tot.to(self.device)

minibatch_size = x_batch_tot.shape[
0
Expand Down Expand Up @@ -716,12 +716,12 @@ def assignment(self, cutoff=3, moi=None):
class UMIDataset(torch.utils.data.Dataset):
"""Characterizes dataset for PyTorch"""

def __init__(self, raw_count, ambient_profile, batch_id, list_ids=None):
def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None):
"""Initialization"""
self.raw_count = torch.from_numpy(raw_count).int()
self.ambient_profile = torch.from_numpy(ambient_profile).float()
self.batch_id = torch.from_numpy(batch_id).to(torch.int64)
self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64)
self.raw_count = torch.from_numpy(raw_count).int().to(device)
self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device)
self.batch_id = torch.from_numpy(batch_id).to(torch.int64).to(device)
self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64).to(device)

if list_ids:
self.list_ids = list_ids
Expand Down

0 comments on commit cab3872

Please sign in to comment.