diff --git a/sctk/__init__.py b/sctk/__init__.py index 007d681..e78344d 100644 --- a/sctk/__init__.py +++ b/sctk/__init__.py @@ -99,13 +99,13 @@ auto_zoom_in, calculate_qc, cellwise_qc, - cluster_qc_find_resolution, clusterwise_qc, crossmap, custom_pipeline, generate_qc_clusters, get_good_sized_batch, integrate, + multi_resolution_cluster_qc, recluster_subset, simple_default_pipeline, default_metric_params_df, diff --git a/sctk/_pipeline.py b/sctk/_pipeline.py index 2aef3fa..16e6e33 100644 --- a/sctk/_pipeline.py +++ b/sctk/_pipeline.py @@ -393,10 +393,10 @@ def cellwise_qc(adata, metrics=None, cell_qc_key="cell_passed_qc", uns_qc_key="s Args: adata: AnnData object to filter cells from. - metrics: Optional list/tuple of metric names or dictionary of metric - names and their corresponding parameters. If not provided, the function - uses a set of default metrics. For defaults and an explanation, please - refer to the QC workflow demo notebook. + metrics: Optional data frame specifying the parameters for each metric. + If not provided, the function uses a set of default metrics stored + in ``sctk.default_metric_params_df``. For defaults and an + explanation, please refer to the QC workflow demo notebook. cell_qc_key: Obs column in the object to store the per-cell QC calls in. uns_qc_key: Uns key to store the determined QC ranges used in filtering. **kwargs: Additional keyword arguments to pass to the @@ -767,24 +767,33 @@ def clusterwise_qc( ad.obs[key_added] = ad.obs[clus_key].isin(good_qc_clusters) -def cluster_qc_find_resolution(ad, metrics, - failed=True, - res=np.arange(0.1,1.1,0.1), - n_pcs=None, - n_neighbors=None, - threshold=0.5, - clus_key="qc_cluster", - umap_key="X_umap_qc", - cell_qc_key="cell_passed_qc", - key_added="cluster_passed_qc", - ) -> None: +def multi_resolution_cluster_qc(ad, metrics, + failed=True, + res=np.arange(0.1,1.1,0.1), + n_pcs=None, + n_neighbors=None, + threshold=0.5, + clus_key="qc_cluster", + umap_key="X_umap_qc", + cell_qc_key="cell_passed_qc", + key_added="cluster_passed_qc", + consensus_threshold=0.5, + consensus_frac_key="consensus_fraction", + consensus_call_key="consensus_passed_qc" + ) -> None: """ Run ``generate_qc_clusters()`` and ``clusterwise_qc()`` for a number of - potential resolutions. The optimal resolution is determined as the one that - grants the highest Jaccard index between cell-level and cluster-level QC - calls. + potential resolutions. Uses the multiple clusterings to identify a more + robust set of QC calls than that derived from a single resolution. Proposes + two sets of QC calls: - Stores optimal resolution clusters and QC calls in input object. + 1. A clustering resolution is proposed as the one that grants the highest + Jaccard index between cell-level and cluster-level QC calls. + + 2. For each cell, the fraction of tested resolutions where the cell passes + QC is computed. This fraction is then thresholded to get a QC call. + + Stores both sets of QC calls in input object. Args: ad: AnnData object to generate QC clusters for. @@ -809,6 +818,13 @@ def cluster_qc_find_resolution(ad, metrics, cell QC calls from obs in the AnnData. key_added: ``clusterwise_qc()`` argument. Key to use for storing the results in the AnnData obs object. + consensus_threshold: A cell has to pass QC in more than this fraction of + tested resolutions to be flagged as a good QC cell in the consensus + calls. + consensus_frac_key: Key to use for storing the consensus fraction in the + AnnData obs. + consensus_call_key: Key to use for storing the consensus calls + (thresholded consensus fraction) in the AnnData obs. Returns: None. @@ -823,11 +839,13 @@ def cluster_qc_find_resolution(ad, metrics, >>> sctk.calculate_qc(adata) >>> metrics_list = ["n_counts", "n_genes", "percent_mito", "percent_ribo", "percent_hb"] >>> sctk.cellwise_qc(adata) - >>> sctk.cluster_qc_find_resolution(adata, metrics=metrics_list) + >>> sctk.multi_resolution_cluster_qc(adata, metrics=metrics_list) """ # store the qc cluster object into a helper variable, along with the Jaccards we make aux_ad = None jaccards = [] + # initialise consensus count (to become fraction later) column + ad.obs[consensus_frac_key] = 0 for sres in res: # compute the clusters and then get cluster-level QC calls # calling with aux_ad=None in the first pass will work just fine @@ -853,6 +871,8 @@ def cluster_qc_find_resolution(ad, metrics, jaccards.append(np.sum(~ad.obs[cell_qc_key] & ~ad.obs[key_added])/np.sum(~ad.obs[cell_qc_key] | ~ad.obs[key_added])) else: jaccards.append(np.sum(ad.obs[cell_qc_key] & ad.obs[key_added])/np.sum(ad.obs[cell_qc_key] | ad.obs[key_added])) + # tick up the QC passed count for consensus computation + ad.obs[consensus_frac_key] += ad.obs[key_added].astype(int) # find our best Jaccard and set the output object to have the corresponding clustering best_res = res[np.argmax(jaccards)] print("Best overlap found for resolution "+str(best_res)) @@ -871,6 +891,9 @@ def cluster_qc_find_resolution(ad, metrics, cell_qc_key=cell_qc_key, key_added=key_added ) + # compute the consensus fraction and calls + ad.obs[consensus_frac_key] = ad.obs[consensus_frac_key]/len(res) + ad.obs[consensus_call_key] = (ad.obs[consensus_frac_key] > consensus_threshold) def get_good_sized_batch(batches, min_size=10) -> list: