diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index c384125..3dd11da 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -18,6 +18,7 @@ from sklearn.model_selection import train_test_split # from sklearn.utils.estimator_checks import check_estimator +from sklearn.utils.estimator_checks import check_estimators_pickle from sklearn.utils.estimator_checks import check_set_params from optgbm.sklearn import OGBMClassifier @@ -48,6 +49,7 @@ def test_ogbm_classifier() -> None: # check_estimator(clf) + check_estimators_pickle(name, clf) check_set_params(name, clf) @@ -57,6 +59,7 @@ def test_ogbm_regressor() -> None: # check_estimator(reg) + check_estimators_pickle(name, reg) check_set_params(name, reg)