Skip to content

Commit

Permalink
fix: optomize dataloading for ambient profile
Browse files Browse the repository at this point in the history
  • Loading branch information
CaibinSh committed May 26, 2024
1 parent 6ad76db commit 8b77310
Showing 1 changed file with 41 additions and 13 deletions.
54 changes: 41 additions & 13 deletions scar/main/_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ class model:
Thank Will Macnair for the valuable feedback.
.. versionadded:: 0.4.0
batch_key : str, optional
batch key in AnnData.obs, by default None. \
If assigned, batch ambient removel will be performed and \
the ambient profile will be estimated for each batch.
.. versionadded:: 0.6.1
device : str, optional
either "auto, "cpu" or "cuda", by default "auto"
verbose : bool, optional
whether to print the details, by default True
Raises
------
Expand Down Expand Up @@ -200,6 +211,7 @@ def __init__(
feature_type: str = "mRNA",
count_model: str = "binomial",
sparsity: float = 0.9,
batch_key: str = None,
device: str = "auto",
verbose: bool = True,
):
Expand Down Expand Up @@ -262,7 +274,7 @@ def __init__(
"""float, the sparsity of expected native signals. (0, 1]. \
Forced to be one in the mode of "sgRNA(s)" and "tag(s)".
"""

if isinstance(raw_count, str):
raw_count = pd.read_pickle(raw_count)
elif isinstance(raw_count, np.ndarray):
Expand All @@ -274,8 +286,24 @@ def __init__(
elif isinstance(raw_count, pd.DataFrame):
pass
elif isinstance(raw_count, ad.AnnData):
if batch_key:
if batch_key not in raw_count.obs.columns:
raise ValueError(f"{batch_key} not found in AnnData.obs.")

self.logger.info(
f"Estimating ambient profile for each batch defined by {batch_key} in AnnData.obs..."
)
batch_id_per_cell = pd.Categorical(raw_count.obs[batch_key]).codes
ambient_profile = np.empty((len(np.unique(batch_id_per_cell)),raw_count.shape[1]))
for batch_id in np.unique(batch_id_per_cell):
subset = raw_count[batch_id_per_cell==batch_id]
ambient_profile[batch_id, :] = subset.X.sum(axis=0) / subset.X.sum()

# add a mapper to locate the batch id
self.batch_id = batch_id_per_cell

# get ambient profile from AnnData.uns
if (ambient_profile is None) and ("ambient_profile_all" in raw_count.uns):
elif (ambient_profile is None) and ("ambient_profile_all" in raw_count.uns):
self.logger.info(
"Found ambient profile in AnnData.uns['ambient_profile_all']"
)
Expand Down Expand Up @@ -324,8 +352,10 @@ def __init__(
ambient_profile = (
ambient_profile.squeeze()
.reshape(1, -1)
.repeat(raw_count.shape[0], axis=0)
)
# add a mapper to locate the artificial batch id
self.batch_id = np.zeros(raw_count.shape[0])

self.ambient_profile = torch.from_numpy(ambient_profile).float().to(self.device)
"""ambient_profile : np.ndarray, the probability of occurrence of each ambient transcript.
"""
Expand Down Expand Up @@ -410,21 +440,17 @@ def train(
train_ids, test_ids = train_test_split(list_ids, train_size=train_size)

# Generators
training_set = UMIDataset(self.raw_count, self.ambient_profile, train_ids)
training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, train_ids)
training_generator = torch.utils.data.DataLoader(
training_set, batch_size=batch_size, shuffle=shuffle
)
val_set = UMIDataset(self.raw_count, self.ambient_profile, test_ids)
val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, test_ids)
val_generator = torch.utils.data.DataLoader(
val_set, batch_size=batch_size, shuffle=shuffle
)

loss_values = []

# self.n_batch_train = len(training_generator)
# self.n_batch_val = len(val_generator)
# self.batch_size = batch_size

# Define model
vae_nets = VAE(
n_features=self.n_features,
Expand Down Expand Up @@ -459,7 +485,7 @@ def train(
desc="Training",
)
progress_bar.clear()
for epoch in range(epochs):
for _ in range(epochs):
train_tot_loss = 0
train_kld_loss = 0
train_recon_loss = 0
Expand Down Expand Up @@ -559,7 +585,7 @@ def inference(
native_frequencies, and noise_ratio. \
A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type.
"""
total_set = UMIDataset(self.raw_count, self.ambient_profile)
total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id)
n_features = self.n_features
sample_size = self.raw_count.shape[0]
self.native_counts = np.empty([sample_size, n_features])
Expand Down Expand Up @@ -677,10 +703,12 @@ def assignment(self, cutoff=3, moi=None):
class UMIDataset(torch.utils.data.Dataset):
"""Characterizes dataset for PyTorch"""

def __init__(self, raw_count, ambient_profile, list_ids=None):
def __init__(self, raw_count, ambient_profile, batch_id, list_ids=None):
"""Initialization"""
self.raw_count = raw_count
self.ambient_profile = ambient_profile
self.batch_id = batch_id

if list_ids:
self.list_ids = list_ids
else:
Expand All @@ -695,5 +723,5 @@ def __getitem__(self, index):
# Select sample
sc_id = self.list_ids[index]
sc_count = self.raw_count[sc_id, :]
sc_ambient = self.ambient_profile[sc_id, :]
sc_ambient = self.ambient_profile[self.batch_id[sc_id], :]
return sc_count, sc_ambient

0 comments on commit 8b77310

Please sign in to comment.