Skip to content

Commit

Permalink
refactor: refactor codes for better GPU usage
Browse files Browse the repository at this point in the history
  • Loading branch information
CaibinSh committed May 28, 2024
1 parent 7bb5d9a commit 535326f
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions scar/main/_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,11 @@ def train(

vae_nets.train()
for x_batch, ambient_freq, batch_id_onehot in training_generator:
# Move data to device
x_batch = x_batch.to(self.device)
ambient_freq = ambient_freq.to(self.device)
batch_id_onehot = batch_id_onehot.to(self.device)

optim.zero_grad()
dec_nr, dec_prob, means, var, dec_dp = vae_nets(x_batch, batch_id_onehot)
recon_loss_minibatch, kld_loss_minibatch, loss_minibatch = loss_fn(
Expand Down Expand Up @@ -603,6 +608,11 @@ def inference(
)

for x_batch_tot, ambient_freq_tot, x_batch_id_onehot_tot in generator_full_data:
# Move data to device
x_batch_tot = x_batch_tot.to(self.device)
x_batch_id_onehot_tot = x_batch_id_onehot_tot.to(self.device)
ambient_freq_tot = ambient_freq_tot.to(self.device)

minibatch_size = x_batch_tot.shape[
0
] # if not the last batch, equals to batch size
Expand Down Expand Up @@ -706,13 +716,12 @@ def assignment(self, cutoff=3, moi=None):
class UMIDataset(torch.utils.data.Dataset):
"""Characterizes dataset for PyTorch"""

def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None):
def __init__(self, raw_count, ambient_profile, batch_id, list_ids=None):
"""Initialization"""
self.device = device
self.raw_count = torch.from_numpy(raw_count).int().to(self.device)
self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device)
self.raw_count = torch.from_numpy(raw_count).int()
self.ambient_profile = torch.from_numpy(ambient_profile).float()
self.batch_id = torch.from_numpy(batch_id).to(torch.int64)
self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64).to(self.device)
self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64)

if list_ids:
self.list_ids = list_ids
Expand Down

0 comments on commit 535326f

Please sign in to comment.