Skip to content

Commit

Permalink
Add test for clf scenario for LOCO and PI
Browse files Browse the repository at this point in the history
  • Loading branch information
paillarj committed Oct 15, 2024
1 parent 9ac51c7 commit f653bc3
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
29 changes: 27 additions & 2 deletions hidimstat/test/test_loco.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.base import clone
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split

from hidimstat.loco import LOCO


def test_LOCO(linear_scenario):
def test_loco(linear_scenario):
X, y, beta = linear_scenario
important_features = np.where(beta != 0)[0]
non_important_features = np.where(beta == 0)[0]
Expand Down Expand Up @@ -53,3 +55,26 @@ def test_LOCO(linear_scenario):

importance = vim["importance"]
assert importance[0].mean() > importance[1].mean()

# Classification case
y_clf = np.where(y > np.median(y), 1, 0)
_, _, y_train_clf, y_test_clf = train_test_split(X, y_clf, random_state=0)
logistic_model = LogisticRegression()
logistic_model.fit(X_train, y_train_clf)

loco_clf = LOCO(
estimator=logistic_model,
score_proba=True,
random_state=0,
n_jobs=1,
loss=log_loss,
)
loco_clf.fit(
X_train,
y_train_clf,
groups=None,
)
vim_clf = loco_clf.score(X_test, y_test_clf)

importance_clf = vim_clf["importance"]
assert importance_clf.shape == (X.shape[1],)
30 changes: 28 additions & 2 deletions hidimstat/test/test_permutation_importance.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split

from hidimstat.permutation_importance import PermutationImportance


def test_CPI(linear_scenario):
def test_permutation_importance(linear_scenario):
X, y, beta = linear_scenario
important_features = np.where(beta != 0)[0]
non_important_features = np.where(beta == 0)[0]
Expand Down Expand Up @@ -55,3 +56,28 @@ def test_CPI(linear_scenario):

importance = vim["importance"]
assert importance[0].mean() > importance[1].mean()

# Classification case
y_clf = np.where(y > np.median(y), 1, 0)
_, _, y_train_clf, y_test_clf = train_test_split(X, y_clf, random_state=0)
logistic_model = LogisticRegression()
logistic_model.fit(X_train, y_train_clf)

pi_clf = PermutationImportance(
estimator=logistic_model,
n_permutations=20,
score_proba=True,
random_state=0,
n_jobs=1,
loss=log_loss,
)

pi_clf.fit(
X_train,
y_train_clf,
groups=None,
)
vim_clf = pi_clf.score(X_test, y_test_clf)

importance_clf = vim_clf["importance"]
assert importance_clf.shape == (X.shape[1],)

0 comments on commit f653bc3

Please sign in to comment.