Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cohort Tracker #658

Merged
merged 50 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
decd34a
nb with stuff that at least does not just fail
eroell Feb 12, 2024
6407623
population logging and tracking first support
eroell Feb 13, 2024
cd7e113
refined and first-line debugged population tracking
eroell Feb 14, 2024
5ee6b5d
population to cohort
eroell Feb 14, 2024
2e28f2b
cohort logging with tests
eroell Feb 14, 2024
3673d28
toy notebook cleaned
eroell Feb 15, 2024
9687e4d
small comments included
eroell Feb 29, 2024
c6f5955
documentation working somewhat
eroell Feb 29, 2024
dcc3841
remove class in tests
eroell Mar 1, 2024
b844096
move read_csv to fixture
eroell Mar 1, 2024
75451bd
remove tracking dict, use tableones for tracking instead
eroell Mar 5, 2024
c42d847
remove DataFrame as accepted input
eroell Mar 5, 2024
232c2b1
legend label order matching bar order
eroell Mar 5, 2024
4090f6f
prepare type detection for alignment, added test
eroell Mar 5, 2024
4606155
add ax and remove unused args, return not solved yet
eroell Mar 6, 2024
2728ba5
tests for plots, move to pyplot for flowchart
eroell Mar 6, 2024
c8ff205
Update ehrapy/tools/cohort_tracking/_cohort_tracker.py
eroell Mar 9, 2024
3c7cdc9
Update ehrapy/tools/cohort_tracking/_cohort_tracker.py
eroell Mar 9, 2024
088983c
Update ehrapy/tools/cohort_tracking/_cohort_tracker.py
eroell Mar 9, 2024
e9dedbb
Update ehrapy/tools/cohort_tracking/_cohort_tracker.py
eroell Mar 9, 2024
711c710
Update ehrapy/tools/cohort_tracking/_cohort_tracker.py
eroell Mar 9, 2024
7f25d66
Update ehrapy/tools/cohort_tracking/_cohort_tracker.py
eroell Mar 9, 2024
331c095
Update ehrapy/tools/cohort_tracking/_cohort_tracker.py
eroell Mar 9, 2024
ee0b68f
Update ehrapy/tools/cohort_tracking/_cohort_tracker.py
eroell Mar 9, 2024
7472eaa
remove reset, add updated notebook for quick check
eroell Mar 9, 2024
89e5d5b
typehints and review comments
eroell Mar 9, 2024
ced853e
remove comment in test
eroell Mar 9, 2024
0d44608
tableone to requirements?
eroell Mar 9, 2024
482d0cb
allow typehint union
eroell Mar 12, 2024
da4f922
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
83c0cca
Fix scanpy pre-release compat
Zethson Mar 12, 2024
6327638
Remove anndata warning ignore
Zethson Mar 12, 2024
3e58315
future import fixed in test conf
eroell Mar 12, 2024
92f9112
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
057dd8f
track_t1 -> tracked_tables
eroell Mar 12, 2024
5589a22
updates with better names, label-dicts, better colors, more tests
eroell Mar 13, 2024
a7646a0
Merge branch 'pop-log' of github.com:eroell/ehrapy into pop-log
eroell Mar 13, 2024
e79c031
remove grid lines, add notebook for testimages generation
eroell Mar 13, 2024
d24f677
prettier docstring demo
eroell Mar 13, 2024
a4c021b
Update ehrapy/tools/cohort_tracking/_cohort_tracker.py
eroell Mar 13, 2024
b75cdbb
Update ehrapy/tools/cohort_tracking/_cohort_tracker.py
eroell Mar 13, 2024
d5cfa23
Update ehrapy/tools/cohort_tracking/_cohort_tracker.py
eroell Mar 13, 2024
5261de3
Update ehrapy/tools/cohort_tracking/_cohort_tracker.py
eroell Mar 13, 2024
61683f5
Update ehrapy/tools/cohort_tracking/_cohort_tracker.py
eroell Mar 13, 2024
9663cfb
remove old comments, better variable names, simplify adata check
eroell Mar 13, 2024
0bd391a
Merge branch 'pop-log' of github.com:eroell/ehrapy into pop-log
eroell Mar 13, 2024
d638a88
fix two doc typos
eroell Mar 13, 2024
50118d3
Merge branch 'main' of github.com:eroell/ehrapy
eroell Mar 13, 2024
2b78475
Merge branch 'main' into pop-log
eroell Mar 13, 2024
3c92f26
identical Returns field
eroell Mar 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ehrapy/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ehrapy.tools._sa import anova_glm, cox_ph, glm, kmf, ols, test_kmf_logrank, test_nested_f_statistic
from ehrapy.tools._scanpy_tl_api import * # noqa: F403
from ehrapy.tools.causal._dowhy import causal_inference
from ehrapy.tools.cohort_tracking._cohort_tracker import CohortTracker
from ehrapy.tools.feature_ranking._rank_features_groups import filter_rank_features_groups, rank_features_groups

