diff --git a/doc/changes/devel/12620.bugfix.rst b/doc/changes/devel/12620.bugfix.rst new file mode 100644 index 00000000000..0e8d53f02b1 --- /dev/null +++ b/doc/changes/devel/12620.bugfix.rst @@ -0,0 +1 @@ +Fix for new sklearn metadata routing protocol in decoding search_light, by `Alex Gramfort`_ diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index a5fc53865cc..d78b123f746 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -14,7 +14,7 @@ from mne.decoding.transformer import Vectorizer from mne.utils import _record_warnings, check_version, use_log_level -pytest.importorskip("sklearn") +sklearn = pytest.importorskip("sklearn") NEW_MULTICLASS_SAMPLE_WEIGHT = check_version("sklearn", "1.4") @@ -186,7 +186,17 @@ def transform(self, X): assert isinstance(pipe.estimators_[0], BaggingClassifier) -def test_generalization_light(): +@pytest.fixture() +def metadata_routing(): + """Temporarily enable metadata routing for new sklearn.""" + if NEW_MULTICLASS_SAMPLE_WEIGHT: + sklearn.set_config(enable_metadata_routing=True) + yield + if NEW_MULTICLASS_SAMPLE_WEIGHT: + sklearn.set_config(enable_metadata_routing=False) + + +def test_generalization_light(metadata_routing): """Test GeneralizingEstimator.""" from sklearn.linear_model import LogisticRegression from sklearn.metrics import roc_auc_score @@ -194,7 +204,9 @@ def test_generalization_light(): from sklearn.pipeline import make_pipeline if NEW_MULTICLASS_SAMPLE_WEIGHT: - logreg = OneVsRestClassifier(LogisticRegression(random_state=0)) + clf = LogisticRegression(random_state=0) + clf.set_fit_request(sample_weight=True) + logreg = OneVsRestClassifier(clf) else: logreg = LogisticRegression( solver="liblinear", @@ -208,10 +220,7 @@ def test_generalization_light(): gl = GeneralizingEstimator(logreg) assert_equal(repr(gl)[:23], "") # transforms @@ -346,7 +355,6 @@ def predict_proba(self, X): @pytest.mark.slowtest def test_sklearn_compliance(): """Test LinearModel compliance with sklearn.""" - pytest.importorskip("sklearn") from sklearn.linear_model import LogisticRegression from sklearn.utils.estimator_checks import check_estimator diff --git a/tools/vulture_allowlist.py b/tools/vulture_allowlist.py index 5c3d41c356e..c0ac3317e09 100644 --- a/tools/vulture_allowlist.py +++ b/tools/vulture_allowlist.py @@ -15,6 +15,7 @@ verbose_debug few_surfaces disabled_event_channels +metadata_routing # Others exc_value