diff --git a/hidimstat/test/test_loco.py b/hidimstat/test/test_loco.py index e17bd60..2b3b79f 100644 --- a/hidimstat/test/test_loco.py +++ b/hidimstat/test/test_loco.py @@ -1,11 +1,13 @@ import numpy as np -from sklearn.linear_model import LinearRegression +from sklearn.base import clone +from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.metrics import log_loss from sklearn.model_selection import train_test_split from hidimstat.loco import LOCO -def test_LOCO(linear_scenario): +def test_loco(linear_scenario): X, y, beta = linear_scenario important_features = np.where(beta != 0)[0] non_important_features = np.where(beta == 0)[0] @@ -53,3 +55,26 @@ def test_LOCO(linear_scenario): importance = vim["importance"] assert importance[0].mean() > importance[1].mean() + + # Classification case + y_clf = np.where(y > np.median(y), 1, 0) + _, _, y_train_clf, y_test_clf = train_test_split(X, y_clf, random_state=0) + logistic_model = LogisticRegression() + logistic_model.fit(X_train, y_train_clf) + + loco_clf = LOCO( + estimator=logistic_model, + score_proba=True, + random_state=0, + n_jobs=1, + loss=log_loss, + ) + loco_clf.fit( + X_train, + y_train_clf, + groups=None, + ) + vim_clf = loco_clf.score(X_test, y_test_clf) + + importance_clf = vim_clf["importance"] + assert importance_clf.shape == (X.shape[1],) diff --git a/hidimstat/test/test_permutation_importance.py b/hidimstat/test/test_permutation_importance.py index 0c94ef7..1f2dfb5 100644 --- a/hidimstat/test/test_permutation_importance.py +++ b/hidimstat/test/test_permutation_importance.py @@ -1,11 +1,12 @@ import numpy as np -from sklearn.linear_model import LinearRegression +from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.metrics import log_loss from sklearn.model_selection import train_test_split from hidimstat.permutation_importance import PermutationImportance -def test_CPI(linear_scenario): +def test_permutation_importance(linear_scenario): X, y, beta = linear_scenario important_features = np.where(beta != 0)[0] non_important_features = np.where(beta == 0)[0] @@ -55,3 +56,28 @@ def test_CPI(linear_scenario): importance = vim["importance"] assert importance[0].mean() > importance[1].mean() + + # Classification case + y_clf = np.where(y > np.median(y), 1, 0) + _, _, y_train_clf, y_test_clf = train_test_split(X, y_clf, random_state=0) + logistic_model = LogisticRegression() + logistic_model.fit(X_train, y_train_clf) + + pi_clf = PermutationImportance( + estimator=logistic_model, + n_permutations=20, + score_proba=True, + random_state=0, + n_jobs=1, + loss=log_loss, + ) + + pi_clf.fit( + X_train, + y_train_clf, + groups=None, + ) + vim_clf = pi_clf.score(X_test, y_test_clf) + + importance_clf = vim_clf["importance"] + assert importance_clf.shape == (X.shape[1],)