try: # pragma: no cover
Expand Down
Empty file.
332 changes: 332 additions & 0 deletions ehrapy/tools/cohort_tracking/_cohort_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
import copy
from typing import Any, Union

import graphviz
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scanpy import AnnData
from tableone import TableOne


def _check_columns_exist(df, columns):
eroell marked this conversation as resolved.
Show resolved Hide resolved
if not all(col in df.columns for col in columns):
missing_columns = [col for col in columns if col not in df.columns]
raise ValueError(f"Columns {missing_columns} not found in dataframe.")


# from tableone: https://github.com/tompollard/tableone/blob/bfd6fbaa4ed3e9f59e1a75191c6296a2a80ccc64/tableone/tableone.py#L555
def _detect_categorical_columns(data) -> list:
eroell marked this conversation as resolved.
Show resolved Hide resolved
# assume all non-numerical and date columns are categorical
numeric_cols = set(data._get_numeric_data().columns.values)
date_cols = set(data.select_dtypes(include=[np.datetime64]).columns)
likely_cat = set(data.columns) - numeric_cols
# mypy absolutely looses it if likely_cat is overwritten to be a list
likely_cat_no_dates = list(likely_cat - date_cols)

# check proportion of unique values if numerical
for var in data._get_numeric_data().columns:
likely_flag = 1.0 * data[var].nunique() / data[var].count() < 0.005
if likely_flag:
likely_cat_no_dates.append(var)
return likely_cat_no_dates


class CohortTracker:
def __init__(self, adata: AnnData | pd.DataFrame, columns: list = None, categorical: list = None, *args: Any):
eroell marked this conversation as resolved.
Show resolved Hide resolved
"""Track cohort changes over multiple filtering or processing steps.

This class offers functionality to track and plot cohort changes over multiple filtering or processing steps,
enabling the user to monitor the impact of each step on the cohort.

Tightly interacting with the `tableone` package [1].
Args:
adata: :class:`~anndata.AnnData` or :class:`~pandas.DataFrame` object to track.
eroell marked this conversation as resolved.
Show resolved Hide resolved
columns: List of columns to track. If `None`, all columns will be tracked.
categorical: List of columns that contain categorical variables, if not given will be inferred from the data.

References
----------
eroell marked this conversation as resolved.
Show resolved Hide resolved
[1] Tom Pollard, Alistair E.W. Johnson, Jesse D. Raffa, Roger G. Mark; tableone: An open source Python package for producing summary statistics for research papers, Journal of the American Medical Informatics Association, Volume 24, Issue 2, 1 March 2017, Pages 267–271, https://doi.org/10.1093/jamia/ocw117

"""
if isinstance(adata, AnnData):
df = adata.obs
elif isinstance(adata, pd.DataFrame):
df = adata
else:
raise ValueError("adata must be an AnnData or a DataFrame.")
eroell marked this conversation as resolved.
Show resolved Hide resolved

self.columns = columns if columns is not None else list(df.columns)

if columns is not None:
_check_columns_exist(df, columns)
if categorical is not None:
_check_columns_exist(df, categorical)
if set(categorical).difference(set(self.columns)):
raise ValueError("categorical columns must be in the (selected) columns.")

self._tracked_steps: int = 0
self._tracked_text: list = []
self._tracked_operations: list = []

# if categorical columns specified, use them
# else, follow tableone's logic
self.categorical = categorical if categorical is not None else _detect_categorical_columns(df[self.columns])
self.track = self._get_column_structure(df)
eroell marked this conversation as resolved.
Show resolved Hide resolved

self._track_backup = copy.deepcopy(self.track)

def __call__(
self, adata: AnnData, label: str = None, operations_done: str = None, *args: Any, **tableone_kwargs: Any
eroell marked this conversation as resolved.
Show resolved Hide resolved
) -> Any:
if isinstance(adata, AnnData):
df = adata.obs
elif isinstance(adata, pd.DataFrame):
df = adata
else:
raise ValueError("adata must be an AnnData or a DataFrame.")
eroell marked this conversation as resolved.
Show resolved Hide resolved

