Skip to content

Commit

Permalink
Add input checks for imputers (#625)
Browse files Browse the repository at this point in the history
Signed-off-by: zethson <[email protected]>
  • Loading branch information
Zethson authored Dec 17, 2023
1 parent 3dcc9ed commit 94f257c
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 110 deletions.
233 changes: 131 additions & 102 deletions ehrapy/preprocessing/_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,30 +239,34 @@ def knn_impute(
print(
"[bold yellow]scikit-learn-intelex is not available. Install via [blue]pip install scikit-learn-intelex [yellow] for faster imputations."
)

with Progress(
"[progress.description]{task.description}",
SpinnerColumn(),
refresh_per_second=1500,
) as progress:
progress.add_task("[blue]Running KNN imputation", total=1)
# numerical only data needs no encoding since KNN Imputation can be applied directly
if np.issubdtype(adata.X.dtype, np.number):
_knn_impute(adata, var_names, n_neighbours)
else:
# ordinal encoding is used since non-numerical data can not be imputed using KNN Imputation
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using KNN imputation
_knn_impute(adata, var_names, n_neighbours)
# imputing on encoded columns might result in float numbers; those can not be decoded
# cast them to int to ensure they can be decoded
adata.X[::, column_indices] = np.rint(adata.X[::, column_indices]).astype(int)
# knn imputer transforms X dtype to numerical (encoded), but object is needed for decoding
adata.X = adata.X.astype("object")
# decode ordinal encoding to obtain imputed original data
adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices])
try:
with Progress(
"[progress.description]{task.description}",
SpinnerColumn(),
refresh_per_second=1500,
) as progress:
progress.add_task("[blue]Running KNN imputation", total=1)
# numerical only data needs no encoding since KNN Imputation can be applied directly
if np.issubdtype(adata.X.dtype, np.number):
_knn_impute(adata, var_names, n_neighbours)
else:
# ordinal encoding is used since non-numerical data can not be imputed using KNN Imputation
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using KNN imputation
_knn_impute(adata, var_names, n_neighbours)
# imputing on encoded columns might result in float numbers; those can not be decoded
# cast them to int to ensure they can be decoded
adata.X[::, column_indices] = np.rint(adata.X[::, column_indices]).astype(int)
# knn imputer transforms X dtype to numerical (encoded), but object is needed for decoding
adata.X = adata.X.astype("object")
# decode ordinal encoding to obtain imputed original data
adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices])
except ValueError as e:
if "Data matrix has wrong shape" in str(e):
print("[bold red]Check that your matrix does not contain any NaN values!")
raise

if _check_module_importable("sklearnex"): # pragma: no cover
unpatch_sklearn()
Expand Down Expand Up @@ -356,62 +360,69 @@ def miss_forest_impute(
from sklearn.ensemble import ExtraTreesRegressor, RandomForestClassifier
from sklearn.impute import IterativeImputer

with Progress(
"[progress.description]{task.description}",
SpinnerColumn(),
refresh_per_second=1500,
) as progress:
progress.add_task("[blue]Running MissForest imputation", total=1)

if settings.n_jobs == 1: # pragma: no cover
print("[bold yellow]The number of jobs is only 1. To decrease the runtime set [blue]ep.settings.n_jobs=-1.")

imp_num = IterativeImputer(
estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs),
initial_strategy=num_initial_strategy,
max_iter=max_iter,
random_state=random_state,
)
# initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data
imp_cat = IterativeImputer(
estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs),
initial_strategy="most_frequent",
max_iter=max_iter,
random_state=random_state,
)

if isinstance(var_names, list):
var_indices = _get_column_indices(adata, var_names) # type: ignore
adata.X[::, var_indices] = imp_num.fit_transform(adata.X[::, var_indices])
elif isinstance(var_names, dict) or var_names is None:
if var_names:
try:
non_num_vars = var_names["non_numerical"]
num_vars = var_names["numerical"]
except KeyError: # pragma: no cover
raise ValueError(
"One or both of your keys provided for var_names are unknown. Only "
"numerical and non_numerical are available!"
) from None
non_num_indices = _get_column_indices(adata, non_num_vars)
num_indices = _get_column_indices(adata, num_vars)

# infer non numerical and numerical indices automatically
else:
non_num_indices_set = _get_non_numerical_column_indices(adata.X)
num_indices = [idx for idx in range(adata.X.shape[1]) if idx not in non_num_indices_set]
non_num_indices = list(non_num_indices_set)
try:
with Progress(
"[progress.description]{task.description}",
SpinnerColumn(),
refresh_per_second=1500,
) as progress:
progress.add_task("[blue]Running MissForest imputation", total=1)

if settings.n_jobs == 1: # pragma: no cover
print(
"[bold yellow]The number of jobs is only 1. To decrease the runtime set [blue]ep.settings.n_jobs=-1."
)

