From 7bb5d9a4dea71da4fb1bd7014b3dab37cce7f4e8 Mon Sep 17 00:00:00 2001 From: Caibin Sheng Date: Tue, 28 May 2024 10:37:10 +0200 Subject: [PATCH] refactor: refactor scar dataloader --- scar/main/_scar.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/scar/main/_scar.py b/scar/main/_scar.py index 5d8a94d..3e56a78 100644 --- a/scar/main/_scar.py +++ b/scar/main/_scar.py @@ -300,8 +300,8 @@ def __init__( ambient_profile[batch_id, :] = subset.X.sum(axis=0) / subset.X.sum() # add a mapper to locate the batch id - self.batch_id = torch.from_numpy(batch_id_per_cell).int().to(self.device) - self.n_batch = np.unique(batch_id_per_cell).size + self.batch_id = batch_id_per_cell + self.n_batch = len(np.unique(batch_id_per_cell)) # get ambient profile from AnnData.uns elif (ambient_profile is None) and ("ambient_profile_all" in raw_count.uns): @@ -322,16 +322,15 @@ def __init__( f"Expecting str or np.array or pd.DataFrame object, but get a {type(raw_count)}" ) - raw_count = raw_count.fillna(0) # missing vals -> zeros - # Loading numpy to tensor on GPU + raw_count = raw_count.fillna(0) # missing vals -> zeros raw_count if isinstance(raw_count, ad.AnnData) else self.raw_count = raw_count.values """raw_count : np.ndarray, raw count matrix. """ self.n_features = raw_count.shape[1] - self.cell_id = list(raw_count.index) - self.feature_names = list(raw_count.columns) + self.cell_id = raw_count.index.to_list() + self.feature_names = raw_count.columns.to_list() if isinstance(ambient_profile, str): ambient_profile = pd.read_pickle(ambient_profile) @@ -355,7 +354,7 @@ def __init__( .reshape(1, -1) ) # add a mapper to locate the artificial batch id - self.batch_id = torch.zeros(raw_count.shape[0]).int().to(self.device) + self.batch_id = np.zeros(raw_count.shape[0], dtype=int)#.reshape(-1, 1) self.n_batch = 1 self.ambient_profile = ambient_profile @@ -710,10 +709,10 @@ class UMIDataset(torch.utils.data.Dataset): def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None): """Initialization""" self.device = device - self.raw_count = torch.from_numpy(raw_count).int().to(device) + self.raw_count = torch.from_numpy(raw_count).int().to(self.device) self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device) - self.batch_id = batch_id.to(torch.int64).to(device) - self.batch_onehot = self._onehot() + self.batch_id = torch.from_numpy(batch_id).to(torch.int64) + self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64).to(self.device) if list_ids: self.list_ids = list_ids @@ -731,11 +730,4 @@ def __getitem__(self, index): sc_count = self.raw_count[sc_id, :] sc_ambient = self.ambient_profile[self.batch_id[sc_id], :] sc_batch_id_onehot = self.batch_onehot[self.batch_id[sc_id], :] - return sc_count, sc_ambient, sc_batch_id_onehot - - def _onehot(self): - """One-hot encoding""" - n_batch = self.batch_id.unique().size()[0] - x_onehot = torch.zeros(n_batch, n_batch).to(self.device) - x_onehot.scatter_(1, self.batch_id.unique().unsqueeze(1), 1) - return x_onehot \ No newline at end of file + return sc_count, sc_ambient, sc_batch_id_onehot \ No newline at end of file