diff --git a/scar/main/_scar.py b/scar/main/_scar.py index 9c2f902..d6af6e6 100644 --- a/scar/main/_scar.py +++ b/scar/main/_scar.py @@ -772,10 +772,9 @@ def __getitem__(self, index): return self.cache[index] # Select sample - sc_id = self.list_ids[index] - sc_count = self.raw_count[sc_id].to(self.device) if isinstance(self.raw_count, torch.Tensor) else torch.from_numpy(self.raw_count[sc_id].X.toarray().flatten()).int().to(self.device) - sc_ambient = self.ambient_profile[self.batch_id[sc_id], :].to(self.device) - sc_batch_id_onehot = self.batch_onehot[self.batch_id[sc_id], :].to(self.device) + sc_count = self.raw_count[index].to(self.device) if isinstance(self.raw_count, torch.Tensor) else torch.from_numpy(self.raw_count[index].X.toarray().flatten()).int().to(self.device) + sc_ambient = self.ambient_profile[self.batch_id[index], :].to(self.device) + sc_batch_id_onehot = self.batch_onehot[self.batch_id[index], :].to(self.device) # Cache the sample if len(self.cache) <= self.cache_capacity: