diff --git a/scarf/assay.py b/scarf/assay.py index 477d8c0..83cf143 100644 --- a/scarf/assay.py +++ b/scarf/assay.py @@ -9,17 +9,18 @@ method for feature selection. """ -from typing import Tuple, List, Generator, Optional, Union +from typing import Generator, List, Optional, Tuple, Union import numpy as np import pandas as pd +import zarr from dask.array.core import Array as daskArrayType from dask.array.core import from_zarr from scipy.sparse import csr_matrix, vstack from zarr import hierarchy as z_hierarchy from .metadata import MetaData -from .utils import show_dask_progress, controlled_compute, logger +from .utils import controlled_compute, logger, show_dask_progress zarrGroup = z_hierarchy.Group @@ -279,12 +280,18 @@ def _verify_keys(self, cell_key: str, feat_key: str) -> None: feat_key: Name of the key (column) from feature attribute table Returns: None + + Note on type checking /GA: + 1. ds.cells.get_dtype(cell_key) == bool returns True because dtype('bool') (from numpy) is conceptually equivalent to Python's bool. + 2. isinstance(ds.cells.get_dtype(cell_key), bool) returns False because dtype('bool') is a numpy.dtype object, not the native Python bool type. + 3. Reason: dtype('bool') is a numpy object, and isinstance checks for the exact class, which is numpy.dtype, not bool. + """ - if cell_key not in self.cells.columns or self.cells.get_dtype(cell_key) != bool: + if cell_key not in self.cells.columns or self.cells.get_dtype(cell_key) != bool: # noqa: E721 raise ValueError( f"ERROR: Either {cell_key} does not exist or is not bool type" ) - if feat_key not in self.feats.columns or self.feats.get_dtype(feat_key) != bool: + if feat_key not in self.feats.columns or self.feats.get_dtype(feat_key) != bool: # noqa: E721 raise ValueError( f"ERROR: Either {feat_key} does not exist or is not bool type" ) @@ -526,9 +533,10 @@ def iter_normed_feature_wise( columns=feat_idx[chunk], ) else: - yield controlled_compute(data[:, chunk], self.nthreads).T, feat_idx[ - chunk - ] + yield ( + controlled_compute(data[:, chunk], self.nthreads).T, + feat_idx[chunk], + ) def save_normed_for_query( self, feat_key: Optional[str], batch_size: int, overwrite: bool = True @@ -549,6 +557,7 @@ def save_normed_for_query( None """ from joblib import Parallel, delayed + from .writers import create_zarr_obj_array def write_wrapper(idx: str, v: np.ndarray) -> None: @@ -563,7 +572,8 @@ def write_wrapper(idx: str, v: np.ndarray) -> None: None, feat_key, batch_size, "Saving features", False ): Parallel(n_jobs=self.nthreads)( - delayed(write_wrapper)(inds[i], mat[i]) for i in range(len(inds)) # type: ignore + delayed(write_wrapper)(inds[i], mat[i]) + for i in range(len(inds)) # type: ignore ) def save_aggregated_ordering( @@ -888,6 +898,51 @@ def set_feature_stats(self, cell_key: str) -> None: self.feats.unmount_location(identifier) return None + def set_summary_stats( + self, cell_key: str = None, n_bins: int = 200, lowess_frac: float = 0.1 + ) -> Tuple[str, str]: + """Calculates summary statistics for the features of the assay using only cells that are marked True by the 'cell_key' parameter. + + Args: + cell_key: Name of the key (column) from cell attribute table. + n_bins: Number of bins to divide the data into. + lowess_frac: Between 0 and 1. The fraction of the data used when estimating the fit between mean and + variance. This is same as `frac` in statsmodels.nonparametric.smoothers_lowess.lowess + + Returns: + A tuple of two strings. + identifier: The text that will be prepended to column names when summary statistics are loaded onto the feature attributes table. + c_var_col: The name of the column in the feature attribute table that contains the corrected variance values. + """ + + def col_renamer(x): + return f"{identifier}_{x}" + + if cell_key is None: + cell_key = "I" + + # check lowess_frac is between 0 and 1 + if not 0 <= lowess_frac <= 1: + raise ValueError("lowess_frac must be between 0 and 1") + + self.set_feature_stats(cell_key) + identifier = self._load_stats_loc(cell_key) + c_var_col = f"c_var__{n_bins}__{lowess_frac}" + if col_renamer(c_var_col) in self.feats.columns: + logger.info("Using existing corrected dispersion values") + else: + slots = ["normed_tot", "avg", "nz_mean", "sigmas", "normed_n"] + for i in slots: + i = col_renamer(i) + if i not in self.feats.columns: + raise KeyError(f"ERROR: {i} not found in feature metadata") + c_var = self.feats.remove_trend( + col_renamer("avg"), col_renamer("sigmas"), n_bins, lowess_frac + ) + self.feats.insert(c_var_col, c_var, overwrite=True, location=identifier) + + return identifier, c_var_col + # maybe we should return plot here? If one wants to modify it. /raz def mark_hvgs( self, @@ -950,21 +1005,9 @@ def mark_hvgs( def col_renamer(x): return f"{identifier}_{x}" - self.set_feature_stats(cell_key) - identifier = self._load_stats_loc(cell_key) - c_var_col = f"c_var__{n_bins}__{lowess_frac}" - if col_renamer(c_var_col) in self.feats.columns: - logger.info("Using existing corrected dispersion values") - else: - slots = ["normed_tot", "avg", "nz_mean", "sigmas", "normed_n"] - for i in slots: - i = col_renamer(i) - if i not in self.feats.columns: - raise KeyError(f"ERROR: {i} not found in feature metadata") - c_var = self.feats.remove_trend( - col_renamer("avg"), col_renamer("sigmas"), n_bins, lowess_frac - ) - self.feats.insert(c_var_col, c_var, overwrite=True, location=identifier) + logger.info("Calculating summary statistics") + identifier, c_var_col = self.set_summary_stats(cell_key, n_bins, lowess_frac) + logger.info("Calculating HVGs") if max_mean != np.inf: max_mean = 2**max_mean