imp_num = IterativeImputer(
estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs),
initial_strategy=num_initial_strategy,
max_iter=max_iter,
random_state=random_state,
)
# initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data
imp_cat = IterativeImputer(
estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs),
initial_strategy="most_frequent",
max_iter=max_iter,
random_state=random_state,
)

# encode all non numerical columns
if non_num_indices:
enc = OrdinalEncoder()
adata.X[::, non_num_indices] = enc.fit_transform(adata.X[::, non_num_indices])
# this step is the most expensive one and might extremely slow down the impute process
if num_indices:
adata.X[::, num_indices] = imp_num.fit_transform(adata.X[::, num_indices])
if non_num_indices:
adata.X[::, non_num_indices] = imp_cat.fit_transform(adata.X[::, non_num_indices])
adata.X[::, non_num_indices] = enc.inverse_transform(adata.X[::, non_num_indices])
if isinstance(var_names, list):
var_indices = _get_column_indices(adata, var_names) # type: ignore
adata.X[::, var_indices] = imp_num.fit_transform(adata.X[::, var_indices])
elif isinstance(var_names, dict) or var_names is None:
if var_names:
try:
non_num_vars = var_names["non_numerical"]
num_vars = var_names["numerical"]
except KeyError: # pragma: no cover
raise ValueError(
"One or both of your keys provided for var_names are unknown. Only "
"numerical and non_numerical are available!"
) from None
non_num_indices = _get_column_indices(adata, non_num_vars)
num_indices = _get_column_indices(adata, num_vars)

# infer non numerical and numerical indices automatically
else:
non_num_indices_set = _get_non_numerical_column_indices(adata.X)
num_indices = [idx for idx in range(adata.X.shape[1]) if idx not in non_num_indices_set]
non_num_indices = list(non_num_indices_set)

# encode all non numerical columns
if non_num_indices:
enc = OrdinalEncoder()
adata.X[::, non_num_indices] = enc.fit_transform(adata.X[::, non_num_indices])
# this step is the most expensive one and might extremely slow down the impute process
if num_indices:
adata.X[::, num_indices] = imp_num.fit_transform(adata.X[::, num_indices])
if non_num_indices:
adata.X[::, non_num_indices] = imp_cat.fit_transform(adata.X[::, non_num_indices])
adata.X[::, non_num_indices] = enc.inverse_transform(adata.X[::, non_num_indices])
except ValueError as e:
if "Data matrix has wrong shape" in str(e):
print("[bold red]Check that your matrix does not contain any NaN values!")
raise

if _check_module_importable("sklearnex"): # pragma: no cover
unpatch_sklearn()
Expand Down Expand Up @@ -1025,29 +1036,47 @@ def mice_forest_impute(
adata = adata.copy()

_warn_imputation_threshold(adata, var_names, threshold=warning_threshold)

with Progress(
"[progress.description]{task.description}",
SpinnerColumn(),
refresh_per_second=1500,
) as progress:
progress.add_task("[blue]Running miceforest", total=1)
if np.issubdtype(adata.X.dtype, np.number):
_miceforest_impute(
adata, var_names, save_all_iterations, random_state, inplace, iterations, variable_parameters, verbose
)
else:
# ordinal encoding is used since non-numerical data can not be imputed using miceforest
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using miceforest
_miceforest_impute(
adata, var_names, save_all_iterations, random_state, inplace, iterations, variable_parameters, verbose
)
adata.X = adata.X.astype("object")
# decode ordinal encoding to obtain imputed original data
adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices])
try:
with Progress(
"[progress.description]{task.description}",
SpinnerColumn(),
refresh_per_second=1500,
) as progress:
progress.add_task("[blue]Running miceforest", total=1)
if np.issubdtype(adata.X.dtype, np.number):
_miceforest_impute(
adata,
var_names,
save_all_iterations,
random_state,
inplace,
iterations,
variable_parameters,
verbose,
)
else:
# ordinal encoding is used since non-numerical data can not be imputed using miceforest
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using miceforest
_miceforest_impute(
adata,
var_names,
save_all_iterations,
random_state,
inplace,
iterations,
variable_parameters,
verbose,
)
adata.X = adata.X.astype("object")
# decode ordinal encoding to obtain imputed original data
adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices])
except ValueError as e:
if "Data matrix has wrong shape" in str(e):
print("[bold red]Check that your matrix does not contain any NaN values!")
raise

if var_names:
logg.debug(
Expand Down
8 changes: 0 additions & 8 deletions tests/anndata/test_anndata_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,14 +347,6 @@ def _setup_anndata_to_df() -> tuple[list, list, list]:

return col1_val, col2_val, col3_val

def test_generate_anndata(self):
adata = generate_anndata((3, 3), include_nlp=False)
assert adata.X.shape == (3, 3)

adata = generate_anndata((2, 2), include_nlp=True)
assert adata.X.shape == (2, 2)
assert "nlp" in adata.obs.columns


class TestAnnDataUtil:
def setup_method(self):
Expand Down

0 comments on commit 94f257c

Please sign in to comment.