Skip to content

Commit

Permalink
Merge branch 'main' into feature/python_3_12
Browse files Browse the repository at this point in the history
  • Loading branch information
Zethson authored Nov 28, 2024
2 parents 288a4be + 1d7c5d7 commit 8ea1638
Show file tree
Hide file tree
Showing 24 changed files with 509 additions and 408 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/run_notebooks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
"docs/tutorials/notebooks/ehrapy_introduction.ipynb",
"docs/tutorials/notebooks/mimic_2_introduction.ipynb",
"docs/tutorials/notebooks/mimic_2_survival_analysis.ipynb",
"docs/tutorials/notebooks/mimic_2_fate.ipynb",
# "docs/tutorials/notebooks/mimic_2_fate.ipynb", # https://github.com/theislab/cellrank/issues/1235
"docs/tutorials/notebooks/mimic_2_causal_inference.ipynb",
# "docs/tutorials/notebooks/mimic_3_demo.ipynb",
# "docs/tutorials/notebooks/medcat.ipynb",
Expand All @@ -34,5 +34,8 @@ jobs:
- name: Install ehrapy and additional dependencies
run: uv pip install --system . cellrank nbconvert ipykernel graphviz

- name: Install scvelo from Github
run: uv pip install --system git+https://github.com/theislab/scvelo.git

- name: Run ${{ matrix.notebook }} Notebook
run: jupyter nbconvert --to notebook --execute ${{ matrix.notebook }}
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
hooks:
- id: prettier
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.3
rev: v0.8.0
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes]
Expand Down
17 changes: 7 additions & 10 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@ build:
os: ubuntu-22.04
tools:
python: "3.11"
jobs:
pre_build:
- python -c "import ehrapy"
- pip freeze
post_create_environment:
- pip install uv
post_install:
# VIRTUAL_ENV needs to be set manually for now.
# See https://github.com/readthedocs/readthedocs.org/pull/11152/
- VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH pip install .[docs]
commands:
- asdf plugin add uv
- asdf install uv latest
- asdf global uv latest
- uv venv
- uv pip install .[docs]
- .venv/bin/python -m sphinx -T -b html -d docs/_build/doctrees -D language=en docs $READTHEDOCS_OUTPUT/html
sphinx:
configuration: docs/conf.py
fail_on_warning: false
Expand Down
2 changes: 1 addition & 1 deletion docs/_ext/edit_on_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_github_repo(app: Sphinx, path: str) -> str:


def _html_page_context(
app: Sphinx, _pagename: str, templatename: str, context: dict[str, Any], doctree: Optional[Any]
app: Sphinx, _pagename: str, templatename: str, context: dict[str, Any], doctree: Any | None
) -> None:
# doctree is None - otherwise viewcode fails
if templatename != "page.html" or doctree is None:
Expand Down
2 changes: 1 addition & 1 deletion ehrapy/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
figdir: str | Path = "./figures/",
cache_compression: str | None = "lzf",
max_memory=15,
n_jobs: int = 1,
n_jobs: int = -1,
logfile: str | Path | None = None,
categories_to_ignore: Iterable[str] = ("N/A", "dontknow", "no_gate", "?"),
_frameon: bool = True,
Expand Down
6 changes: 3 additions & 3 deletions ehrapy/core/_tool_available.py → ehrapy/_utils_available.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from subprocess import PIPE, Popen


