Skip to content

Commit

Permalink
refactor(scar): refactor dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
CaibinSh committed Aug 8, 2024
1 parent 4ba263a commit 4930798
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions scar/main/_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,10 +772,9 @@ def __getitem__(self, index):
return self.cache[index]

# Select sample
sc_id = self.list_ids[index]
sc_count = self.raw_count[sc_id].to(self.device) if isinstance(self.raw_count, torch.Tensor) else torch.from_numpy(self.raw_count[sc_id].X.toarray().flatten()).int().to(self.device)
sc_ambient = self.ambient_profile[self.batch_id[sc_id], :].to(self.device)
sc_batch_id_onehot = self.batch_onehot[self.batch_id[sc_id], :].to(self.device)
sc_count = self.raw_count[index].to(self.device) if isinstance(self.raw_count, torch.Tensor) else torch.from_numpy(self.raw_count[index].X.toarray().flatten()).int().to(self.device)
sc_ambient = self.ambient_profile[self.batch_id[index], :].to(self.device)
sc_batch_id_onehot = self.batch_onehot[self.batch_id[index], :].to(self.device)

# Cache the sample
if len(self.cache) <= self.cache_capacity:
Expand Down

0 comments on commit 4930798

Please sign in to comment.