From 577f3a3e3a39e4e7eea1f816facdf6ea86f6df1a Mon Sep 17 00:00:00 2001 From: Y-oHr-N Date: Wed, 5 Feb 2020 17:00:14 +0900 Subject: [PATCH] Fix AttributeError --- optgbm/sklearn.py | 38 ++++++++++++++++++-------------------- tests/test_sklearn.py | 20 +++++++++++++++----- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/optgbm/sklearn.py b/optgbm/sklearn.py index 8114bde..5f437ae 100644 --- a/optgbm/sklearn.py +++ b/optgbm/sklearn.py @@ -233,24 +233,6 @@ def feature_importance(self, **kwargs: Any) -> np.ndarray: class _BaseOGBMModel(lgb.LGBMModel): - @property - def _param_distributions( - self, - ) -> Dict[str, optuna.distributions.BaseDistribution]: - if self.param_distributions is None: - return DEFAULT_PARAM_DISTRIBUTIONS - - return self.param_distributions - - @property - def _random_state(self) -> Optional[int]: - if self.random_state is None or isinstance(self.random_state, int): - return self.random_state - - random_state = check_random_state(self.random_state) - - return random_state.randint(0, MAX_INT) - def __init__( self, boosting_type: str = "gbdt", @@ -315,6 +297,22 @@ def __init__( def _check_is_fitted(self) -> None: check_is_fitted(self, "n_features_") + def _get_param_distributions( + self, + ) -> Dict[str, optuna.distributions.BaseDistribution]: + if self.param_distributions is None: + return DEFAULT_PARAM_DISTRIBUTIONS + + return self.param_distributions + + def _get_random_state(self) -> Optional[int]: + if self.random_state is None or isinstance(self.random_state, int): + return self.random_state + + random_state = check_random_state(self.random_state) + + return random_state.randint(0, MAX_INT) + def fit( self, X: TWO_DIM_ARRAYLIKE_TYPE, @@ -374,7 +372,7 @@ def fit( is_classifier = self._estimator_type == "classifier" cv = check_cv(self.cv, y, is_classifier) - seed = self._random_state + seed = self._get_random_state() params = self.get_params() @@ -442,7 +440,7 @@ def fit( objective = _Objective( params, dataset, - self._param_distributions, + self._get_param_distributions(), eval_name, is_higher_better, callbacks=callbacks, diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 52b01bd..8aaed73 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -15,7 +15,9 @@ from sklearn.datasets import load_iris from sklearn.datasets import load_wine from sklearn.model_selection import train_test_split -from sklearn.utils.estimator_checks import check_estimator + +# from sklearn.utils.estimator_checks import check_estimator +from sklearn.utils.estimator_checks import check_set_params from optgbm.sklearn import OGBMClassifier from optgbm.sklearn import OGBMRegressor @@ -31,14 +33,22 @@ def zero_one_loss( return "zero_one_loss", np.mean(y_true != y_pred), False -@pytest.mark.skip def test_ogbm_classifier() -> None: - check_estimator(OGBMClassifier) + clf = OGBMClassifier() + name = clf.__class__.__name__ + + # check_estimator(clf) + + check_set_params(name, clf) -@pytest.mark.skip def test_ogbm_regressor() -> None: - check_estimator(OGBMRegressor) + reg = OGBMRegressor() + name = reg.__class__.__name__ + + # check_estimator(reg) + + check_set_params(name, reg) @pytest.mark.parametrize("reg_sqrt", [False, True])