diff --git a/scar/main/_scar.py b/scar/main/_scar.py index 3e56a78..374d84f 100644 --- a/scar/main/_scar.py +++ b/scar/main/_scar.py @@ -494,6 +494,11 @@ def train( vae_nets.train() for x_batch, ambient_freq, batch_id_onehot in training_generator: + # Move data to device + x_batch = x_batch.to(self.device) + ambient_freq = ambient_freq.to(self.device) + batch_id_onehot = batch_id_onehot.to(self.device) + optim.zero_grad() dec_nr, dec_prob, means, var, dec_dp = vae_nets(x_batch, batch_id_onehot) recon_loss_minibatch, kld_loss_minibatch, loss_minibatch = loss_fn( @@ -603,6 +608,11 @@ def inference( ) for x_batch_tot, ambient_freq_tot, x_batch_id_onehot_tot in generator_full_data: + # Move data to device + x_batch_tot = x_batch_tot.to(self.device) + x_batch_id_onehot_tot = x_batch_id_onehot_tot.to(self.device) + ambient_freq_tot = ambient_freq_tot.to(self.device) + minibatch_size = x_batch_tot.shape[ 0 ] # if not the last batch, equals to batch size @@ -706,13 +716,12 @@ def assignment(self, cutoff=3, moi=None): class UMIDataset(torch.utils.data.Dataset): """Characterizes dataset for PyTorch""" - def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None): + def __init__(self, raw_count, ambient_profile, batch_id, list_ids=None): """Initialization""" - self.device = device - self.raw_count = torch.from_numpy(raw_count).int().to(self.device) - self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device) + self.raw_count = torch.from_numpy(raw_count).int() + self.ambient_profile = torch.from_numpy(ambient_profile).float() self.batch_id = torch.from_numpy(batch_id).to(torch.int64) - self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64).to(self.device) + self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64) if list_ids: self.list_ids = list_ids