diff --git a/skltemplate/utils/tests/test_discovery.py b/skltemplate/utils/tests/test_discovery.py index 2be0f0c..a78c029 100644 --- a/skltemplate/utils/tests/test_discovery.py +++ b/skltemplate/utils/tests/test_discovery.py @@ -1,6 +1,8 @@ # Authors: scikit-learn-contrib developers # License: BSD 3 clause +import pytest + from skltemplate.utils.discovery import all_displays, all_estimators, all_functions @@ -8,6 +10,16 @@ def test_all_estimators(): estimators = all_estimators() assert len(estimators) == 3 + estimators = all_estimators(type_filter="classifier") + assert len(estimators) == 1 + + estimators = all_estimators(type_filter=["classifier", "transformer"]) + assert len(estimators) == 2 + + err_msg = "Parameter type_filter must be" + with pytest.raises(ValueError, match=err_msg): + all_estimators(type_filter="xxxx") + def test_all_displays(): displays = all_displays()