Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rank features groups obs #622

Merged
merged 19 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions ehrapy/tools/feature_ranking/_rank_features_groups.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import Literal, Optional, Union

Expand All @@ -6,6 +8,8 @@
import scanpy as sc
from anndata import AnnData

from ehrapy.anndata import move_to_x
from ehrapy.preprocessing import encode
from ehrapy.tools import _method_options


Expand Down Expand Up @@ -240,6 +244,51 @@ def _evaluate_categorical_features(
)


def _check_no_datetime_columns(df):
datetime_cols = [col for col in df.columns if df[col].dtype == "datetime64[ns]"]
eroell marked this conversation as resolved.
Show resolved Hide resolved
if datetime_cols:
raise ValueError(f"Columns with datetime format found: {datetime_cols}")


def _get_intersection(adata_uns, key, selection):
"""Get intersection of adata_uns[key] and selection"""
if key in adata_uns:
uns_enc_to_keep = list(set(adata_uns["encoded_non_numerical_columns"]) & set(selection))
else:
uns_enc_to_keep = []
return uns_enc_to_keep


def _check_columns_to_rank_dict(columns_to_rank):
if isinstance(columns_to_rank, str):
if columns_to_rank == "all":
_var_subset = _obs_subset = False
else:
raise ValueError("If columns_to_rank is a string, it must be 'all'.")

elif isinstance(columns_to_rank, dict):
allowed_keys = {"var_names", "obs_names"}
for key in columns_to_rank.keys():
if key not in allowed_keys:
raise ValueError(
f"columns_to_rank dictionary must have only keys 'var_names' and/or 'obs_names', not {key}."
)
if not isinstance(key, str):
raise ValueError(f"columns_to_rank dictionary keys must be strings, not {type(key)}.")

for key, value in columns_to_rank.items():
if not isinstance(value, Iterable) or any(not isinstance(item, str) for item in value):
raise ValueError(f"The value associated with key '{key}' must be an iterable of strings.")

_var_subset = "var_names" in columns_to_rank.keys()
_obs_subset = "obs_names" in columns_to_rank.keys()

else:
raise ValueError("columns_to_rank must be either 'all' or a dictionary.")

return _var_subset, _obs_subset


