Skip to content

Commit

Permalink
fix for sklearn metadata (#12620)
Browse files Browse the repository at this point in the history
Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
agramfort and larsoner authored May 20, 2024
1 parent 5c2a255 commit b12396d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12620.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix for new sklearn metadata routing protocol in decoding search_light, by `Alex Gramfort`_
24 changes: 16 additions & 8 deletions mne/decoding/tests/test_search_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mne.decoding.transformer import Vectorizer
from mne.utils import _record_warnings, check_version, use_log_level

pytest.importorskip("sklearn")
sklearn = pytest.importorskip("sklearn")

NEW_MULTICLASS_SAMPLE_WEIGHT = check_version("sklearn", "1.4")

Expand Down Expand Up @@ -186,15 +186,27 @@ def transform(self, X):
assert isinstance(pipe.estimators_[0], BaggingClassifier)


def test_generalization_light():
@pytest.fixture()
def metadata_routing():
"""Temporarily enable metadata routing for new sklearn."""
if NEW_MULTICLASS_SAMPLE_WEIGHT:
sklearn.set_config(enable_metadata_routing=True)
yield
if NEW_MULTICLASS_SAMPLE_WEIGHT:
sklearn.set_config(enable_metadata_routing=False)


def test_generalization_light(metadata_routing):
"""Test GeneralizingEstimator."""
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.multiclass import OneVsRestClassifier
from sklearn.pipeline import make_pipeline

if NEW_MULTICLASS_SAMPLE_WEIGHT:
logreg = OneVsRestClassifier(LogisticRegression(random_state=0))
clf = LogisticRegression(random_state=0)
clf.set_fit_request(sample_weight=True)
logreg = OneVsRestClassifier(clf)
else:
logreg = LogisticRegression(
solver="liblinear",
Expand All @@ -208,10 +220,7 @@ def test_generalization_light():
gl = GeneralizingEstimator(logreg)
assert_equal(repr(gl)[:23], "<GeneralizingEstimator(")
gl.fit(X, y)
# TODO: Need to fix this for 1.4.2+
# https://scikit-learn.org/stable/metadata_routing.html
if not NEW_MULTICLASS_SAMPLE_WEIGHT:
gl.fit(X, y, sample_weight=np.ones_like(y))
gl.fit(X, y, sample_weight=np.ones_like(y))

assert_equal(gl.__repr__()[-28:], ", fitted with 10 estimators>")
# transforms
Expand Down Expand Up @@ -346,7 +355,6 @@ def predict_proba(self, X):
@pytest.mark.slowtest
def test_sklearn_compliance():
"""Test LinearModel compliance with sklearn."""
pytest.importorskip("sklearn")
from sklearn.linear_model import LogisticRegression
from sklearn.utils.estimator_checks import check_estimator

Expand Down
1 change: 1 addition & 0 deletions tools/vulture_allowlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
verbose_debug
few_surfaces
disabled_event_channels
metadata_routing

# Others
exc_value
Expand Down

0 comments on commit b12396d

Please sign in to comment.