Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #91

Merged
merged 61 commits into from
Aug 14, 2024
Merged
Changes from 1 commit
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
8b77310
fix: optomize dataloading for ambient profile
CaibinSh May 26, 2024
3896417
fix: optomize dataloading for ambient profile
CaibinSh May 26, 2024
3a0058d
docs: update installation
CaibinSh May 26, 2024
5bc0d84
feat: add functionality for batch ambient removal
CaibinSh May 26, 2024
353059f
feat: add functionality for batch ambient removal
CaibinSh May 26, 2024
5066899
fix: fix a bug and refactor dataloading
CaibinSh May 26, 2024
e5ebaf8
fix: fix a bug
CaibinSh May 26, 2024
d6c1663
fix: fix a bug
CaibinSh May 26, 2024
24e1f95
fix: fix a bug
CaibinSh May 26, 2024
db17c87
fix: fix a bug
CaibinSh May 26, 2024
7bb5d9a
refactor: refactor scar dataloader
CaibinSh May 28, 2024
535326f
refactor: refactor codes for better GPU usage
CaibinSh May 28, 2024
73f737c
refactor: refactor codes for better GPU usage
CaibinSh May 28, 2024
7ada773
refactor: Refactor functions for compatibility with scipy 1.14
CaibinSh Jul 24, 2024
cab3872
refactor: load entire data to GPU
CaibinSh Jul 25, 2024
c4901cd
refactor(scar): refactor dataloading to improve efficiency
CaibinSh Jul 27, 2024
cbf222b
refactor(scar): refactor dataloading to improve efficiency
CaibinSh Jul 27, 2024
99665da
refactor(scar): refactor dataloading to improve efficiency
CaibinSh Jul 27, 2024
707a0e0
refactor(scar): refactor dataloading to improve efficiency
CaibinSh Jul 27, 2024
3f07c35
feature(scar): introduce caching functionality
CaibinSh Jul 27, 2024
7d85da3
doc(docs): update python version for readthedocs
CaibinSh Jul 28, 2024
7d436e4
doc(docs): update sphinx-gallery version for readthedocs
CaibinSh Jul 28, 2024
6b101f1
refactor(scar): delete unneccessary codes
CaibinSh Jul 28, 2024
c846d9b
refactor(scar): refactor scar to allow efficient usage of GPU
CaibinSh Jul 29, 2024
4c808ab
refactor(scar): refactor scar to allow efficient usage of GPU
CaibinSh Jul 29, 2024
5eb2ebb
refactor(scar): refactor scar to allow efficient usage of GPU
CaibinSh Jul 29, 2024
6529b97
refactor(scar): refactor scar to allow efficient usage of GPU
CaibinSh Jul 29, 2024
4ba263a
refactor(scar): not output native frequencies by default
CaibinSh Aug 7, 2024
4930798
refactor(scar): refactor dataloader
CaibinSh Aug 8, 2024
51a85e7
refactor(scar): remove redundant codes
CaibinSh Aug 8, 2024
41dd2a2
refactor(scar): optimize dataloading
CaibinSh Aug 8, 2024
0af1c99
refactor(vae): refactor stochastic rounding for performance
CaibinSh Aug 8, 2024
08f9e22
chore: rewrite device log information
CaibinSh Aug 9, 2024
d44f3a3
docs(scar): correct docs for scar
CaibinSh Aug 10, 2024
06f3d7e
docs(tutorials): update tutorial notebooks
Aug 10, 2024
78f8f76
docs(scar): refactor docs
CaibinSh Aug 10, 2024
43038ca
Merge branch 'develop' of https://github.com/Novartis/scar into develop
CaibinSh Aug 10, 2024
2262e34
docs(scar): add docs for cache_capacity parameter
CaibinSh Aug 11, 2024
d3d4d1b
refactor(main): refactor command line tool
CaibinSh Aug 11, 2024
9ef0272
refator(scar): refactor scar to automate calculation of ambient profi…
CaibinSh Aug 11, 2024
bedfa6e
fix(main): fix a bug
CaibinSh Aug 11, 2024
1615684
fix(main): fix a bug
CaibinSh Aug 11, 2024
7363753
Merge branch 'main' into develop
CaibinSh Aug 12, 2024
4c86c1b
fix(test): lower down sample size for test
CaibinSh Aug 12, 2024
6b00c74
docs(scar): update version in doc
CaibinSh Aug 12, 2024
d581d7c
docs(tutorial): update notebooks
CaibinSh Aug 14, 2024
631e4bb
docs(tutorial): add a tutorial for batch denoising
CaibinSh Aug 14, 2024
53b839b
refactor(docs): reset semantic_release to main branch
CaibinSh Aug 14, 2024
8a7b448
docs(tutorial): update hyperlink
CaibinSh Aug 14, 2024
b0d3e2f
Develop (#82)
CaibinSh Aug 14, 2024
bd64320
fix(workflow): update semantic release
CaibinSh Aug 14, 2024
0dbc5d4
fix(workflow): update semantic release
CaibinSh Aug 14, 2024
0260953
fix(workflow): update semantic release
CaibinSh Aug 14, 2024
3d65da8
update semantic release (#84)
CaibinSh Aug 14, 2024
8adb533
fix(workflow): update semantic release
CaibinSh Aug 14, 2024
9c6a9e2
Merge branch 'main' into develop
CaibinSh Aug 14, 2024
5d882b7
chore: change semantic release branch
CaibinSh Aug 14, 2024
92e3a86
Merge branch 'develop' of https://github.com/Novartis/scar into develop
CaibinSh Aug 14, 2024
9d950b1
chore: change semantic release branch
CaibinSh Aug 14, 2024
0b28e11
chore: change semantic release branch
CaibinSh Aug 14, 2024
c89123a
chore: change semantic release branch
CaibinSh Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix: optomize dataloading for ambient profile
CaibinSh committed May 26, 2024
commit 8b773108070f1e56e806ddb33207c2da9a6bd07c
54 changes: 41 additions & 13 deletions scar/main/_scar.py
Original file line number Diff line number Diff line change
@@ -91,6 +91,17 @@ class model:
Thank Will Macnair for the valuable feedback.

.. versionadded:: 0.4.0
batch_key : str, optional
batch key in AnnData.obs, by default None. \
If assigned, batch ambient removel will be performed and \
the ambient profile will be estimated for each batch.

.. versionadded:: 0.6.1

device : str, optional
either "auto, "cpu" or "cuda", by default "auto"
verbose : bool, optional
whether to print the details, by default True

Raises
------
@@ -200,6 +211,7 @@ def __init__(
feature_type: str = "mRNA",
count_model: str = "binomial",
sparsity: float = 0.9,
batch_key: str = None,
device: str = "auto",
verbose: bool = True,
):
@@ -262,7 +274,7 @@ def __init__(
"""float, the sparsity of expected native signals. (0, 1]. \
Forced to be one in the mode of "sgRNA(s)" and "tag(s)".
"""

if isinstance(raw_count, str):
raw_count = pd.read_pickle(raw_count)
elif isinstance(raw_count, np.ndarray):
@@ -274,8 +286,24 @@ def __init__(
elif isinstance(raw_count, pd.DataFrame):
pass
elif isinstance(raw_count, ad.AnnData):
if batch_key:
if batch_key not in raw_count.obs.columns:
raise ValueError(f"{batch_key} not found in AnnData.obs.")

self.logger.info(
f"Estimating ambient profile for each batch defined by {batch_key} in AnnData.obs..."
)
batch_id_per_cell = pd.Categorical(raw_count.obs[batch_key]).codes
ambient_profile = np.empty((len(np.unique(batch_id_per_cell)),raw_count.shape[1]))
for batch_id in np.unique(batch_id_per_cell):
subset = raw_count[batch_id_per_cell==batch_id]
ambient_profile[batch_id, :] = subset.X.sum(axis=0) / subset.X.sum()

# add a mapper to locate the batch id
self.batch_id = batch_id_per_cell

# get ambient profile from AnnData.uns
if (ambient_profile is None) and ("ambient_profile_all" in raw_count.uns):
elif (ambient_profile is None) and ("ambient_profile_all" in raw_count.uns):
self.logger.info(
"Found ambient profile in AnnData.uns['ambient_profile_all']"
)
@@ -324,8 +352,10 @@ def __init__(
ambient_profile = (
ambient_profile.squeeze()
.reshape(1, -1)
.repeat(raw_count.shape[0], axis=0)
)
# add a mapper to locate the artificial batch id
self.batch_id = np.zeros(raw_count.shape[0])

self.ambient_profile = torch.from_numpy(ambient_profile).float().to(self.device)
"""ambient_profile : np.ndarray, the probability of occurrence of each ambient transcript.
"""
@@ -410,21 +440,17 @@ def train(
train_ids, test_ids = train_test_split(list_ids, train_size=train_size)

# Generators
training_set = UMIDataset(self.raw_count, self.ambient_profile, train_ids)
training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, train_ids)
training_generator = torch.utils.data.DataLoader(
training_set, batch_size=batch_size, shuffle=shuffle
)
val_set = UMIDataset(self.raw_count, self.ambient_profile, test_ids)
val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, test_ids)
val_generator = torch.utils.data.DataLoader(
val_set, batch_size=batch_size, shuffle=shuffle
)

loss_values = []

# self.n_batch_train = len(training_generator)
# self.n_batch_val = len(val_generator)
# self.batch_size = batch_size

# Define model
vae_nets = VAE(
n_features=self.n_features,
@@ -459,7 +485,7 @@ def train(
desc="Training",
)
progress_bar.clear()
for epoch in range(epochs):
for _ in range(epochs):
train_tot_loss = 0
train_kld_loss = 0
train_recon_loss = 0
@@ -559,7 +585,7 @@ def inference(
native_frequencies, and noise_ratio. \
A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type.
"""
total_set = UMIDataset(self.raw_count, self.ambient_profile)
total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id)
n_features = self.n_features
sample_size = self.raw_count.shape[0]
self.native_counts = np.empty([sample_size, n_features])
@@ -677,10 +703,12 @@ def assignment(self, cutoff=3, moi=None):
class UMIDataset(torch.utils.data.Dataset):
"""Characterizes dataset for PyTorch"""

def __init__(self, raw_count, ambient_profile, list_ids=None):
def __init__(self, raw_count, ambient_profile, batch_id, list_ids=None):
"""Initialization"""
self.raw_count = raw_count
self.ambient_profile = ambient_profile
self.batch_id = batch_id

if list_ids:
self.list_ids = list_ids
else:
@@ -695,5 +723,5 @@ def __getitem__(self, index):
# Select sample
sc_id = self.list_ids[index]
sc_count = self.raw_count[sc_id, :]
sc_ambient = self.ambient_profile[sc_id, :]
sc_ambient = self.ambient_profile[self.batch_id[sc_id], :]
return sc_count, sc_ambient