Skip to content

Commit

Permalink
Add faiss backend for KNN imputation (#704)
Browse files Browse the repository at this point in the history
* [pre-commit.ci] pre-commit autoupdate (#702)

updates:
- [github.com/astral-sh/ruff-pre-commit: v0.3.7 → v0.4.1](astral-sh/ruff-pre-commit@v0.3.7...v0.4.1)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Unify feature type annotations (#697)

* Added infer and check feature types methods

* Added and tested decorator and adapted feature importances

* Added test cases and updated imputation

* Adapted encoding

* Feature specifications output

* Fix HVF test

* Added tree printing for inferred feature types

* Notebook fixes

* Fix feature importance test

* Beautify tree

* Base encoding on original feature types

* Added to usage

* Update logging message

* Improved method description

* Submodule update

Signed-off-by: zethson <[email protected]>

* PR Revisions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update submodule

* Extended method docs description

---------

Signed-off-by: zethson <[email protected]>
Co-authored-by: zethson <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Add faiss backend for KNN imputation

Signed-off-by: zethson <[email protected]>

* Fix MIMIC-II notebook

Signed-off-by: zethson <[email protected]>

* Fix MIMIC-II notebook

Signed-off-by: zethson <[email protected]>

* Refactoring

Signed-off-by: zethson <[email protected]>

* Refactoring

Signed-off-by: zethson <[email protected]>

* Submodule

Signed-off-by: zethson <[email protected]>

---------

Signed-off-by: zethson <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Lilly May <[email protected]>
  • Loading branch information
3 people authored Apr 24, 2024
1 parent d97f179 commit 63fca78
Show file tree
Hide file tree
Showing 16 changed files with 317 additions and 90 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
hooks:
- id: prettier
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.7
rev: v0.4.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes]
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
1 change: 1 addition & 0 deletions docs/usage/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ Methods that extract and visualize tool-specific annotation in an AnnData object
:toctree: anndata
:nosignatures:
anndata.infer_feature_types
anndata.df_to_anndata
anndata.anndata_to_df
anndata.move_to_obs
Expand Down
1 change: 1 addition & 0 deletions ehrapy/anndata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ehrapy.anndata._feature_specifications import check_feature_types, infer_feature_types
from ehrapy.anndata.anndata_ext import (
anndata_to_df,
delete_from_obs,
Expand Down
8 changes: 7 additions & 1 deletion ehrapy/anndata/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
# -----------------------
# The column name and used values in adata.var for column types.

EHRAPY_TYPE_KEY = "ehrapy_column_type"
EHRAPY_TYPE_KEY = "ehrapy_column_type" # TODO: Change to ENCODING_TYPE_KEY
NUMERIC_TAG = "numeric"
NON_NUMERIC_TAG = "non_numeric"
NON_NUMERIC_ENCODED_TAG = "non_numeric_encoded"


FEATURE_TYPE_KEY = "feature_type"
CONTINUOUS_TAG = "numeric" # TODO: Eventually rename to NUMERIC_TAG (as soon as the other NUMERIC_TAG is removed)
CATEGORICAL_TAG = "categorical"
DATE_TAG = "date"
95 changes: 95 additions & 0 deletions ehrapy/anndata/_feature_specifications.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Literal

import numpy as np
import pandas as pd
from anndata import AnnData
from rich import print
from rich.tree import Tree

from ehrapy import logging as logg
from ehrapy.anndata._constants import CATEGORICAL_TAG, CONTINUOUS_TAG, DATE_TAG, FEATURE_TYPE_KEY
from ehrapy.anndata.anndata_ext import anndata_to_df


def infer_feature_types(adata: AnnData, layer: str | None = None, output: Literal["tree", "dataframe"] | None = "tree"):
"""Infer feature types from AnnData object.
For each feature in adata.var_names, the method infers one of the following types: 'date', 'categorical', or 'numeric'.
The inferred types are stored in adata.var['feature_type']. Please check the inferred types and adjust if necessary using
adata.var['feature_type']['feature1']='corrected_type'.
Be aware that not all features stored numerically are of 'numeric' type, as categorical features might be stored in a numerically encoded format.
For example, a feature with values [0, 1, 2] might be a categorical feature with three categories. This is accounted for in the method, but it is
recommended to check the inferred types.
Args:
adata: :class:`~anndata.AnnData` object storing the EHR data.
layer: The layer to use from the AnnData object. If None, the X layer is used.
output: The output format. Choose between 'tree', 'dataframe', or None. If 'tree', the feature types will be printed to the console in a tree format.
If 'dataframe', a pandas DataFrame with the feature types will be returned. If None, nothing will be returned. Defaults to 'tree'.
"""
feature_types = {}

df = anndata_to_df(adata, layer=layer)
for feature in adata.var_names:
col = df[feature].dropna()
majority_type = col.apply(type).value_counts().idxmax()
if majority_type == pd.Timestamp:
feature_types[feature] = DATE_TAG
elif majority_type not in [int, float, complex]:
feature_types[feature] = CATEGORICAL_TAG
# Guess categorical if the feature is an integer and the values are 0/1 to n-1 with no gaps
elif np.all(i.is_integer() for i in col) and (
(col.min() == 0 and np.all(np.sort(col.unique()) == np.arange(col.nunique())))
or (col.min() == 1 and np.all(np.sort(col.unique()) == np.arange(1, col.nunique() + 1)))
):
feature_types[feature] = CATEGORICAL_TAG
else:
feature_types[feature] = CONTINUOUS_TAG

adata.var[FEATURE_TYPE_KEY] = pd.Series(feature_types)[adata.var_names]

logg.info(
f"Stored feature types in adata.var['{FEATURE_TYPE_KEY}']. Please verify and adjust if necessary using adata.var['{FEATURE_TYPE_KEY}']['feature1']='corrected_type'."
)

if output == "tree":
feature_type_overview(adata)
elif output == "dataframe":
return adata.var[FEATURE_TYPE_KEY]
elif output is not None:
raise ValueError(f"Output format {output} not recognized. Choose between 'tree', 'dataframe', or None.")


def check_feature_types(func):
def wrapper(adata, *args, **kwargs):
if FEATURE_TYPE_KEY not in adata.var.keys():
raise ValueError("Feature types are not specified in adata.var. Please run `infer_feature_types` first.")
np.all(adata.var[FEATURE_TYPE_KEY].isin([CATEGORICAL_TAG, CONTINUOUS_TAG, DATE_TAG]))
return func(adata, *args, **kwargs)

return wrapper


@check_feature_types
def feature_type_overview(adata: AnnData):
"""Print an overview of the feature types in the AnnData object."""
tree = Tree(
f"[b] Detected feature types for AnnData object with {len(adata.obs_names)} obs and {len(adata.var_names)} vars",
guide_style="underline2",
)

branch = tree.add("📅[b] Date features")
for date in sorted(adata.var_names[adata.var[FEATURE_TYPE_KEY] == DATE_TAG]):
branch.add(date)

branch = tree.add("📐[b] Numerical features")
for numeric in sorted(adata.var_names[adata.var[FEATURE_TYPE_KEY] == CONTINUOUS_TAG]):
branch.add(numeric)

branch = tree.add("🗂️[b] Categorical features")
cat_features = adata.var_names[adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG]
df = anndata_to_df(adata[:, cat_features])
for categorical in sorted(cat_features):
branch.add(f"{categorical} ({df.loc[:, categorical].nunique()} categories)")

print(tree)
3 changes: 1 addition & 2 deletions ehrapy/anndata/anndata_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def _adata_type_overview(
f"[b green]Variable names for AnnData object with {len(adata.obs_names)} obs and {len(adata.var_names)} vars",
guide_style="underline2 bright_blue",
)

if "var_to_encoding" in adata.uns.keys():
original_values = adata.uns["original_values_categoricals"]
branch = tree.add("🔐 Encoded variables", style="b green")
Expand Down Expand Up @@ -545,8 +546,6 @@ def set_numeric_vars(
for i in range(n_values):
adata.X[:, vars_idx[i]] = values[:, i]

logg.info(f"Values in columns {vars} were replaced by {values}.")

return adata


Expand Down
17 changes: 16 additions & 1 deletion ehrapy/preprocessing/_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
from sklearn.preprocessing import LabelEncoder, OneHotEncoder

from ehrapy import logging as logg
from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NON_NUMERIC_TAG, NUMERIC_TAG
from ehrapy.anndata._constants import (
CATEGORICAL_TAG,
CONTINUOUS_TAG,
DATE_TAG,
EHRAPY_TYPE_KEY,
FEATURE_TYPE_KEY,
NON_NUMERIC_ENCODED_TAG,
NON_NUMERIC_TAG,
NUMERIC_TAG,
)
from ehrapy.anndata.anndata_ext import _get_var_indices_for_type

multi_encoding_modes = {"hash"}
Expand Down Expand Up @@ -141,6 +150,9 @@ def encode(
new_var = pd.DataFrame(index=encoded_var_names)
new_var[EHRAPY_TYPE_KEY] = adata.var[EHRAPY_TYPE_KEY].copy()
new_var.loc[new_var.index.str.contains("ehrapycat")] = NON_NUMERIC_ENCODED_TAG
if FEATURE_TYPE_KEY in adata.var.keys():
new_var[FEATURE_TYPE_KEY] = adata.var[FEATURE_TYPE_KEY].copy()
new_var.loc[new_var.index.str.contains("ehrapycat"), FEATURE_TYPE_KEY] = CATEGORICAL_TAG

encoded_ann_data = AnnData(
encoded_x,
Expand Down Expand Up @@ -243,6 +255,9 @@ def encode(
new_var = pd.DataFrame(index=encoded_var_names)
new_var[EHRAPY_TYPE_KEY] = adata.var[EHRAPY_TYPE_KEY].copy()
new_var.loc[new_var.index.str.contains("ehrapycat")] = NON_NUMERIC_ENCODED_TAG
if FEATURE_TYPE_KEY in adata.var.keys():
new_var[FEATURE_TYPE_KEY] = adata.var[FEATURE_TYPE_KEY].copy()
new_var.loc[new_var.index.str.contains("ehrapycat"), FEATURE_TYPE_KEY] = CATEGORICAL_TAG

try:
encoded_ann_data = AnnData(
Expand Down
Loading

0 comments on commit 63fca78

Please sign in to comment.