diff --git a/scar/main/_scar.py b/scar/main/_scar.py index 7ff1ff7..a048fe0 100644 --- a/scar/main/_scar.py +++ b/scar/main/_scar.py @@ -228,7 +228,7 @@ def __init__( elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): self.device = torch.device("mps") self.logger.info("MPS is detected and will be used.") - self.logger.warning("PyTorch is slower on MPS than on the CPU; we recommend using the CPU by specifying device='cpu' on Mac.") + # self.logger.warning("PyTorch is slower on MPS than on the CPU; we recommend using the CPU by specifying device='cpu' on Mac.") else: self.device = torch.device("cpu") self.logger.info("No GPU detected. Use CPU instead.") @@ -441,11 +441,11 @@ def train( train_ids, test_ids = train_test_split(list_ids, train_size=train_size) # Generators - training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, list_ids=train_ids) + training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, list_ids=train_ids) training_generator = torch.utils.data.DataLoader( training_set, batch_size=batch_size, shuffle=shuffle ) - val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, list_ids=test_ids) + val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, list_ids=test_ids) val_generator = torch.utils.data.DataLoader( val_set, batch_size=batch_size, shuffle=shuffle ) @@ -495,9 +495,9 @@ def train( vae_nets.train() for x_batch, ambient_freq, batch_id_onehot in training_generator: # Move data to device - x_batch = x_batch.to(self.device) - ambient_freq = ambient_freq.to(self.device) - batch_id_onehot = batch_id_onehot.to(self.device) + # x_batch = x_batch.to(self.device) + # ambient_freq = ambient_freq.to(self.device) + # batch_id_onehot = batch_id_onehot.to(self.device) optim.zero_grad() dec_nr, dec_prob, means, var, dec_dp = vae_nets(x_batch, batch_id_onehot) @@ -592,7 +592,7 @@ def inference( native_frequencies, and noise_ratio. \ A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type. """ - total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id) + total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device) n_features = self.n_features sample_size = self.raw_count.shape[0] self.native_counts = np.empty([sample_size, n_features]) @@ -609,9 +609,9 @@ def inference( for x_batch_tot, ambient_freq_tot, x_batch_id_onehot_tot in generator_full_data: # Move data to device - x_batch_tot = x_batch_tot.to(self.device) - x_batch_id_onehot_tot = x_batch_id_onehot_tot.to(self.device) - ambient_freq_tot = ambient_freq_tot.to(self.device) + # x_batch_tot = x_batch_tot.to(self.device) + # x_batch_id_onehot_tot = x_batch_id_onehot_tot.to(self.device) + # ambient_freq_tot = ambient_freq_tot.to(self.device) minibatch_size = x_batch_tot.shape[ 0 @@ -716,12 +716,12 @@ def assignment(self, cutoff=3, moi=None): class UMIDataset(torch.utils.data.Dataset): """Characterizes dataset for PyTorch""" - def __init__(self, raw_count, ambient_profile, batch_id, list_ids=None): + def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None): """Initialization""" - self.raw_count = torch.from_numpy(raw_count).int() - self.ambient_profile = torch.from_numpy(ambient_profile).float() - self.batch_id = torch.from_numpy(batch_id).to(torch.int64) - self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64) + self.raw_count = torch.from_numpy(raw_count).int().to(device) + self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device) + self.batch_id = torch.from_numpy(batch_id).to(torch.int64).to(device) + self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64).to(device) if list_ids: self.list_ids = list_ids