def rank_features_groups(
adata: AnnData,
groupby: str,
Expand All @@ -255,6 +304,8 @@ def rank_features_groups(
correction_method: _method_options._correction_method = "benjamini-hochberg",
tie_correct: bool = False,
layer: Optional[str] = None,
field_to_rank: Union[Literal["layer"], Literal["obs"], Literal["layer_and_obs"]] = "layer",
columns_to_rank: Union[dict[str, Iterable[str]], Literal["all"]] = "all",
**kwds,
) -> None: # pragma: no cover
"""Rank features for characterizing groups.
Expand Down Expand Up @@ -288,6 +339,8 @@ def rank_features_groups(
Used only for statistical tests (e.g. doesn't work for "logreg" `num_cols_method`)
tie_correct: Use tie correction for `'wilcoxon'` scores. Used only for `'wilcoxon'`.
layer: Key from `adata.layers` whose value will be used to perform tests on.
field_to_rank: Set to `layer` to rank variables in `adata.X` or `adata.layers[layer]` (default), `obs` to rank `adata.obs`, or `layer_and_obs` to rank both. Layer needs to be None if this is not 'layer'.
columns_to_rank: Subset of columns to rank. If 'all', all columns are used. If a dictionary, it must have keys 'var_names' and/or 'obs_names' and values must be iterables of strings. E.g. {'var_names': ['glucose'], 'obs_names': ['age', 'height']}.
**kwds: Are passed to test methods. Currently this affects only parameters that
are passed to :class:`sklearn.linear_model.LogisticRegression`.
For instance, you can pass `penalty='l1'` to try to come up with a
Expand Down Expand Up @@ -320,8 +373,88 @@ def rank_features_groups(
>>> ep.tl.rank_features_groups(adata, "service_unit")
>>> ep.pl.rank_features_groups(adata)
"""
if layer is not None and field_to_rank == "obs":
raise ValueError("If 'layer' is not None, 'field_to_rank' cannot be 'obs'.")

if field_to_rank not in ["layer", "obs", "layer_and_obs"]:
raise ValueError(f"layer must be one of 'layer', 'obs', 'layer_and_obs', not {field_to_rank}")

# to give better error messages, check if columns_to_rank have valid keys and values here
_var_subset, _obs_subset = _check_columns_to_rank_dict(columns_to_rank)

adata = adata.copy() if copy else adata

# to create a minimal adata object below, grab a reference to X/layer of the original adata,
# subsetted to the specified columns
if field_to_rank in ["layer", "layer_and_obs"]:
# for some reason ruff insists on this type check. columns_to_rank is always a dict with key "var_names" if _var_subset is True
if _var_subset and isinstance(columns_to_rank, dict):
X_to_keep = (
adata[:, columns_to_rank["var_names"]].X
if layer is None
else adata[:, columns_to_rank["var_names"]].layers[layer]
)
var_to_keep = adata[:, columns_to_rank["var_names"]].var
uns_num_to_keep = _get_intersection(
adata_uns=adata.uns, key="numerical_columns", selection=columns_to_rank["var_names"]
)
uns_non_num_to_keep = _get_intersection(
adata_uns=adata.uns, key="non_numerical_columns", selection=columns_to_rank["var_names"]
)
uns_enc_to_keep = _get_intersection(
adata_uns=adata.uns, key="encoded_non_numerical_columns", selection=columns_to_rank["var_names"]
)

else:
X_to_keep = adata.X if layer is None else adata.layers[layer]
var_to_keep = adata.var
uns_num_to_keep = adata.uns["numerical_columns"] if "numerical_columns" in adata.uns else []
uns_enc_to_keep = (
adata.uns["encoded_non_numerical_columns"] if "encoded_non_numerical_columns" in adata.uns else []
)
uns_non_num_to_keep = adata.uns["non_numerical_columns"] if "non_numerical_columns" in adata.uns else []

else:
X_to_keep = np.zeros((len(adata), 1))
Zethson marked this conversation as resolved.
Show resolved Hide resolved
var_to_keep = pd.DataFrame({"dummy": [0]})
uns_num_to_keep = []
uns_enc_to_keep = []
uns_non_num_to_keep = []

adata_minimal = sc.AnnData(
X=X_to_keep,
obs=adata.obs,
var=var_to_keep,
uns={
"numerical_columns": uns_num_to_keep,
"encoded_non_numerical_columns": uns_enc_to_keep,
"non_numerical_columns": uns_non_num_to_keep,
},
)

if field_to_rank in ["obs", "layer_and_obs"]:
# want columns of obs to become variables in X to be able to use rank_features_groups
# for some reason ruff insists on this type check. columns_to_rank is always a dict with key "obs_names" if _obs_subset is True
if _obs_subset and isinstance(columns_to_rank, dict):
obs_to_move = adata.obs[columns_to_rank["obs_names"]].keys()
else:
obs_to_move = adata.obs.keys()
_check_no_datetime_columns(adata.obs[obs_to_move])
adata_minimal = move_to_x(adata_minimal, list(obs_to_move))

if field_to_rank == "obs":
# the 0th column is a dummy of zeros and is meaningless in this case, and needs to be removed
adata_minimal = adata_minimal[:, 1:]

adata_minimal = encode(adata_minimal, autodetect=True, encodings="label")

if layer is not None:
adata_minimal.layers[layer] = adata_minimal.X

# save the reference to the original adata, because we will need to access it later
adata_orig = adata
adata = adata_minimal

if not adata.obs[groupby].dtype == "category":
adata.obs[groupby] = pd.Categorical(adata.obs[groupby])

Expand Down Expand Up @@ -403,12 +536,17 @@ def rank_features_groups(
groups_order=group_names,
)

# if field_to_rank was obs or layer_and_obs, the adata object we have been working with is adata_minimal
adata_orig.uns[key_added] = adata.uns[key_added]
adata = adata_orig

# Adjust p values
if "pvals" in adata.uns[key_added]:
adata.uns[key_added]["pvals_adj"] = _adjust_pvalues(
adata.uns[key_added]["pvals"], corr_method=correction_method
)

# For some reason, pts should be a DataFrame
if "pts" in adata.uns[key_added]:
adata.uns[key_added]["pts"] = pd.DataFrame(adata.uns[key_added]["pts"])

Expand Down
13 changes: 13 additions & 0 deletions tests/tools/test_data_features_ranking/dataset1.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
idx,sys_bp_entry,dia_bp_entry,glucose,weight,disease,station
1,138,78,80,77,A,ICU
2,139,79,90,76,A,ICU
3,140,80,120,60,A,MICU
4,141,81,130,90,A,MICU
5,148,77,80,110,B,ICU
6,149,78,135,78,B,ICU
7,150,79,125,56,B,MICU
8,151,80,95,76,B,MICU
9,158,55,70,67,C,ICU
10,159,56,85,82,C,ICU
11,160,57,125,59,C,MICU
12,161,58,125,81,C,MICU
114 changes: 114 additions & 0 deletions tests/tools/test_features_ranking.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from pathlib import Path

import numpy as np
import pandas as pd
import pytest

import ehrapy as ep
import ehrapy.tools.feature_ranking._rank_features_groups as _utils
from ehrapy.io._read import read_csv

CURRENT_DIR = Path(__file__).parent
_TEST_PATH = f"{CURRENT_DIR}/test_data_features_ranking"


class TestHelperFunctions:
Expand Down Expand Up @@ -270,3 +276,111 @@ def test_only_cat_features(self):
assert "scores" in adata.uns["rank_features_groups"]
assert "logfoldchanges" in adata.uns["rank_features_groups"]
assert "pvals_adj" in adata.uns["rank_features_groups"]

@pytest.mark.parametrize("field_to_rank", ["layer", "obs", "layer_and_obs"])
def test_rank_adata_immutability_property(self, field_to_rank):
"""
Test that rank_features_group does not modify the adata object passed to it,
except for the desired .uns field.
This test is important because to save memory, copies are made conservatively in rank_features_groups
"""
adata = read_csv(
dataset_path=f"{_TEST_PATH}/dataset1.csv", columns_x_only=["station", "sys_bp_entry", "dia_bp_entry"]
)
adata = ep.pp.encode(adata, encodings={"label": ["station"]})
adata_orig = adata.copy()

ep.tl.rank_features_groups(adata, groupby="disease", field_to_rank=field_to_rank)

assert adata_orig.shape == adata.shape
assert adata_orig.X.shape == adata.X.shape
assert adata_orig.obs.shape == adata.obs.shape
assert adata_orig.var.shape == adata.var.shape

assert np.allclose(adata_orig.X, adata.X)
assert np.array_equal(adata_orig.obs, adata.obs)

assert "rank_features_groups" in adata.uns

@pytest.mark.parametrize("field_to_rank", ["layer", "obs", "layer_and_obs"])
def test_rank_features_groups_generates_outputs(self, field_to_rank):
"""
Test that the desired output is generated
"""

adata = read_csv(
dataset_path=f"{_TEST_PATH}/dataset1.csv",
columns_obs_only=["disease", "station", "sys_bp_entry", "dia_bp_entry"],
)

ep.tl.rank_features_groups(adata, groupby="disease", field_to_rank=field_to_rank)

# check standard rank_features_groups entries
assert "names" in adata.uns["rank_features_groups"]
assert "pvals" in adata.uns["rank_features_groups"]
assert "scores" in adata.uns["rank_features_groups"]
assert "pvals_adj" in adata.uns["rank_features_groups"]
assert "logfoldchanges" in adata.uns["rank_features_groups"]
assert "log2foldchanges" not in adata.uns["rank_features_groups"]
assert "pts" not in adata.uns["rank_features_groups"]

if field_to_rank == "layer" or field_to_rank == "obs":
assert len(adata.uns["rank_features_groups"]["names"]) == 3 # It only captures the length of each group
assert len(adata.uns["rank_features_groups"]["pvals"]) == 3
assert len(adata.uns["rank_features_groups"]["scores"]) == 3

elif field_to_rank == "layer_and_obs":
assert len(adata.uns["rank_features_groups"]["names"]) == 6 # It only captures the length of each group
assert len(adata.uns["rank_features_groups"]["pvals"]) == 6
assert len(adata.uns["rank_features_groups"]["scores"]) == 6

def test_rank_features_groups_consistent_results(self):
adata_features_in_x = read_csv(
dataset_path=f"{_TEST_PATH}/dataset1.csv",
columns_x_only=["station", "sys_bp_entry", "dia_bp_entry", "glucose"],
)
adata_features_in_x = ep.pp.encode(adata_features_in_x, encodings={"label": ["station"]})

adata_features_in_obs = read_csv(
dataset_path=f"{_TEST_PATH}/dataset1.csv",
columns_obs_only=["disease", "station", "sys_bp_entry", "dia_bp_entry", "glucose"],
)

adata_features_in_x_and_obs = read_csv(
dataset_path=f"{_TEST_PATH}/dataset1.csv",
columns_obs_only=["disease", "station"],
)
# to keep the same variables as in the datsets above, in order to make the comparison of consistency
adata_features_in_x_and_obs = adata_features_in_x_and_obs[:, ["sys_bp_entry", "dia_bp_entry", "glucose"]]
adata_features_in_x_and_obs.uns["numerical_columns"] = ["sys_bp_entry", "dia_bp_entry", "glucose"]

ep.tl.rank_features_groups(adata_features_in_x, groupby="disease")
ep.tl.rank_features_groups(adata_features_in_obs, groupby="disease", field_to_rank="obs")
ep.tl.rank_features_groups(adata_features_in_x_and_obs, groupby="disease", field_to_rank="layer_and_obs")

for record in adata_features_in_x.uns["rank_features_groups"]["names"].dtype.names:
assert np.allclose(
adata_features_in_x.uns["rank_features_groups"]["scores"][record],
adata_features_in_obs.uns["rank_features_groups"]["scores"][record],
)
assert np.allclose(
np.array(adata_features_in_x.uns["rank_features_groups"]["pvals"][record]),
np.array(adata_features_in_obs.uns["rank_features_groups"]["pvals"][record]),
)
assert np.array_equal(
np.array(adata_features_in_x.uns["rank_features_groups"]["names"][record]),
np.array(adata_features_in_obs.uns["rank_features_groups"]["names"][record]),
)
for record in adata_features_in_x.uns["rank_features_groups"]["names"].dtype.names:
assert np.allclose(
adata_features_in_x.uns["rank_features_groups"]["scores"][record],
adata_features_in_x_and_obs.uns["rank_features_groups"]["scores"][record],
)
assert np.allclose(
np.array(adata_features_in_x.uns["rank_features_groups"]["pvals"][record]),
np.array(adata_features_in_x_and_obs.uns["rank_features_groups"]["pvals"][record]),
)
assert np.array_equal(
np.array(adata_features_in_x.uns["rank_features_groups"]["names"][record]),
np.array(adata_features_in_x_and_obs.uns["rank_features_groups"]["names"][record]),
)
Loading