_check_columns_exist(df, self.columns)

# track a small text with each tracking step, for the flowchart
track_text = label if label is not None else f"Cohort {self.tracked_steps}"
track_text += "\n (n=" + str(adata.n_obs) + ")"
self._tracked_text.append(track_text)

# track a small text with the operations done
self._tracked_operations.append(operations_done)

self._tracked_steps += 1

t1 = TableOne(df, columns=self.columns, categorical=self.categorical, **tableone_kwargs)
# track new stuff
eroell marked this conversation as resolved.
Show resolved Hide resolved
self._get_column_dicts(t1)

def _get_column_structure(self, df):
column_structure = {}
for column in self.columns:
if column in self.categorical:
# if e.g. a column containing integers is deemed categorical, coerce it to categorical
df[column] = df[column].astype("category")
column_structure[column] = {category: [] for category in df[column].cat.categories}
else:
column_structure[column] = []

return column_structure

def _get_column_dicts(self, table_one):
for col, value in self.track.items():
if isinstance(value, dict):
self._get_cat_dicts(table_one, col)
else:
self._get_num_dicts(table_one, col)

def _get_cat_dicts(self, table_one, col):
eroell marked this conversation as resolved.
Show resolved Hide resolved
for cat in self.track[col].keys():
# if tableone does not have the category of this column anymore, set the percentage to 0
# for categorized columns (e.g. gender 1.0/0.0), str(cat) helps to avoid considering the category as a float
if (col, str(cat)) in table_one.cat_table["Overall"].index:
pct = float(table_one.cat_table["Overall"].loc[(col, str(cat))].split("(")[1].split(")")[0])
else:
pct = 0
self.track[col][cat].append(pct)

def _get_num_dicts(self, table_one, col):
summary = table_one.cont_table["Overall"].loc[(col, "")]
self.track[col].append(summary)

def reset(self):
eroell marked this conversation as resolved.
Show resolved Hide resolved
self.track = self._track_backup
self._tracked_steps = 0
self._tracked_text = []
self._tracked_operations = []

@property
def tracked_steps(self):
return self._tracked_steps

def plot_cohort_change(
Zethson marked this conversation as resolved.
Show resolved Hide resolved
self,
set_axis_labels=True,
subfigure_title: bool = False,
sns_color_palette: str = "husl",
eroell marked this conversation as resolved.
Show resolved Hide resolved
save: str = None,
return_plot: bool = False,
eroell marked this conversation as resolved.
Show resolved Hide resolved
subplots_kwargs: dict = None,
legend_kwargs: dict = None,
):
"""Plot the cohort change over the tracked steps.

Create stacked bar plots to monitor cohort changes over the steps tracked with `CohortTracker`.

Args:
set_axis_labels: If `True`, the y-axis labels will be set to the column names.
subfigure_title: If `True`, each subplot will have a title with the `label` provided during tracking.
sns_color_palette: The color palette to use for the plot. Default is "husl".
save: If a string is provided, the plot will be saved to the path specified.
return_plot: If `True`, the plot will be returned as a tuple of (fig, ax).
subplot_kwargs: Additional keyword arguments for the subplots.
legend_kwargs: Additional keyword arguments for the legend.

Returns:
If `return_plot` a :class:`~matplotlib.figure.Figure` and a :class:`~matplotlib.axes.Axes` or a list of it.

Example:
eroell marked this conversation as resolved.
Show resolved Hide resolved
.. code-block:: python
eroell marked this conversation as resolved.
Show resolved Hide resolved

import ehrapy as ep

adata = ep.dt.diabetes_130(columns_obs_only=["gender", "race", "weight", "age"])
cohort_tracker = ep.tl.CohortTracker(adata)
cohort_tracker(adata, label="original")
adata = adata[:1000]
cohort_tracker(adata, label="filtered cohort", operations_done="filtered to first 1000 entries")
cohort_tracker.plot_cohort_change()
Preview:
.. image:: /_static/docstring_previews/flowchart.png
"""
# Plotting
eroell marked this conversation as resolved.
Show resolved Hide resolved
subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs

fig, axes = plt.subplots(self.tracked_steps, 1, **subplots_kwargs)

legend_labels = []

# if only one step is tracked, axes object is not iterable
if self.tracked_steps == 1:
axes = [axes]

# each tracked step is a subplot
for idx, ax in enumerate(axes):
if subfigure_title:
ax.set_title(self._tracked_text[idx])

# iterate over the tracked columns in the dataframe
for pos, (_cols, data) in enumerate(self.track.items()):
data = pd.DataFrame(data).loc[idx]

cumwidth = 0

# Adjust the hue shift based on the category position such that the colors are more distinguishable
hue_shift = (pos + 1) / len(data)
colors = sns.color_palette(sns_color_palette, len(data))
adjusted_colors = [((color[0] + hue_shift) % 1, color[1], color[2]) for color in colors]

# for categoricals, plot multiple bars
if _cols in self.categorical:
for i, value in enumerate(data):
ax.barh(pos, value, left=cumwidth, color=adjusted_colors[i], height=0.7)

if value > 5:
# Add proportion numbers to the bars
width = value
ax.text(
cumwidth + width / 2,
pos,
f"{value:.1f}",
ha="center",
va="center",
color="white",
fontweight="bold",
)

ax.set_yticks([])
ax.set_xticks([])
cumwidth += value
legend_labels.append(data.index[i])

# for numericals, plot a single bar
else:
ax.barh(pos, 100, left=cumwidth, color=adjusted_colors[0], height=0.8)
ax.text(
100 / 2,
pos,
data[0],
ha="center",
va="center",
color="white",
fontweight="bold",
)
legend_labels.append(_cols)

# Set y-axis labels
if set_axis_labels:
ax.set_yticks(
range(len(self.track.keys()))
) # Set ticks at positions corresponding to the number of columns
ax.set_yticklabels(self.track.keys()) # Set y-axis labels to the column names

# makes the frames invisible
# for ax in axes:
# ax.axis('off')

# Add legend
tot_legend_kwargs = {"loc": "best", "bbox_to_anchor": (1, 1)}
if legend_kwargs is not None:
tot_legend_kwargs.update(legend_kwargs)

plt.legend(legend_labels, **tot_legend_kwargs)

if save is not None:
if not isinstance(save, str):
raise ValueError("'save' must be a string.")
plt.savefig(
save,
)

if return_plot:
return fig, axes

else:
plt.tight_layout()
plt.show()

def plot_flowchart(self, save: str = None, return_plot: bool = True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to style the graphviz plot? It's pretty ugly^^

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my limited experience here, going to matplotlib in our case seems to make more sense and certainly saves some pain. Not amazingly pretty, still

"""Flowchart over the tracked steps.

Create a simple flowchart of data preparation steps tracked with `CohortTracker`.

Args:
save: If a string is provided, the plot will be saved to the path specified.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This save behavior is different from the one scanpy uses (which is bullshit). I'm not sure whether we should have this. Is this interoperable with matplotlib in any way and can just be saved this way?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely agree - using return_figure=True the user gets the figure object, and can do e.g. fig.savefig("<name>.png", <other args>) just as usual.

Asking for the matplotlib object should be the only way of saving the plot, I think now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now go with a show argument: likely harmonized away in the future then

return_plot: If `True`, the plot will be returned as a :class:`~graphviz.Digraph`.

Returns:
If `return_plot` a :class:`~graphviz.Digraph`.

Example:
.. code-block:: python

import ehrapy as ep

adata = ep.dt.diabetes_130(columns_obs_only=["gender", "race", "weight", "age"])
cohort_tracker = ep.tl.CohortTracker(adata)
cohort_tracker(adata, label="original")
adata = adata[:1000]
cohort_tracker(adata, label="filtered cohort", operations_done="filtered to first 1000 entries")
cohort_tracker.plot_flowchart()
Preview:
.. image:: /_static/docstring_previews/flowchart.png

"""

eroell marked this conversation as resolved.
Show resolved Hide resolved
# Create Digraph object
dot = graphviz.Digraph()

# Define nodes (edgy nodes)
for i, text in enumerate(self._tracked_text):
dot.node(name=str(i), label=text, style="filled", shape="box")

for i, op in enumerate(self._tracked_operations[1:]):
dot.edge(str(i), str(i + 1), label=op, labeldistance="2.5")

# Render the graph
if save is not None:
if not isinstance(save, str):
raise ValueError("'save' must be a string.")
dot.render(save, format="png", cleanup=True)

# Think that to be shown, the plot can a) be rendered (as above) or be "printed" by the notebook
if return_plot:
return dot
Loading
Loading