diff --git a/docs/conf.py b/docs/conf.py index bf21e403..761c2a22 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -86,6 +86,7 @@ "pandas": ("https://pandas.pydata.org/docs/", None), "python": ("https://docs.python.org/3", None), "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), + "pynndescent": ("https://pynndescent.readthedocs.io/en/latest/", None), "sklearn": ("https://scikit-learn.org/stable/", None), "torch": ("https://pytorch.org/docs/master/", None), "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None), diff --git a/ehrapy/preprocessing/_scanpy_pp_api.py b/ehrapy/preprocessing/_scanpy_pp_api.py index a34b4aec..5317530e 100644 --- a/ehrapy/preprocessing/_scanpy_pp_api.py +++ b/ehrapy/preprocessing/_scanpy_pp_api.py @@ -1,27 +1,35 @@ -from collections.abc import Collection, Mapping, Sequence +from __future__ import annotations + from types import MappingProxyType -from typing import Any, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Union import numpy as np import scanpy as sc -from anndata import AnnData -from scipy.sparse import spmatrix + +if TYPE_CHECKING: + from collections.abc import Collection, Mapping, Sequence + + from anndata import AnnData + from scanpy.neighbors import KnnTransformerLike + from scipy.sparse import spmatrix + + from ehrapy.preprocessing._types import KnownTransformer AnyRandom = Union[int, np.random.RandomState, None] def pca( - data: Union[AnnData, np.ndarray, spmatrix], - n_comps: Optional[int] = None, - zero_center: Optional[bool] = True, + data: AnnData | np.ndarray | spmatrix, + n_comps: int | None = None, + zero_center: bool | None = True, svd_solver: str = "arpack", random_state: AnyRandom = 0, return_info: bool = False, dtype: str = "float32", copy: bool = False, chunked: bool = False, - chunk_size: Optional[int] = None, -) -> Union[AnnData, np.ndarray, spmatrix]: # pragma: no cover + chunk_size: int | None = None, +) -> AnnData | np.ndarray | spmatrix | None: # pragma: no cover """Computes a principal component analysis. Computes PCA coordinates, loadings and variance decomposition. Uses the implementation of *scikit-learn*. @@ -91,17 +99,17 @@ def pca( def regress_out( adata: AnnData, - keys: Union[str, Sequence[str]], - n_jobs: Optional[int] = None, + keys: str | Sequence[str], + n_jobs: int | None = None, copy: bool = False, -) -> Optional[AnnData]: # pragma: no cover +) -> AnnData | None: # pragma: no cover """Regress out (mostly) unwanted sources of variation. Uses simple linear regression. This is inspired by Seurat's `regressOut` function in R [Satija15]. Note that this function tends to overcorrect in certain circumstances. Args: - adata: :class:`~anndata.AnnData` object object containing all observations. + adata: :class:`~anndata.AnnData` object containing all observations. keys: Keys for observation annotation on which to regress on. n_jobs: Number of jobs for parallel computation. `None` means using :attr:`scanpy._settings.ScanpyConfig.n_jobs`. copy: Determines whether a copy of `adata` is returned. @@ -113,12 +121,12 @@ def regress_out( def subsample( - data: Union[AnnData, np.ndarray, spmatrix], - fraction: Optional[float] = None, - n_obs: Optional[int] = None, + data: AnnData | np.ndarray | spmatrix, + fraction: float | None = None, + n_obs: int | None = None, random_state: AnyRandom = 0, copy: bool = False, -) -> Optional[AnnData]: # pragma: no cover +) -> AnnData | None: # pragma: no cover """Subsample to a fraction of the number of observations. Args: @@ -138,9 +146,9 @@ def subsample( def combat( adata: AnnData, key: str = "batch", - covariates: Optional[Collection[str]] = None, + covariates: Collection[str] | None = None, inplace: bool = True, -) -> Union[AnnData, np.ndarray, None]: # pragma: no cover +) -> AnnData | np.ndarray | None: # pragma: no cover """ComBat function for batch effect correction [Johnson07]_ [Leek12]_ [Pedersen12]_. Corrects for batch effects by fitting linear models, gains statistical power via an EB framework where information is borrowed across features. @@ -149,7 +157,7 @@ def combat( .. _combat.py: https://github.com/brentp/combat.py Args: - adata: :class:`~anndata.AnnData` object object containing all observations. + adata: :class:`~anndata.AnnData` object containing all observations. key: Key to a categorical annotation from :attr:`~anndata.AnnData.obs` that will be used for batch effect removal. covariates: Additional covariates besides the batch variable such as adjustment variables or biological condition. This parameter refers to the design matrix `X` in Equation 2.1 in [Johnson07]_ and to the `mod` argument in @@ -163,7 +171,7 @@ def combat( return sc.pp.combat(adata=adata, key=key, covariates=covariates, inplace=inplace) -_Method = Literal["umap", "gauss", "rapids"] +_Method = Literal["umap", "gauss"] _MetricFn = Callable[[np.ndarray, np.ndarray], float] _MetricSparseCapable = Literal["cityblock", "cosine", "euclidean", "l1", "l2", "manhattan"] _MetricScipySpatial = Literal[ @@ -191,16 +199,17 @@ def combat( def neighbors( adata: AnnData, n_neighbors: int = 15, - n_pcs: Optional[int] = None, - use_rep: Optional[str] = None, + n_pcs: int | None = None, + use_rep: str | None = None, knn: bool = True, random_state: AnyRandom = 0, - method: Optional[_Method] = "umap", - metric: Union[_Metric, _MetricFn] = "euclidean", + method: _Method = "umap", + transformer: KnnTransformerLike | KnownTransformer | None = None, + metric: _Metric | _MetricFn = "euclidean", metric_kwds: Mapping[str, Any] = MappingProxyType({}), - key_added: Optional[str] = None, + key_added: str | None = None, copy: bool = False, -) -> Optional[AnnData]: # pragma: no cover +) -> AnnData | None: # pragma: no cover """Compute a neighborhood graph of observations [McInnes18]_. The neighbor search efficiency of this heavily relies on UMAP [McInnes18]_, @@ -209,7 +218,7 @@ def neighbors( connectivities are computed according to [Coifman05]_, in the adaption of [Haghverdi16]_. Args: - adata: :class:`~anndata.AnnData` object object containing all observations. + adata: :class:`~anndata.AnnData` object containing all observations. n_neighbors: The size of local neighborhood (in terms of number of neighboring data points) used for manifold approximation. Larger values result in more global views of the manifold, while smaller values result in more local data being preserved. In general values should be in the range 2 to 100. If `knn` is `True`, number of nearest neighbors to be searched. @@ -225,6 +234,19 @@ def neighbors( method: Use 'umap' [McInnes18]_ or 'gauss' (Gauss kernel following [Coifman05]_ with adaptive width [Haghverdi16]_) for computing connectivities. Use 'rapids' for the RAPIDS implementation of UMAP (experimental, GPU only). metric: A known metric’s name or a callable that returns a distance. + transformer: Approximate kNN search implementation. Follows the API of + :class:`~sklearn.neighbors.KNeighborsTransformer`. + See scanpy's `knn-transformers tutorial `_ for more details. This tutorial is also valid for ehrapy's `neighbors` function. + Next to the advanced options from the knn-transformers tutorial, this argument accepts the following basic options: + + `None` (the default) + Behavior depends on data size. + For small data, uses :class:`~sklearn.neighbors.KNeighborsTransformer` with algorithm="brute" for exact kNN, otherwise uses + :class:`~pynndescent.pynndescent_.PyNNDescentTransformer` for approximate kNN. + `'pynndescent'` + Uses :class:`~pynndescent.pynndescent_.PyNNDescentTransformer` for approximate kNN. + `'sklearn'` + Uses :class:`~sklearn.neighbors.KNeighborsTransformer` with algorithm="brute" for exact kNN. metric_kwds: Options for the metric. key_added: If not specified, the neighbors data is stored in .uns['neighbors'], distances and connectivities are stored in .obsp['distances'] and .obsp['connectivities'] respectively. @@ -250,6 +272,7 @@ def neighbors( knn=knn, random_state=random_state, method=method, + transformer=transformer, metric=metric, metric_kwds=metric_kwds, key_added=key_added, diff --git a/ehrapy/preprocessing/_types.py b/ehrapy/preprocessing/_types.py new file mode 100644 index 00000000..6810761a --- /dev/null +++ b/ehrapy/preprocessing/_types.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from typing import Literal + +KnownTransformer = Literal["pynndescent", "sklearn"]