From 861d76251830d4de71c301ee6b28d0917d0fdf81 Mon Sep 17 00:00:00 2001 From: Carl Buchholz <32228189+aGuyLearning@users.noreply.github.com> Date: Mon, 2 Dec 2024 09:05:42 +0100 Subject: [PATCH] Improve survival analysis interface (#825) * updated kmf to match method signature * updated notebook * updated ehrapy tutorial commit * updated docu for new method signature * added outputs to survival analysis * correctly passing on fitting options * pull request fixes. - removed kwargs - updated documentation * added legacy suport * added kmf function legacy support in tests and added new kaplan_meier function in line with new signature * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updated notebook * added stacklevel to deprecation warning * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * added deprecation warning in comment * Update ehrapy/plot/_survival_analysis.py * Update ehrapy/plot/_survival_analysis.py * Update ehrapy/plot/_survival_analysis.py * Update ehrapy/plot/_survival_analysis.py * Update tests/tools/test_sa.py * doc adjustments * change name of kmf plot to kaplan_meier, some adjustments * introduce keyword only for univariate sa * correct docstring * update submodule * add lifelines intersphinx mappings * Update ehrapy/tools/_sa.py * Update ehrapy/tools/_sa.py * Update ehrapy/tools/_sa.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Lukas Heumos Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> Co-authored-by: eroell --- docs/conf.py | 1 + docs/contributing.md | 2 +- docs/tutorials/notebooks | 2 +- docs/usage/usage.md | 4 +- ehrapy/plot/__init__.py | 2 +- ehrapy/plot/_survival_analysis.py | 70 +++++++++-- ehrapy/tools/__init__.py | 2 + ehrapy/tools/_sa.py | 190 ++++++++++++++++++++++++++++-- tests/tools/test_sa.py | 11 +- 9 files changed, 256 insertions(+), 28 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 761c2a22..dd46ddab 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -96,6 +96,7 @@ "flax": ("https://flax.readthedocs.io/en/latest/", None), "jax": ("https://jax.readthedocs.io/en/latest/", None), "lamin": ("https://lamin.ai/docs", None), + "lifelines": ("https://lifelines.readthedocs.io/en/latest/", None), } language = "en" diff --git a/docs/contributing.md b/docs/contributing.md index ab0d890b..ce5858eb 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -51,7 +51,7 @@ and [prettier][prettier-editors]. ## Writing tests ```{note} -Remember to first install the package with `pip install -e "[dev,test,docs]"` +Remember to first install the package with `pip install -e ".[dev,test,docs]"` ``` This package uses the [pytest][] for automated testing. Please [write tests][scanpy-test-docs] for every function added diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 99b17e70..ac088bca 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 99b17e7039699548a908433fa3ee6b5cbac5e29f +Subproject commit ac088bcabae5de8516ca9a5aa036b4e3cdf67df6 diff --git a/docs/usage/usage.md b/docs/usage/usage.md index c77593b0..6f3f2366 100644 --- a/docs/usage/usage.md +++ b/docs/usage/usage.md @@ -226,7 +226,7 @@ In contrast to a preprocessing function, a tool usually adds an easily interpret tools.ols tools.glm - tools.kmf + tools.kaplan_meier tools.test_kmf_logrank tools.test_nested_f_statistic tools.cox_ph @@ -368,7 +368,7 @@ Methods that extract and visualize tool-specific annotation in an AnnData object :nosignatures: plot.ols - plot.kmf + plot.kaplan_meier ``` ### Causal Inference diff --git a/ehrapy/plot/__init__.py b/ehrapy/plot/__init__.py index 5ae52ab1..0c740e95 100644 --- a/ehrapy/plot/__init__.py +++ b/ehrapy/plot/__init__.py @@ -2,6 +2,6 @@ from ehrapy.plot._colormaps import * # noqa: F403 from ehrapy.plot._missingno_pl_api import * # noqa: F403 from ehrapy.plot._scanpy_pl_api import * # noqa: F403 -from ehrapy.plot._survival_analysis import kmf, ols +from ehrapy.plot._survival_analysis import kaplan_meier, kmf, ols from ehrapy.plot.causal_inference._dowhy import causal_effect from ehrapy.plot.feature_ranking._feature_importances import rank_features_supervised diff --git a/ehrapy/plot/_survival_analysis.py b/ehrapy/plot/_survival_analysis.py index bf74df85..717f9202 100644 --- a/ehrapy/plot/_survival_analysis.py +++ b/ehrapy/plot/_survival_analysis.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING import matplotlib.pyplot as plt @@ -38,7 +39,7 @@ def ols( ax: Axes | None = None, title: str | None = None, **kwds, -): +) -> Axes | None: """Plots an Ordinary Least Squares (OLS) Model result, scatter plot, and line plot. Args: @@ -134,6 +135,8 @@ def ols( if not show: return ax + else: + return None def kmf( @@ -152,7 +155,48 @@ def kmf( figsize: tuple[float, float] | None = None, show: bool | None = None, title: str | None = None, -): +) -> Axes | None: + warnings.warn( + "This function is deprecated and will be removed in the next release. Use `ep.pl.kaplan_meier` instead.", + DeprecationWarning, + stacklevel=2, + ) + return kaplan_meier( + kmfs=kmfs, + ci_alpha=ci_alpha, + ci_force_lines=ci_force_lines, + ci_show=ci_show, + ci_legend=ci_legend, + at_risk_counts=at_risk_counts, + color=color, + grid=grid, + xlim=xlim, + ylim=ylim, + xlabel=xlabel, + ylabel=ylabel, + figsize=figsize, + show=show, + title=title, + ) + + +def kaplan_meier( + kmfs: Sequence[KaplanMeierFitter], + ci_alpha: list[float] | None = None, + ci_force_lines: list[Boolean] | None = None, + ci_show: list[Boolean] | None = None, + ci_legend: list[Boolean] | None = None, + at_risk_counts: list[Boolean] | None = None, + color: list[str] | None | None = None, + grid: Boolean | None = False, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + xlabel: str | None = None, + ylabel: str | None = None, + figsize: tuple[float, float] | None = None, + show: bool | None = None, + title: str | None = None, +) -> Axes | None: """Plots a pretty figure of the Fitted KaplanMeierFitter model See https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html @@ -186,23 +230,21 @@ def kmf( # So we need to flip `censor_fl` when pass `censor_fl` to KaplanMeierFitter >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) - >>> kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X) - >>> ep.pl.kmf( + >>> kmf = ep.tl.kaplan_meier(adata, "mort_day_censored", "censor_flg") + >>> ep.pl.kaplan_meier( ... [kmf], color=["r"], xlim=[0, 700], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived", show=True ... ) .. image:: /_static/docstring_previews/kmf_plot_1.png - >>> T = adata[:, ["mort_day_censored"]].X - >>> E = adata[:, ["censor_flg"]].X >>> groups = adata[:, ["service_unit"]].X - >>> ix1 = groups == "FICU" - >>> ix2 = groups == "MICU" - >>> ix3 = groups == "SICU" - >>> kmf_1 = ep.tl.kmf(T[ix1], E[ix1], label="FICU") - >>> kmf_2 = ep.tl.kmf(T[ix2], E[ix2], label="MICU") - >>> kmf_3 = ep.tl.kmf(T[ix3], E[ix3], label="SICU") - >>> ep.pl.kmf([kmf_1, kmf_2, kmf_3], ci_show=[False,False,False], color=['k','r', 'g'], + >>> adata_ficu = adata[groups == "FICU"] + >>> adata_micu = adata[groups == "MICU"] + >>> adata_sicu = adata[groups == "SICU"] + >>> kmf_1 = ep.tl.kaplan_meier(adata_ficu, "mort_day_censored", "censor_flg", label="FICU") + >>> kmf_2 = ep.tl.kaplan_meier(adata_micu, "mort_day_censored", "censor_flg", label="MICU") + >>> kmf_3 = ep.tl.kaplan_meier(adata_sicu, "mort_day_censored", "censor_flg", label="SICU") + >>> ep.pl.kaplan_meier([kmf_1, kmf_2, kmf_3], ci_show=[False,False,False], color=['k','r', 'g'], >>> xlim=[0, 750], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived") .. image:: /_static/docstring_previews/kmf_plot_2.png @@ -251,3 +293,5 @@ def kmf( if not show: return ax + else: + return None diff --git a/ehrapy/tools/__init__.py b/ehrapy/tools/__init__.py index 5da8fa69..c034882f 100644 --- a/ehrapy/tools/__init__.py +++ b/ehrapy/tools/__init__.py @@ -2,6 +2,7 @@ anova_glm, cox_ph, glm, + kaplan_meier, kmf, log_logistic_aft, nelson_aalen, @@ -31,6 +32,7 @@ "cox_ph", "glm", "kmf", + "kaplan_meier", "log_logistic_aft", "nelson_aalen", "ols", diff --git a/ehrapy/tools/_sa.py b/ehrapy/tools/_sa.py index e23b6a43..fed63b9e 100644 --- a/ehrapy/tools/_sa.py +++ b/ehrapy/tools/_sa.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Literal import numpy as np # This package is implicitly used @@ -126,7 +127,9 @@ def kmf( weights: Iterable | None = None, censoring: Literal["right", "left"] = None, ) -> KaplanMeierFitter: - """Fit the Kaplan-Meier estimate for the survival function. + """DEPRECATION WARNING: This function is deprecated and will be removed in the next release. Use `kaplan_meier` instead. + + Fit the Kaplan-Meier estimate for the survival function. The Kaplan–Meier estimator, also known as the product limit estimator, is a non-parametric statistic used to estimate the survival function from lifetime data. In medical research, it is often used to measure the fraction of patients living for a certain amount of time after treatment. @@ -158,6 +161,12 @@ def kmf( >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) >>> kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X) """ + + warnings.warn( + "This function is deprecated and will be removed in the next release. Use `ep.tl.kaplan_meier` instead.", + DeprecationWarning, + stacklevel=2, + ) kmf = KaplanMeierFitter() if censoring == "None" or "right": kmf.fit( @@ -185,6 +194,71 @@ def kmf( return kmf +def kaplan_meier( + adata: AnnData, + duration_col: str, + event_col: str | None = None, + *, + timeline: list[float] | None = None, + entry: str | None = None, + label: str | None = None, + alpha: float | None = None, + ci_labels: list[str] | None = None, + weights: list[float] | None = None, + fit_options: dict | None = None, + censoring: Literal["right", "left"] = "right", +) -> KaplanMeierFitter: + """Fit the Kaplan-Meier estimate for the survival function. + + The Kaplan–Meier estimator, also known as the product limit estimator, is a non-parametric statistic used to estimate the survival function from lifetime data. + In medical research, it is often used to measure the fraction of patients living for a certain amount of time after treatment. + + See https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator + https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html#module-lifelines.fitters.kaplan_meier_fitter + + Args: + adata: AnnData object with necessary columns `duration_col` and `event_col`. + duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes. + event_col: The name of the column in anndata that contains the subjects’ death observation. + timeline: Return the best estimate at the values in timelines (positively increasing) + entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations. + If None, all members of the population entered study when they were "born". + label: A string to name the column of the estimate. + alpha: The alpha value in the confidence intervals. Overrides the initializing alpha for this call to fit only. + ci_labels: Add custom column names to the generated confidence intervals as a length-2 list: [, ] (default: