"
]
@@ -827,9 +839,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "conda_gpu",
+ "display_name": "conda_scar-0.6.0",
"language": "python",
- "name": "conda_gpu"
+ "name": "conda_scar-0.6.0"
},
"language_info": {
"codemirror_mode": {
@@ -841,7 +853,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.9"
+ "version": "3.12.3"
},
"vscode": {
"interpreter": {
diff --git a/docs/tutorials/scAR_tutorial_sgRNA_assignment.ipynb b/docs/tutorials/scAR_tutorial_sgRNA_assignment.ipynb
index 3ad4d94..06ad9c5 100644
--- a/docs/tutorials/scAR_tutorial_sgRNA_assignment.ipynb
+++ b/docs/tutorials/scAR_tutorial_sgRNA_assignment.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
- "id": "db46edf9",
+ "id": "b48c6ca1",
"metadata": {},
"source": [
"# sgRNA assignment\n",
@@ -27,7 +27,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "22b76e35",
+ "id": "21831340",
"metadata": {},
"outputs": [],
"source": [
@@ -42,7 +42,7 @@
{
"cell_type": "code",
"execution_count": 1,
- "id": "39e1033c",
+ "id": "a5425759",
"metadata": {},
"outputs": [],
"source": [
@@ -58,7 +58,7 @@
},
{
"cell_type": "markdown",
- "id": "86f1e620",
+ "id": "162dc896",
"metadata": {},
"source": [
"## Download data\n",
@@ -69,13 +69,13 @@
{
"cell_type": "code",
"execution_count": 2,
- "id": "e2d98e7d",
+ "id": "a2f319f4",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "387a73ce320c440f963734dbf756e882",
+ "model_id": "ae48a91f3a364204a6b55388c9c0582e",
"version_major": 2,
"version_minor": 0
},
@@ -97,7 +97,7 @@
},
{
"cell_type": "markdown",
- "id": "bde606d4",
+ "id": "11f41dbe",
"metadata": {},
"source": [
"sgRNA counts (unfiltered droplets)"
@@ -106,7 +106,7 @@
{
"cell_type": "code",
"execution_count": 3,
- "id": "0dfe91ac",
+ "id": "bb4028ce",
"metadata": {},
"outputs": [],
"source": [
@@ -115,7 +115,7 @@
},
{
"cell_type": "markdown",
- "id": "c2a8ea3a",
+ "id": "6ef0d42f",
"metadata": {
"tags": []
},
@@ -125,7 +125,7 @@
},
{
"cell_type": "markdown",
- "id": "ea33196f",
+ "id": "83b8c23f",
"metadata": {},
"source": [
"Identify cell-containing and cell-free droplets using kneeplot of mRNA counts."
@@ -133,7 +133,7 @@
},
{
"cell_type": "markdown",
- "id": "09882854",
+ "id": "b3005918",
"metadata": {},
"source": [
"\n",
@@ -148,7 +148,7 @@
{
"cell_type": "code",
"execution_count": 4,
- "id": "63ae0666",
+ "id": "53b0ac8c",
"metadata": {},
"outputs": [],
"source": [
@@ -164,7 +164,7 @@
},
{
"cell_type": "markdown",
- "id": "965e7f43",
+ "id": "2074f2c5",
"metadata": {},
"source": [
"The thresholds (200 and 500) are experiment-specific. We currently manually determine them by examing the following kneeplot. "
@@ -173,12 +173,12 @@
{
"cell_type": "code",
"execution_count": 5,
- "id": "f16e1041",
+ "id": "4c7172d8",
"metadata": {},
"outputs": [
{
"data": {
- "image/png": "",
+ "image/png": "",
"text/plain": [
""
]
@@ -213,7 +213,7 @@
},
{
"cell_type": "markdown",
- "id": "af962da9",
+ "id": "76384998",
"metadata": {
"tags": []
},
@@ -224,7 +224,7 @@
{
"cell_type": "code",
"execution_count": 6,
- "id": "6ea194fa",
+ "id": "da65b16c",
"metadata": {},
"outputs": [],
"source": [
@@ -234,7 +234,7 @@
{
"cell_type": "code",
"execution_count": 7,
- "id": "f653d1e1",
+ "id": "6f1f6559",
"metadata": {},
"outputs": [
{
@@ -452,7 +452,7 @@
},
{
"cell_type": "markdown",
- "id": "a87b59af",
+ "id": "1bafe491",
"metadata": {},
"source": [
"**Ambient profile of sgRNAs**"
@@ -461,7 +461,7 @@
{
"cell_type": "code",
"execution_count": 8,
- "id": "d452ce6c",
+ "id": "dd0661af",
"metadata": {},
"outputs": [
{
@@ -537,7 +537,7 @@
},
{
"cell_type": "markdown",
- "id": "dc36ac46",
+ "id": "eccdb2c9",
"metadata": {},
"source": [
"## Training"
@@ -546,33 +546,34 @@
{
"cell_type": "code",
"execution_count": 9,
- "id": "020c0b79",
+ "id": "33a6a2e9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "2023-05-01 16:45:39|INFO|VAE|Running VAE using the following param set:\n",
- "2023-05-01 16:45:39|INFO|VAE|...denoised count type: sgRNAs\n",
- "2023-05-01 16:45:39|INFO|VAE|...count model: binomial\n",
- "2023-05-01 16:45:39|INFO|VAE|...num_input_feature: 93\n",
- "2023-05-01 16:45:39|INFO|VAE|...NN_layer1: 150\n",
- "2023-05-01 16:45:39|INFO|VAE|...NN_layer2: 100\n",
- "2023-05-01 16:45:39|INFO|VAE|...latent_space: 15\n",
- "2023-05-01 16:45:39|INFO|VAE|...dropout_prob: 0.00\n",
- "2023-05-01 16:45:39|INFO|VAE|...expected data sparsity: 1.00\n",
- "2023-05-01 16:45:40|INFO|model|kld_weight: 1.00e-05\n",
- "2023-05-01 16:45:40|INFO|model|learning rate: 1.00e-03\n",
- "2023-05-01 16:45:40|INFO|model|lr_step_size: 5\n",
- "2023-05-01 16:45:40|INFO|model|lr_gamma: 0.97\n"
+ "2024-08-10 21:02:55|INFO|model|cuda is detected and will be used.\n",
+ "2024-08-10 21:02:55|INFO|VAE|Running VAE using the following param set:\n",
+ "2024-08-10 21:02:55|INFO|VAE|...denoised count type: sgRNAs\n",
+ "2024-08-10 21:02:55|INFO|VAE|...count model: binomial\n",
+ "2024-08-10 21:02:55|INFO|VAE|...num_input_feature: 93\n",
+ "2024-08-10 21:02:55|INFO|VAE|...NN_layer1: 150\n",
+ "2024-08-10 21:02:55|INFO|VAE|...NN_layer2: 100\n",
+ "2024-08-10 21:02:55|INFO|VAE|...latent_space: 15\n",
+ "2024-08-10 21:02:55|INFO|VAE|...dropout_prob: 0.00\n",
+ "2024-08-10 21:02:55|INFO|VAE|...expected data sparsity: 1.00\n",
+ "2024-08-10 21:02:55|INFO|model|kld_weight: 1.00e-05\n",
+ "2024-08-10 21:02:55|INFO|model|learning rate: 1.00e-03\n",
+ "2024-08-10 21:02:55|INFO|model|lr_step_size: 5\n",
+ "2024-08-10 21:02:55|INFO|model|lr_gamma: 0.97\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Training: 100%|██████████| 100/100 [05:17<00:00, 3.18s/it, Loss=2.8968e+02]\n"
+ "Training: 100%|██████████| 100/100 [02:55<00:00, 1.76s/it, Loss=4.2189e+02]\n"
]
}
],
@@ -592,7 +593,7 @@
},
{
"cell_type": "markdown",
- "id": "397a229f",
+ "id": "510e3173",
"metadata": {},
"source": [
"Resulting assignment is saved in `sgRNAs.feature_assignment`.\n",
@@ -603,7 +604,7 @@
{
"cell_type": "code",
"execution_count": 10,
- "id": "5197865c",
+ "id": "b0a4f116",
"metadata": {},
"outputs": [
{
@@ -634,8 +635,8 @@
" \n",
" \n",
" AAACCCAAGCTAAGTA-1 \n",
- " H2AFY-1 \n",
- " 1 \n",
+ " APH1A-1, H2AFY-1 \n",
+ " 2 \n",
" \n",
" \n",
" AAACCCAAGGAAGTGA-1 \n",
@@ -644,13 +645,13 @@
" \n",
" \n",
" AAACCCAAGGTTGGAC-1 \n",
- " ACE2-1 \n",
- " 1 \n",
+ " GSK3A-1, ACE2-1 \n",
+ " 2 \n",
" \n",
" \n",
" AAACCCAAGTGCGTCC-1 \n",
- " ACE2-2 \n",
- " 1 \n",
+ " H2AFY-2, ACE2-2 \n",
+ " 2 \n",
" \n",
" \n",
" AAACCCAAGTGCTCGC-1 \n",
@@ -693,18 +694,18 @@
""
],
"text/plain": [
- " sgRNAs n_sgRNAs\n",
- "AAACCCAAGCTAAGTA-1 H2AFY-1 1\n",
- "AAACCCAAGGAAGTGA-1 PPIB-2 1\n",
- "AAACCCAAGGTTGGAC-1 ACE2-1 1\n",
- "AAACCCAAGTGCGTCC-1 ACE2-2 1\n",
- "AAACCCAAGTGCTCGC-1 CTCF-2 1\n",
- "... ... ...\n",
- "TTTGTTGTCCCATTTA-1 CSNK2A1-1 1\n",
- "TTTGTTGTCGGAACTT-1 EIF4EBP1-2 1\n",
- "TTTGTTGTCGGCTGTG-1 SUZ12-1 1\n",
- "TTTGTTGTCTGGGCGT-1 GSK3A-2 1\n",
- "TTTGTTGTCTTCCAGC-1 PPIB-2, RBBP4-1 2\n",
+ " sgRNAs n_sgRNAs\n",
+ "AAACCCAAGCTAAGTA-1 APH1A-1, H2AFY-1 2\n",
+ "AAACCCAAGGAAGTGA-1 PPIB-2 1\n",
+ "AAACCCAAGGTTGGAC-1 GSK3A-1, ACE2-1 2\n",
+ "AAACCCAAGTGCGTCC-1 H2AFY-2, ACE2-2 2\n",
+ "AAACCCAAGTGCTCGC-1 CTCF-2 1\n",
+ "... ... ...\n",
+ "TTTGTTGTCCCATTTA-1 CSNK2A1-1 1\n",
+ "TTTGTTGTCGGAACTT-1 EIF4EBP1-2 1\n",
+ "TTTGTTGTCGGCTGTG-1 SUZ12-1 1\n",
+ "TTTGTTGTCTGGGCGT-1 GSK3A-2 1\n",
+ "TTTGTTGTCTTCCAGC-1 PPIB-2, RBBP4-1 2\n",
"\n",
"[21091 rows x 2 columns]"
]
@@ -720,7 +721,7 @@
},
{
"cell_type": "markdown",
- "id": "8131bc78",
+ "id": "cebd202d",
"metadata": {},
"source": [
"## Visulization"
@@ -728,7 +729,7 @@
},
{
"cell_type": "markdown",
- "id": "15b9190b",
+ "id": "46dc569a",
"metadata": {},
"source": [
"Plot setting"
@@ -737,7 +738,7 @@
{
"cell_type": "code",
"execution_count": 11,
- "id": "5234d078",
+ "id": "b304bffd",
"metadata": {},
"outputs": [
{
@@ -774,7 +775,7 @@
},
{
"cell_type": "markdown",
- "id": "cfc89b2f",
+ "id": "ee7c6080",
"metadata": {},
"source": [
"### Cell number of sgRNA assignments"
@@ -783,7 +784,7 @@
{
"cell_type": "code",
"execution_count": 12,
- "id": "7e892df4",
+ "id": "aa8de84e",
"metadata": {
"tags": [
"nbsphinx-gallery",
@@ -817,7 +818,7 @@
},
{
"cell_type": "markdown",
- "id": "4870a675",
+ "id": "13ce1cf5",
"metadata": {},
"source": [
"Most of cells are assigned with a single sgRNA."
@@ -826,9 +827,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "conda_gpu",
+ "display_name": "conda_scar-0.6.0",
"language": "python",
- "name": "conda_gpu"
+ "name": "conda_scar-0.6.0"
},
"language_info": {
"codemirror_mode": {
@@ -840,7 +841,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.9"
+ "version": "3.12.3"
},
"vscode": {
"interpreter": {
diff --git a/scar/main/__main__.py b/scar/main/__main__.py
index 0f22dcb..fb7cb07 100644
--- a/scar/main/__main__.py
+++ b/scar/main/__main__.py
@@ -4,14 +4,11 @@
import argparse
import os
-import pandas as pd
-import scanpy as sc
-from scipy.sparse import csr_matrix
+import pandas as pd, scanpy as sc
from ._scar import model
from ..__init__ import __version__
from ._utils import get_logger
-
def main():
"""main function for command line interface"""
args = Config()
@@ -28,6 +25,9 @@ def main():
epochs = args.epochs
device = args.device
sparsity = args.sparsity
+ batchkey = args.batchkey
+ cachecapacity = args.cachecapacity
+ gnf = bool(args.get_native_frequencies)
save_model = args.save_model
batch_size = args.batchsize
batch_size_infer = args.batchsize_infer
@@ -63,37 +63,37 @@ def main():
if feature_type.lower() == "all":
features = adata.var["feature_types"].unique()
- count_matrix = adata.to_df()
+ count_matrix = adata.copy()
# Denoising mRNAs
elif feature_type.lower() in ["mrna", "mrnas"]:
features = "Gene Expression"
adata_fb = adata[:, adata.var["feature_types"] == features]
- count_matrix = adata_fb.to_df()
+ count_matrix = adata_fb.copy()
# Denoising sgRNAs
elif feature_type.lower() in ["sgrna", "sgrnas"]:
features = "CRISPR Guide Capture"
adata_fb = adata[:, adata.var["feature_types"] == features]
- count_matrix = adata_fb.to_df()
+ count_matrix = adata_fb.copy()
# Denoising CMO tags
elif feature_type.lower() in ["tag", "tags"]:
features = "Multiplexing Capture"
adata_fb = adata[:, adata.var["feature_types"] == features]
- count_matrix = adata_fb.to_df()
+ count_matrix = adata_fb.copy()
# Denoising ADTs
elif feature_type.lower() in ["adt", "adts"]:
features = "Antibody Capture"
adata_fb = adata[:, adata.var["feature_types"] == features]
- count_matrix = adata_fb.to_df()
+ count_matrix = adata_fb.copy()
# Denoising ATAC peaks
elif feature_type.lower() in ["atac"]:
features = "Peaks"
adata_fb = adata[:, adata.var["feature_types"] == features]
- count_matrix = adata_fb.to_df()
+ count_matrix = adata_fb.copy()
main_logger.info(f"modalities to denoise: {features}")
@@ -135,6 +135,8 @@ def main():
latent_dim=latent_dim,
feature_type=feature_type,
count_model=count_model,
+ batch_key=batchkey,
+ cache_capacity=cachecapacity,
sparsity=sparsity,
device=device,
)
@@ -147,6 +149,7 @@ def main():
scar_model.inference(
adjust=adjust,
+ get_native_frequencies=gnf,
round_to_int=round_to_int,
batch_size=batch_size_infer,
clip_to_obs=clip_to_obs,
@@ -167,30 +170,35 @@ def main():
)
pd.DataFrame(
- scar_model.native_counts,
+ scar_model.native_counts.toarray(),
index=count_matrix.index,
columns=count_matrix.columns,
).to_pickle(output_path01)
+ main_logger.info(f"denoised counts saved in: {output_path01}")
+
pd.DataFrame(
- scar_model.bayesfactor,
- index=count_matrix.index,
- columns=count_matrix.columns,
- ).to_pickle(output_path02)
- pd.DataFrame(
- scar_model.native_frequencies,
- index=count_matrix.index,
- columns=count_matrix.columns,
- ).to_pickle(output_path03)
- pd.DataFrame(
- scar_model.noise_ratio, index=count_matrix.index, columns=["noise_ratio"]
+ scar_model.noise_ratio.toarray(),
+ index=count_matrix.index,
+ columns=["noise_ratio"]
).to_pickle(output_path04)
-
- main_logger.info(f"denoised counts saved in: {output_path01}")
- main_logger.info(f"BayesFactor matrix saved in: {output_path02}")
- main_logger.info(f"expected native frequencies saved in: {output_path03}")
main_logger.info(f"expected noise ratio saved in: {output_path04}")
+ if scar_model.native_frequencies is not None:
+ pd.DataFrame(
+ scar_model.native_frequencies.toarray(),
+ index=count_matrix.index,
+ columns=count_matrix.columns,
+ ).to_pickle(output_path03)
+ main_logger.info(f"expected native frequencies saved in: {output_path03}")
+
if feature_type.lower() in ["sgrna", "sgrnas", "tag", "tags", "cmo", "cmos"]:
+ pd.DataFrame(
+ scar_model.bayesfactor.toarray(),
+ index=count_matrix.index,
+ columns=count_matrix.columns,
+ ).to_pickle(output_path02)
+ main_logger.info(f"BayesFactor matrix saved in: {output_path02}")
+
output_path05 = os.path.join(output_dir, "assignment.pickle")
scar_model.feature_assignment.to_pickle(output_path05)
main_logger.info(f"assignment saved in: {output_path05}")
@@ -201,23 +209,21 @@ def main():
)
denoised_adata = adata.copy()
- denoised_adata.X = csr_matrix(scar_model.native_counts)
+ denoised_adata.X = scar_model.native_counts
denoised_adata.obs["noise_ratio"] = pd.DataFrame(
- scar_model.noise_ratio,
- index=count_matrix.index,
+ scar_model.noise_ratio.toarray(),
+ index=count_matrix.obs_names,
columns=["noise_ratio"],
)
-
- denoised_adata.layers["native_frequencies"] = csr_matrix(
- scar_model.native_frequencies
- )
- denoised_adata.layers["BayesFactor"] = csr_matrix(scar_model.bayesfactor)
+ if scar_model.native_frequencies is not None:
+ denoised_adata.layers["native_frequencies"] = scar_model.native_frequencies.toarray()
if feature_type.lower() in ["sgrna", "sgrnas", "tag", "tags", "cmo", "cmos"]:
denoised_adata.obs = denoised_adata.obs.join(scar_model.feature_assignment)
+ denoised_adata.layers["BayesFactor"] = scar_model.bayesfactor.toarray()
denoised_adata.write(output_path_h5ad)
- main_logger.info("the denoised h5ad file saved in: {output_path_h5ad}")
+ main_logger.info(f"the denoised h5ad file saved in: {output_path_h5ad}")
class Config:
@@ -238,8 +244,7 @@ def scar_parser():
"""Argument parser"""
parser = argparse.ArgumentParser(
- description="scAR (single cell Ambient Remover): \
- denoising drop-based single-cell omics data",
+ description="scAR (single-cell Ambient Remover) is a deep learning model for removal of the ambient signals in droplet-based single cell omics",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
@@ -280,6 +285,27 @@ def scar_parser():
default=0.9,
help="The sparsity of expected native signals",
)
+ parser.add_argument(
+ "-bk",
+ "--batchkey",
+ type=str,
+ default=None,
+ help="The batch key for batch correction",
+ )
+ parser.add_argument(
+ "-cache",
+ "--cachecapacity",
+ type=int,
+ default=20000,
+ help="The capacity of cache for batch correction",
+ )
+ parser.add_argument(
+ "-gnf",
+ "--get_native_frequencies",
+ type=int,
+ default=0,
+ help="Whether to get native frequencies, 0 or 1, by default 0, not to get native frequencies",
+ )
parser.add_argument(
"-hl1",
"--hidden_layer1",
diff --git a/scar/main/_scar.py b/scar/main/_scar.py
index 5d8a94d..7fcaac4 100644
--- a/scar/main/_scar.py
+++ b/scar/main/_scar.py
@@ -2,17 +2,12 @@
"""The main module of scar"""
-import sys
-import time
-import warnings
+import sys, time, contextlib, torch
from typing import Optional, Union
-import contextlib
-import numpy as np
-import pandas as pd
-import anndata as ad
+from scipy import sparse
+import numpy as np, pandas as pd, anndata as ad
-import torch
-from sklearn.model_selection import train_test_split
+from torch.utils.data import Dataset, random_split, DataLoader
from tqdm import tqdm
from tqdm.contrib import DummyTqdmFile
@@ -91,15 +86,19 @@ class model:
Thank Will Macnair for the valuable feedback.
.. versionadded:: 0.4.0
+ cache_capacity : int, optional
+ the capacity of caching data on GPU. Set a smaller value upon GPU memory issue. By default 20000 cells are cached.
+
+ .. versionadded:: 0.7.0
batch_key : str, optional
batch key in AnnData.obs, by default None. \
If assigned, batch ambient removel will be performed and \
the ambient profile will be estimated for each batch.
- .. versionadded:: 0.6.1
+ .. versionadded:: 0.7.0
device : str, optional
- either "auto, "cpu" or "cuda", by default "auto"
+ either "auto, "cpu" or "cuda" or "mps", by default "auto"
verbose : bool, optional
whether to print the details, by default True
@@ -153,7 +152,7 @@ class model:
sorted_native_counts = citeseq.native_signals[citeseq.celltype.argsort()][
:, citeseq.ambient_profile.argsort()
] # native counts
- sorted_denoised_counts = citeseq_denoised.native_counts[citeseq.celltype.argsort()][
+ sorted_denoised_counts = citeseq_denoised.native_counts.toarray()[citeseq.celltype.argsort()][
:, citeseq.ambient_profile.argsort()
] # denoised counts
@@ -213,6 +212,7 @@ def __init__(
sparsity: float = 0.9,
batch_key: str = None,
device: str = "auto",
+ cache_capacity: int = 20000,
verbose: bool = True,
):
"""initialize object"""
@@ -224,14 +224,13 @@ def __init__(
if device == "auto":
if torch.cuda.is_available():
self.device = torch.device("cuda")
- self.logger.info("CPU is detected and will be used.")
+ self.logger.info(f"{self.device} is detected and will be used.")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
self.device = torch.device("mps")
- self.logger.info("MPS is detected and will be used.")
- self.logger.warning("PyTorch is slower on MPS than on the CPU; we recommend using the CPU by specifying device='cpu' on Mac.")
+ self.logger.info(f"{self.device} is detected and will be used.")
else:
self.device = torch.device("cpu")
- self.logger.info("No GPU detected. Use CPU instead.")
+ self.logger.info(f"No GPU detected. {self.device} will be used.")
else:
self.device = device
self.logger.info(f"{device} will be used.")
@@ -274,24 +273,19 @@ def __init__(
"""float, the sparsity of expected native signals. (0, 1]. \
Forced to be one in the mode of "sgRNA(s)" and "tag(s)".
"""
+ self.cache_capacity = cache_capacity
+ """int, the capacity of caching data on GPU. Set a smaller value upon GPU memory issue. By default 20000 cells are cached on GPU/MPS.
+
+ .. versionadded:: 0.7.0
+ """
- if isinstance(raw_count, str):
- raw_count = pd.read_pickle(raw_count)
- elif isinstance(raw_count, np.ndarray):
- raw_count = pd.DataFrame(
- raw_count,
- index=range(raw_count.shape[0]),
- columns=range(raw_count.shape[1]),
- )
- elif isinstance(raw_count, pd.DataFrame):
- pass
- elif isinstance(raw_count, ad.AnnData):
- if batch_key:
+ if isinstance(raw_count, ad.AnnData):
+ if batch_key is not None:
if batch_key not in raw_count.obs.columns:
raise ValueError(f"{batch_key} not found in AnnData.obs.")
self.logger.info(
- f"Estimating ambient profile for each batch defined by {batch_key} in AnnData.obs..."
+ f"Found {raw_count.obs[batch_key].nunique()} batches defined by {batch_key} in AnnData.obs. Estimating ambient profile per batch..."
)
batch_id_per_cell = pd.Categorical(raw_count.obs[batch_key]).codes
ambient_profile = np.empty((len(np.unique(batch_id_per_cell)),raw_count.shape[1]))
@@ -300,38 +294,51 @@ def __init__(
ambient_profile[batch_id, :] = subset.X.sum(axis=0) / subset.X.sum()
# add a mapper to locate the batch id
- self.batch_id = torch.from_numpy(batch_id_per_cell).int().to(self.device)
- self.n_batch = np.unique(batch_id_per_cell).size
+ self.batch_id = batch_id_per_cell
+ self.n_batch = len(np.unique(batch_id_per_cell))
+ else:
+ # get ambient profile from AnnData.uns
+ if "ambient_profile_all" in raw_count.uns:
+ self.logger.info(
+ "Found ambient profile in AnnData.uns['ambient_profile_all']"
+ )
+ ambient_profile = raw_count.uns["ambient_profile_all"]
+ else:
+ self.logger.info(
+ "Ambient profile not found in AnnData.uns['ambient_profile'], estimating it by averaging pooled cells..."
+ )
- # get ambient profile from AnnData.uns
- elif (ambient_profile is None) and ("ambient_profile_all" in raw_count.uns):
- self.logger.info(
- "Found ambient profile in AnnData.uns['ambient_profile_all']"
- )
- ambient_profile = raw_count.uns["ambient_profile_all"]
- elif (ambient_profile is None) and (
- "ambient_profile_all" not in raw_count.uns
- ):
- self.logger.info(
- "Ambient profile not found in AnnData.uns['ambient_profile'], estimating it by averaging pooled cells..."
- )
- # convert AnnData to pd.DataFrame
- raw_count = raw_count.to_df()
+ elif isinstance(raw_count, str):
+ # read pickle file into dataframe
+ raw_count = pd.read_pickle(raw_count)
+
+ elif isinstance(raw_count, np.ndarray):
+ # convert np.array to pd.DataFrame
+ raw_count = pd.DataFrame(
+ raw_count,
+ index=range(raw_count.shape[0]),
+ columns=range(raw_count.shape[1]),
+ )
+
+ elif isinstance(raw_count, pd.DataFrame):
+ pass
else:
raise TypeError(
- f"Expecting str or np.array or pd.DataFrame object, but get a {type(raw_count)}"
+ f"Expecting str or np.array or pd.DataFrame or AnnData object, but get a {type(raw_count)}"
)
- raw_count = raw_count.fillna(0) # missing vals -> zeros
-
- # Loading numpy to tensor on GPU
- self.raw_count = raw_count.values
+ self.raw_count = raw_count
"""raw_count : np.ndarray, raw count matrix.
"""
self.n_features = raw_count.shape[1]
-
- self.cell_id = list(raw_count.index)
- self.feature_names = list(raw_count.columns)
+ """int, number of features.
+ """
+ self.cell_id = raw_count.index.to_list() if isinstance(raw_count, pd.DataFrame) else raw_count.obs_names.to_list()
+ """list, cell id.
+ """
+ self.feature_names = raw_count.columns.to_list() if isinstance(raw_count, pd.DataFrame) else raw_count.var_names.to_list()
+ """list, feature names.
+ """
if isinstance(ambient_profile, str):
ambient_profile = pd.read_pickle(ambient_profile)
@@ -341,9 +348,13 @@ def __init__(
elif isinstance(ambient_profile, np.ndarray):
ambient_profile = np.nan_to_num(ambient_profile) # missing vals -> zeros
elif not ambient_profile:
- self.logger.info(" Evaluate empty profile from cells")
- ambient_profile = raw_count.sum() / raw_count.sum().sum()
- ambient_profile = ambient_profile.fillna(0).values
+ self.logger.info(" Evaluate ambient profile from cells")
+ if isinstance(raw_count, pd.DataFrame):
+ ambient_profile = raw_count.sum() / raw_count.sum().sum()
+ ambient_profile = ambient_profile.fillna(0).values
+ elif isinstance(raw_count, ad.AnnData):
+ ambient_profile = np.array(raw_count.X.sum(axis=0)/raw_count.X.sum())
+ ambient_profile = np.nan_to_num(ambient_profile).flatten()
else:
raise TypeError(
f"Expecting str / np.array / None / pd.DataFrame, but get a {type(ambient_profile)}"
@@ -355,7 +366,7 @@ def __init__(
.reshape(1, -1)
)
# add a mapper to locate the artificial batch id
- self.batch_id = torch.zeros(raw_count.shape[0]).int().to(self.device)
+ self.batch_id = np.zeros(raw_count.shape[0], dtype=int)#.reshape(-1, 1)
self.n_batch = 1
self.ambient_profile = ambient_profile
@@ -437,19 +448,14 @@ def train(
After training, a trained_model attribute will be added.
"""
-
- list_ids = list(range(self.raw_count.shape[0]))
- train_ids, test_ids = train_test_split(list_ids, train_size=train_size)
-
# Generators
- training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, list_ids=train_ids)
- training_generator = torch.utils.data.DataLoader(
- training_set, batch_size=batch_size, shuffle=shuffle
- )
- val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, list_ids=test_ids)
- val_generator = torch.utils.data.DataLoader(
- val_set, batch_size=batch_size, shuffle=shuffle
+ total_dataset = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, cache_capacity=self.cache_capacity)
+ training_set, validation_set = random_split(total_dataset, [train_size, 1 - train_size])
+ training_generator = DataLoader(
+ training_set, batch_size=batch_size, shuffle=shuffle,
+ drop_last=True
)
+ self.dataset = total_dataset
loss_values = []
@@ -543,6 +549,7 @@ def inference(
cutoff=3,
round_to_int="stochastic_rounding",
clip_to_obs=False,
+ get_native_frequencies=False,
moi=None,
):
"""inference infering the expected native signals, noise ratios, Bayesfactors and expected native frequencies
@@ -576,6 +583,11 @@ def inference(
Use it with caution, as it may lead to over-estimation of overall noise.
.. versionadded:: 0.5.0
+
+ get_native_frequencies : bool, optional
+ whether to get native frequencies, by default False
+
+ .. versionadded:: 0.7.0
moi : int, optional (under development)
multiplicity of infection. If assigned, it will allow optimized thresholding, \
@@ -588,19 +600,33 @@ def inference(
native_frequencies, and noise_ratio. \
A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type.
"""
- total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device)
n_features = self.n_features
sample_size = self.raw_count.shape[0]
- self.native_counts = np.empty([sample_size, n_features])
- self.bayesfactor = np.empty([sample_size, n_features])
- self.native_frequencies = np.empty([sample_size, n_features])
- self.noise_ratio = np.empty([sample_size, 1])
+ dt = np.int64 if round_to_int=="stochastic_rounding" else np.float32
+ native_counts = sparse.lil_matrix((sample_size, n_features), dtype=dt)
+ noise_ratio = sparse.lil_matrix((sample_size, 1), dtype=np.float32)
+
+ native_frequencies = sparse.lil_matrix((sample_size, n_features), dtype=np.float32) if get_native_frequencies else None
+
+ if self.feature_type.lower() in [
+ "sgrna",
+ "sgrnas",
+ "tag",
+ "tags",
+ "cmo",
+ "cmos",
+ "atac",
+ ]:
+ bayesfactor = sparse.lil_matrix((sample_size, n_features), dtype=np.float32)
+ else:
+ bayesfactor = None
+
if not batch_size:
batch_size = sample_size
i = 0
- generator_full_data = torch.utils.data.DataLoader(
- total_set, batch_size=batch_size, shuffle=False
+ generator_full_data = DataLoader(
+ self.dataset, batch_size=batch_size, shuffle=False
)
for x_batch_tot, ambient_freq_tot, x_batch_id_onehot_tot in generator_full_data:
@@ -622,20 +648,28 @@ def inference(
round_to_int=round_to_int,
clip_to_obs=clip_to_obs,
)
- self.native_counts[
+ native_counts[
i * batch_size : i * batch_size + minibatch_size, :
] = native_counts_batch
- self.bayesfactor[
- i * batch_size : i * batch_size + minibatch_size, :
- ] = bayesfactor_batch
- self.native_frequencies[
- i * batch_size : i * batch_size + minibatch_size, :
- ] = native_frequencies_batch
- self.noise_ratio[
+ noise_ratio[
i * batch_size : i * batch_size + minibatch_size, :
] = noise_ratio_batch
+ if native_frequencies is not None:
+ native_frequencies[
+ i * batch_size : i * batch_size + minibatch_size, :
+ ] = native_frequencies_batch
+ if bayesfactor is not None:
+ bayesfactor[
+ i * batch_size : i * batch_size + minibatch_size, :
+ ] = bayesfactor_batch
+
i += 1
+ self.native_counts = native_counts.tocsr()
+ self.noise_ratio = noise_ratio.tocsr()
+ self.bayesfactor = bayesfactor.tocsr() if bayesfactor is not None else None
+ self.native_frequencies = native_frequencies.tocsr() if native_frequencies is not None else None
+
if self.feature_type.lower() in [
"sgrna",
"sgrnas",
@@ -676,7 +710,7 @@ def assignment(self, cutoff=3, moi=None):
index=self.cell_id, columns=[self.feature_type, f"n_{self.feature_type}"]
)
bayesfactor_df = pd.DataFrame(
- self.bayesfactor, index=self.cell_id, columns=self.feature_names
+ self.bayesfactor.toarray(), index=self.cell_id, columns=self.feature_names
)
bayesfactor_df[bayesfactor_df < cutoff] = 0 # Apply the cutoff for Bayesfactors
@@ -703,39 +737,39 @@ def assignment(self, cutoff=3, moi=None):
if moi:
raise NotImplementedError
-
-class UMIDataset(torch.utils.data.Dataset):
+class UMIDataset(Dataset):
"""Characterizes dataset for PyTorch"""
- def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None):
+ def __init__(self, raw_count, ambient_profile, batch_id, device, cache_capacity=20000):
"""Initialization"""
- self.device = device
- self.raw_count = torch.from_numpy(raw_count).int().to(device)
+
+ self.raw_count = torch.from_numpy(raw_count.fillna(0).values).int() if isinstance(raw_count, pd.DataFrame) else raw_count
self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device)
- self.batch_id = batch_id.to(torch.int64).to(device)
- self.batch_onehot = self._onehot()
+ self.batch_id = torch.from_numpy(batch_id).to(torch.int64).to(device)
+ self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64).to(device)
+ self.device = device
+ self.cache_capacity = cache_capacity
- if list_ids:
- self.list_ids = list_ids
- else:
- self.list_ids = list(range(raw_count.shape[0]))
+ # Cache data
+ self.cache = {}
def __len__(self):
"""Denotes the total number of samples"""
- return len(self.list_ids)
+ return self.raw_count.shape[0]
def __getitem__(self, index):
"""Generates one sample of data"""
- # Select sample
- sc_id = self.list_ids[index]
- sc_count = self.raw_count[sc_id, :]
- sc_ambient = self.ambient_profile[self.batch_id[sc_id], :]
- sc_batch_id_onehot = self.batch_onehot[self.batch_id[sc_id], :]
- return sc_count, sc_ambient, sc_batch_id_onehot
-
- def _onehot(self):
- """One-hot encoding"""
- n_batch = self.batch_id.unique().size()[0]
- x_onehot = torch.zeros(n_batch, n_batch).to(self.device)
- x_onehot.scatter_(1, self.batch_id.unique().unsqueeze(1), 1)
- return x_onehot
\ No newline at end of file
+
+ if index in self.cache:
+ return self.cache[index]
+ else:
+ # Select samples
+ sc_count = self.raw_count[index].to(self.device) if isinstance(self.raw_count, torch.Tensor) else torch.from_numpy(self.raw_count[index].X.toarray().flatten()).int().to(self.device)
+ sc_ambient = self.ambient_profile[self.batch_id[index], :]
+ sc_batch_id_onehot = self.batch_onehot[self.batch_id[index], :]
+
+ # Cache samples
+ if len(self.cache) <= self.cache_capacity:
+ self.cache[index] = (sc_count, sc_ambient, sc_batch_id_onehot)
+
+ return sc_count, sc_ambient, sc_batch_id_onehot
diff --git a/scar/main/_vae.py b/scar/main/_vae.py
index 00368e1..98dfd93 100644
--- a/scar/main/_vae.py
+++ b/scar/main/_vae.py
@@ -149,20 +149,18 @@ def inference(
elif round_to_int.lower() == "stochastic_rounding":
expected_native_counts = (
np.floor(expected_native_counts)
- + np.random.binomial(
- 1,
- expected_native_counts - np.floor(expected_native_counts),
- expected_native_counts.shape,
- )
+ + (
+ np.random.rand(*expected_native_counts.shape)
+ < expected_native_counts - np.floor(expected_native_counts)
+ ).astype(int)
).astype(int)
expected_amb_counts = (
np.floor(expected_amb_counts)
- + np.random.binomial(
- 1,
- expected_amb_counts - np.floor(expected_amb_counts),
- expected_amb_counts.shape,
- )
+ + (
+ np.random.rand(*expected_amb_counts.shape)
+ < expected_amb_counts - np.floor(expected_amb_counts)
+ ).astype(int)
).astype(int)
if clip_to_obs:
@@ -172,13 +170,6 @@ def inference(
a_max=input_matrix_np,
)
- if clip_to_obs:
- expected_native_counts = np.clip(
- expected_native_counts,
- a_min=np.zeros_like(input_matrix_np),
- a_max=input_matrix_np,
- )
-
if not adjust:
adjust = 0
elif adjust == "global":
diff --git a/scar/test/test_scar.py b/scar/test/test_scar.py
index 502d62d..3080d05 100755
--- a/scar/test/test_scar.py
+++ b/scar/test/test_scar.py
@@ -22,7 +22,7 @@ def test_scar(self):
feature_type="sgRNAs",
)
- scarObj.train(epochs=40, batch_size=64)
+ scarObj.train(epochs=40, batch_size=32)
scarObj.inference()
@@ -58,7 +58,7 @@ def test_scar_citeseq(self):
feature_type="ADTs",
)
- citeseq_scar.train(epochs=200, batch_size=64, verbose=False)
+ citeseq_scar.train(epochs=200, batch_size=32, verbose=False)
citeseq_scar.inference()
dist = euclidean_distances(