From 4930798f7feb2cd0772e5d83919a12ce5cba6686 Mon Sep 17 00:00:00 2001 From: Caibin Sheng Date: Thu, 8 Aug 2024 10:14:26 +0200 Subject: [PATCH] refactor(scar): refactor dataloader --- scar/main/_scar.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/scar/main/_scar.py b/scar/main/_scar.py index 9c2f902..d6af6e6 100644 --- a/scar/main/_scar.py +++ b/scar/main/_scar.py @@ -772,10 +772,9 @@ def __getitem__(self, index): return self.cache[index] # Select sample - sc_id = self.list_ids[index] - sc_count = self.raw_count[sc_id].to(self.device) if isinstance(self.raw_count, torch.Tensor) else torch.from_numpy(self.raw_count[sc_id].X.toarray().flatten()).int().to(self.device) - sc_ambient = self.ambient_profile[self.batch_id[sc_id], :].to(self.device) - sc_batch_id_onehot = self.batch_onehot[self.batch_id[sc_id], :].to(self.device) + sc_count = self.raw_count[index].to(self.device) if isinstance(self.raw_count, torch.Tensor) else torch.from_numpy(self.raw_count[index].X.toarray().flatten()).int().to(self.device) + sc_ambient = self.ambient_profile[self.batch_id[index], :].to(self.device) + sc_batch_id_onehot = self.batch_onehot[self.batch_id[index], :].to(self.device) # Cache the sample if len(self.cache) <= self.cache_capacity: