Skip to content

Commit

Permalink
refactor: refactor scar dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
CaibinSh committed May 28, 2024
1 parent db17c87 commit 7bb5d9a
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions scar/main/_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ def __init__(
ambient_profile[batch_id, :] = subset.X.sum(axis=0) / subset.X.sum()

# add a mapper to locate the batch id
self.batch_id = torch.from_numpy(batch_id_per_cell).int().to(self.device)
self.n_batch = np.unique(batch_id_per_cell).size
self.batch_id = batch_id_per_cell
self.n_batch = len(np.unique(batch_id_per_cell))

# get ambient profile from AnnData.uns
elif (ambient_profile is None) and ("ambient_profile_all" in raw_count.uns):
Expand All @@ -322,16 +322,15 @@ def __init__(
f"Expecting str or np.array or pd.DataFrame object, but get a {type(raw_count)}"
)

raw_count = raw_count.fillna(0) # missing vals -> zeros

# Loading numpy to tensor on GPU
raw_count = raw_count.fillna(0) # missing vals -> zeros raw_count if isinstance(raw_count, ad.AnnData) else
self.raw_count = raw_count.values
"""raw_count : np.ndarray, raw count matrix.
"""
self.n_features = raw_count.shape[1]

self.cell_id = list(raw_count.index)
self.feature_names = list(raw_count.columns)
self.cell_id = raw_count.index.to_list()
self.feature_names = raw_count.columns.to_list()

if isinstance(ambient_profile, str):
ambient_profile = pd.read_pickle(ambient_profile)
Expand All @@ -355,7 +354,7 @@ def __init__(
.reshape(1, -1)
)
# add a mapper to locate the artificial batch id
self.batch_id = torch.zeros(raw_count.shape[0]).int().to(self.device)
self.batch_id = np.zeros(raw_count.shape[0], dtype=int)#.reshape(-1, 1)
self.n_batch = 1

self.ambient_profile = ambient_profile
Expand Down Expand Up @@ -710,10 +709,10 @@ class UMIDataset(torch.utils.data.Dataset):
def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None):
"""Initialization"""
self.device = device
self.raw_count = torch.from_numpy(raw_count).int().to(device)
self.raw_count = torch.from_numpy(raw_count).int().to(self.device)
self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device)
self.batch_id = batch_id.to(torch.int64).to(device)
self.batch_onehot = self._onehot()
self.batch_id = torch.from_numpy(batch_id).to(torch.int64)
self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64).to(self.device)

if list_ids:
self.list_ids = list_ids
Expand All @@ -731,11 +730,4 @@ def __getitem__(self, index):
sc_count = self.raw_count[sc_id, :]
sc_ambient = self.ambient_profile[self.batch_id[sc_id], :]
sc_batch_id_onehot = self.batch_onehot[self.batch_id[sc_id], :]
return sc_count, sc_ambient, sc_batch_id_onehot

def _onehot(self):
"""One-hot encoding"""
n_batch = self.batch_id.unique().size()[0]
x_onehot = torch.zeros(n_batch, n_batch).to(self.device)
x_onehot.scatter_(1, self.batch_id.unique().unsqueeze(1), 1)
return x_onehot
return sc_count, sc_ambient, sc_batch_id_onehot

0 comments on commit 7bb5d9a

Please sign in to comment.