From 94f257c168faa0d4e50873508bcf6ad5a7e574f8 Mon Sep 17 00:00:00 2001 From: Lukas Heumos Date: Sun, 17 Dec 2023 23:05:24 +0100 Subject: [PATCH] Add input checks for imputers (#625) Signed-off-by: zethson --- ehrapy/preprocessing/_imputation.py | 233 ++++++++++++++++------------ tests/anndata/test_anndata_ext.py | 8 - 2 files changed, 131 insertions(+), 110 deletions(-) diff --git a/ehrapy/preprocessing/_imputation.py b/ehrapy/preprocessing/_imputation.py index fc0d9998..11fb448f 100644 --- a/ehrapy/preprocessing/_imputation.py +++ b/ehrapy/preprocessing/_imputation.py @@ -239,30 +239,34 @@ def knn_impute( print( "[bold yellow]scikit-learn-intelex is not available. Install via [blue]pip install scikit-learn-intelex [yellow] for faster imputations." ) - - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task("[blue]Running KNN imputation", total=1) - # numerical only data needs no encoding since KNN Imputation can be applied directly - if np.issubdtype(adata.X.dtype, np.number): - _knn_impute(adata, var_names, n_neighbours) - else: - # ordinal encoding is used since non-numerical data can not be imputed using KNN Imputation - enc = OrdinalEncoder() - column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"]) - adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) - # impute the data using KNN imputation - _knn_impute(adata, var_names, n_neighbours) - # imputing on encoded columns might result in float numbers; those can not be decoded - # cast them to int to ensure they can be decoded - adata.X[::, column_indices] = np.rint(adata.X[::, column_indices]).astype(int) - # knn imputer transforms X dtype to numerical (encoded), but object is needed for decoding - adata.X = adata.X.astype("object") - # decode ordinal encoding to obtain imputed original data - adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices]) + try: + with Progress( + "[progress.description]{task.description}", + SpinnerColumn(), + refresh_per_second=1500, + ) as progress: + progress.add_task("[blue]Running KNN imputation", total=1) + # numerical only data needs no encoding since KNN Imputation can be applied directly + if np.issubdtype(adata.X.dtype, np.number): + _knn_impute(adata, var_names, n_neighbours) + else: + # ordinal encoding is used since non-numerical data can not be imputed using KNN Imputation + enc = OrdinalEncoder() + column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"]) + adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) + # impute the data using KNN imputation + _knn_impute(adata, var_names, n_neighbours) + # imputing on encoded columns might result in float numbers; those can not be decoded + # cast them to int to ensure they can be decoded + adata.X[::, column_indices] = np.rint(adata.X[::, column_indices]).astype(int) + # knn imputer transforms X dtype to numerical (encoded), but object is needed for decoding + adata.X = adata.X.astype("object") + # decode ordinal encoding to obtain imputed original data + adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices]) + except ValueError as e: + if "Data matrix has wrong shape" in str(e): + print("[bold red]Check that your matrix does not contain any NaN values!") + raise if _check_module_importable("sklearnex"): # pragma: no cover unpatch_sklearn() @@ -356,62 +360,69 @@ def miss_forest_impute( from sklearn.ensemble import ExtraTreesRegressor, RandomForestClassifier from sklearn.impute import IterativeImputer - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task("[blue]Running MissForest imputation", total=1) - - if settings.n_jobs == 1: # pragma: no cover - print("[bold yellow]The number of jobs is only 1. To decrease the runtime set [blue]ep.settings.n_jobs=-1.") - - imp_num = IterativeImputer( - estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs), - initial_strategy=num_initial_strategy, - max_iter=max_iter, - random_state=random_state, - ) - # initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data - imp_cat = IterativeImputer( - estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs), - initial_strategy="most_frequent", - max_iter=max_iter, - random_state=random_state, - ) - - if isinstance(var_names, list): - var_indices = _get_column_indices(adata, var_names) # type: ignore - adata.X[::, var_indices] = imp_num.fit_transform(adata.X[::, var_indices]) - elif isinstance(var_names, dict) or var_names is None: - if var_names: - try: - non_num_vars = var_names["non_numerical"] - num_vars = var_names["numerical"] - except KeyError: # pragma: no cover - raise ValueError( - "One or both of your keys provided for var_names are unknown. Only " - "numerical and non_numerical are available!" - ) from None - non_num_indices = _get_column_indices(adata, non_num_vars) - num_indices = _get_column_indices(adata, num_vars) - - # infer non numerical and numerical indices automatically - else: - non_num_indices_set = _get_non_numerical_column_indices(adata.X) - num_indices = [idx for idx in range(adata.X.shape[1]) if idx not in non_num_indices_set] - non_num_indices = list(non_num_indices_set) + try: + with Progress( + "[progress.description]{task.description}", + SpinnerColumn(), + refresh_per_second=1500, + ) as progress: + progress.add_task("[blue]Running MissForest imputation", total=1) + + if settings.n_jobs == 1: # pragma: no cover + print( + "[bold yellow]The number of jobs is only 1. To decrease the runtime set [blue]ep.settings.n_jobs=-1." + ) + + imp_num = IterativeImputer( + estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs), + initial_strategy=num_initial_strategy, + max_iter=max_iter, + random_state=random_state, + ) + # initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data + imp_cat = IterativeImputer( + estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs), + initial_strategy="most_frequent", + max_iter=max_iter, + random_state=random_state, + ) - # encode all non numerical columns - if non_num_indices: - enc = OrdinalEncoder() - adata.X[::, non_num_indices] = enc.fit_transform(adata.X[::, non_num_indices]) - # this step is the most expensive one and might extremely slow down the impute process - if num_indices: - adata.X[::, num_indices] = imp_num.fit_transform(adata.X[::, num_indices]) - if non_num_indices: - adata.X[::, non_num_indices] = imp_cat.fit_transform(adata.X[::, non_num_indices]) - adata.X[::, non_num_indices] = enc.inverse_transform(adata.X[::, non_num_indices]) + if isinstance(var_names, list): + var_indices = _get_column_indices(adata, var_names) # type: ignore + adata.X[::, var_indices] = imp_num.fit_transform(adata.X[::, var_indices]) + elif isinstance(var_names, dict) or var_names is None: + if var_names: + try: + non_num_vars = var_names["non_numerical"] + num_vars = var_names["numerical"] + except KeyError: # pragma: no cover + raise ValueError( + "One or both of your keys provided for var_names are unknown. Only " + "numerical and non_numerical are available!" + ) from None + non_num_indices = _get_column_indices(adata, non_num_vars) + num_indices = _get_column_indices(adata, num_vars) + + # infer non numerical and numerical indices automatically + else: + non_num_indices_set = _get_non_numerical_column_indices(adata.X) + num_indices = [idx for idx in range(adata.X.shape[1]) if idx not in non_num_indices_set] + non_num_indices = list(non_num_indices_set) + + # encode all non numerical columns + if non_num_indices: + enc = OrdinalEncoder() + adata.X[::, non_num_indices] = enc.fit_transform(adata.X[::, non_num_indices]) + # this step is the most expensive one and might extremely slow down the impute process + if num_indices: + adata.X[::, num_indices] = imp_num.fit_transform(adata.X[::, num_indices]) + if non_num_indices: + adata.X[::, non_num_indices] = imp_cat.fit_transform(adata.X[::, non_num_indices]) + adata.X[::, non_num_indices] = enc.inverse_transform(adata.X[::, non_num_indices]) + except ValueError as e: + if "Data matrix has wrong shape" in str(e): + print("[bold red]Check that your matrix does not contain any NaN values!") + raise if _check_module_importable("sklearnex"): # pragma: no cover unpatch_sklearn() @@ -1025,29 +1036,47 @@ def mice_forest_impute( adata = adata.copy() _warn_imputation_threshold(adata, var_names, threshold=warning_threshold) - - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task("[blue]Running miceforest", total=1) - if np.issubdtype(adata.X.dtype, np.number): - _miceforest_impute( - adata, var_names, save_all_iterations, random_state, inplace, iterations, variable_parameters, verbose - ) - else: - # ordinal encoding is used since non-numerical data can not be imputed using miceforest - enc = OrdinalEncoder() - column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"]) - adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) - # impute the data using miceforest - _miceforest_impute( - adata, var_names, save_all_iterations, random_state, inplace, iterations, variable_parameters, verbose - ) - adata.X = adata.X.astype("object") - # decode ordinal encoding to obtain imputed original data - adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices]) + try: + with Progress( + "[progress.description]{task.description}", + SpinnerColumn(), + refresh_per_second=1500, + ) as progress: + progress.add_task("[blue]Running miceforest", total=1) + if np.issubdtype(adata.X.dtype, np.number): + _miceforest_impute( + adata, + var_names, + save_all_iterations, + random_state, + inplace, + iterations, + variable_parameters, + verbose, + ) + else: + # ordinal encoding is used since non-numerical data can not be imputed using miceforest + enc = OrdinalEncoder() + column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"]) + adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) + # impute the data using miceforest + _miceforest_impute( + adata, + var_names, + save_all_iterations, + random_state, + inplace, + iterations, + variable_parameters, + verbose, + ) + adata.X = adata.X.astype("object") + # decode ordinal encoding to obtain imputed original data + adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices]) + except ValueError as e: + if "Data matrix has wrong shape" in str(e): + print("[bold red]Check that your matrix does not contain any NaN values!") + raise if var_names: logg.debug( diff --git a/tests/anndata/test_anndata_ext.py b/tests/anndata/test_anndata_ext.py index c42be655..a848e377 100644 --- a/tests/anndata/test_anndata_ext.py +++ b/tests/anndata/test_anndata_ext.py @@ -347,14 +347,6 @@ def _setup_anndata_to_df() -> tuple[list, list, list]: return col1_val, col2_val, col3_val - def test_generate_anndata(self): - adata = generate_anndata((3, 3), include_nlp=False) - assert adata.X.shape == (3, 3) - - adata = generate_anndata((2, 2), include_nlp=True) - assert adata.X.shape == (2, 2) - assert "nlp" in adata.obs.columns - class TestAnnDataUtil: def setup_method(self):