def _check_module_importable(package: str) -> bool: # pragma: no cover
def _check_module_importable(package: str) -> bool:
"""Checks whether a module is installed and can be loaded.
Args:
Expand All @@ -19,7 +19,7 @@ def _check_module_importable(package: str) -> bool: # pragma: no cover
return module_available


def _shell_command_accessible(command: list[str]) -> bool: # pragma: no cover
def _shell_command_accessible(command: list[str]) -> bool:
"""Checks whether the provided command is accessible in the current shell.
Args:
Expand All @@ -29,7 +29,7 @@ def _shell_command_accessible(command: list[str]) -> bool: # pragma: no cover
True if the command is accessible, False otherwise.
"""
command_accessible = Popen(command, stdout=PIPE, stderr=PIPE, universal_newlines=True, shell=True)
(commmand_stdout, command_stderr) = command_accessible.communicate()
command_accessible.communicate()
if command_accessible.returncode != 0:
return False

Expand Down
4 changes: 2 additions & 2 deletions ehrapy/_doc_util.py → ehrapy/_utils_doc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import inspect
from collections.abc import Callable
from textwrap import dedent
from typing import Callable, Optional, Union


def getdoc(c_or_f: Union[Callable, type]) -> Optional[str]: # pragma: no cover
def getdoc(c_or_f: Callable | type) -> str | None: # pragma: no cover
if getattr(c_or_f, "__doc__", None) is None:
return None
doc = inspect.getdoc(c_or_f)
Expand Down
21 changes: 21 additions & 0 deletions ehrapy/_utils_rendering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import functools

from rich.progress import Progress, SpinnerColumn


def spinner(message: str = "Running task"):
def wrap(func):
@functools.wraps(func)
def wrapped_f(*args, **kwargs):
with Progress(
"[progress.description]{task.description}",
SpinnerColumn(),
refresh_per_second=1500,
) as progress:
progress.add_task(f"[blue]{message}", total=1)
result = func(*args, **kwargs)
return result

return wrapped_f

return wrap
56 changes: 51 additions & 5 deletions ehrapy/anndata/anndata_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
from collections import OrderedDict
from string import ascii_letters
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING, Any, NamedTuple

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -303,7 +303,7 @@ def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData:
return new_adata


def _get_column_indices(adata: AnnData, col_names: str | Iterable[str]) -> list[int]:
def get_column_indices(adata: AnnData, col_names: str | Iterable[str]) -> list[int]:
"""Fetches the column indices in X for a given list of column names
Args:
Expand Down Expand Up @@ -383,7 +383,7 @@ def set_numeric_vars(
if copy:
adata = adata.copy()

vars_idx = _get_column_indices(adata, vars)
vars_idx = get_column_indices(adata, vars)

adata.X[:, vars_idx] = values

Expand All @@ -404,7 +404,7 @@ def _detect_binary_columns(df: pd.DataFrame, numerical_columns: list[str]) -> li
for column in numerical_columns:
# checking for float and int as well as NaNs (this is safe since checked columns are numericals only)
# only columns that contain at least one 0 and one 1 are counted as binary (or 0.0/1.0)
if df[column].isin([0.0, 1.0, np.NaN, 0, 1]).all() and df[column].nunique() == 2:
if df[column].isin([0.0, 1.0, np.nan, 0, 1]).all() and df[column].nunique() == 2:
binary_columns.append(column)

return binary_columns
Expand All @@ -423,7 +423,7 @@ def _cast_obs_columns(obs: pd.DataFrame) -> pd.DataFrame:
# type cast each non-numerical column to either bool (if possible) or category else
obs[object_columns] = obs[object_columns].apply(
lambda obs_name: obs_name.astype("category")
if not set(pd.unique(obs_name)).issubset({False, True, np.NaN})
if not set(pd.unique(obs_name)).issubset({False, True, np.nan})
else obs_name.astype("bool"),
axis=0,
)
Expand Down Expand Up @@ -663,3 +663,49 @@ def get_rank_features_df(

class NotEncodedError(AssertionError):
pass


def _are_ndarrays_equal(arr1: np.ndarray, arr2: np.ndarray) -> np.bool_:
"""Check if two arrays are equal member-wise.
Note: Two NaN are considered equal.
Args:
arr1: First array to compare
arr2: Second array to compare
Returns:
True if the two arrays are equal member-wise
"""
return np.all(np.equal(arr1, arr2, dtype=object) | ((arr1 != arr1) & (arr2 != arr2)))


def _is_val_missing(data: np.ndarray) -> np.ndarray[Any, np.dtype[np.bool_]]:
"""Check if values in a AnnData matrix are missing.
Args:
data: The AnnData matrix to check
Returns:
An array of bool representing the missingness of the original data, with the same shape
"""
return np.isin(data, [None, ""]) | (data != data)


def _to_dense_matrix(adata: AnnData, layer: str | None = None) -> np.ndarray: # pragma: no cover
"""Extract a layer from an AnnData object and convert it to a dense matrix if required.
Args:
adata: The AnnData where to extract the layer from.
layer: Name of the layer to extract. If omitted, X is considered.
Returns:
The layer as a dense matrix. If a conversion was required, this function returns a copy of the original layer,
othersize this function returns a reference.
"""
from scipy.sparse import issparse

if layer is None:
return adata.X.toarray() if issparse(adata.X) else adata.X
else:
return adata.layers[layer].toarray() if issparse(adata.layers[layer]) else adata.layers[layer]
2 changes: 1 addition & 1 deletion ehrapy/data/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ def synthea_1k_sample(

df = anndata_to_df(adata)
df.drop(
columns=[col for col in df.columns if any(isinstance(x, (list, dict)) for x in df[col].dropna())], inplace=True
columns=[col for col in df.columns if any(isinstance(x, list | dict) for x in df[col].dropna())], inplace=True
)
df.drop(columns=df.columns[df.isna().all()], inplace=True)
adata = df_to_anndata(df, index_column="id")
Expand Down
12 changes: 6 additions & 6 deletions ehrapy/plot/_scanpy_pl_api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from __future__ import annotations

from collections.abc import Collection, Iterable, Mapping, Sequence
from collections.abc import Callable, Collection, Iterable, Mapping, Sequence
from enum import Enum
from functools import partial
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Callable, Literal, Union
from typing import TYPE_CHECKING, Any, Literal

import scanpy as sc
from scanpy.plotting import DotPlot, MatrixPlot, StackedViolin

from ehrapy._doc_util import (
from ehrapy._utils_doc import (
_doc_params,
doc_adata_color_etc,
doc_common_groupby_plot_args,
Expand All @@ -36,12 +36,12 @@
from scanpy.plotting._utils import _AxesSubplot

_Basis = Literal["pca", "tsne", "umap", "diffmap", "draw_graph_fr"]
_VarNames = Union[str, Sequence[str]]
ColorLike = Union[str, tuple[float, ...]]
_VarNames = str | Sequence[str]
ColorLike = str | tuple[float, ...]
_IGraphLayout = Literal["fa", "fr", "rt", "rt_circular", "drl", "eq_tree", ...] # type: ignore
_FontWeight = Literal["light", "normal", "medium", "semibold", "bold", "heavy", "black"]
_FontSize = Literal["xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large"]
VBound = Union[str, float, Callable[[Sequence[float]], float]]
VBound = str | float | Callable[[Sequence[float]], float]


@_doc_params(scatter_temp=doc_scatter_basic, show_save_ax=doc_show_save_ax)
Expand Down
4 changes: 2 additions & 2 deletions ehrapy/preprocessing/_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def encode(

if isinstance(encodings, str) and not autodetect:
raise ValueError("Passing a string for parameter encodings is only possible when using autodetect=True!")
elif autodetect and not isinstance(encodings, (str, type(None))):
elif autodetect and not isinstance(encodings, str | type(None)):
raise ValueError(
f"Setting encode mode with autodetect=True only works by passing a string (encode mode name) or None not {type(encodings)}!"
)
Expand Down Expand Up @@ -630,7 +630,7 @@ def _update_obs(adata: AnnData, categorical_names: list[str]) -> pd.DataFrame:
updated_obs[var_name] = adata.X[::, idx : idx + 1].flatten()
# note: this will count binary columns (0 and 1 only) as well
# needed for writing to .h5ad files
if set(pd.unique(updated_obs[var_name])).issubset({False, True, np.NaN}):
if set(pd.unique(updated_obs[var_name])).issubset({False, True, np.nan}):
updated_obs[var_name] = updated_obs[var_name].astype("bool")
# get all non bool object columns and cast them to category dtype
object_columns = list(updated_obs.select_dtypes(include="object").columns)
Expand Down
Loading

0 comments on commit 8ea1638

Please sign in to comment.