Skip to content

Commit

Permalink
Fix AttributeError
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-oHr-N committed Feb 5, 2020
1 parent 953d52b commit 577f3a3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 25 deletions.
38 changes: 18 additions & 20 deletions optgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,24 +233,6 @@ def feature_importance(self, **kwargs: Any) -> np.ndarray:


class _BaseOGBMModel(lgb.LGBMModel):
@property
def _param_distributions(
self,
) -> Dict[str, optuna.distributions.BaseDistribution]:
if self.param_distributions is None:
return DEFAULT_PARAM_DISTRIBUTIONS

return self.param_distributions

@property
def _random_state(self) -> Optional[int]:
if self.random_state is None or isinstance(self.random_state, int):
return self.random_state

random_state = check_random_state(self.random_state)

return random_state.randint(0, MAX_INT)

def __init__(
self,
boosting_type: str = "gbdt",
Expand Down Expand Up @@ -315,6 +297,22 @@ def __init__(
def _check_is_fitted(self) -> None:
check_is_fitted(self, "n_features_")

def _get_param_distributions(
self,
) -> Dict[str, optuna.distributions.BaseDistribution]:
if self.param_distributions is None:
return DEFAULT_PARAM_DISTRIBUTIONS

return self.param_distributions

def _get_random_state(self) -> Optional[int]:
if self.random_state is None or isinstance(self.random_state, int):
return self.random_state

random_state = check_random_state(self.random_state)

return random_state.randint(0, MAX_INT)

def fit(
self,
X: TWO_DIM_ARRAYLIKE_TYPE,
Expand Down Expand Up @@ -374,7 +372,7 @@ def fit(
is_classifier = self._estimator_type == "classifier"
cv = check_cv(self.cv, y, is_classifier)

seed = self._random_state
seed = self._get_random_state()

params = self.get_params()

Expand Down Expand Up @@ -442,7 +440,7 @@ def fit(
objective = _Objective(
params,
dataset,
self._param_distributions,
self._get_param_distributions(),
eval_name,
is_higher_better,
callbacks=callbacks,
Expand Down
20 changes: 15 additions & 5 deletions tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from sklearn.datasets import load_iris
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.utils.estimator_checks import check_estimator

# from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.estimator_checks import check_set_params

from optgbm.sklearn import OGBMClassifier
from optgbm.sklearn import OGBMRegressor
Expand All @@ -31,14 +33,22 @@ def zero_one_loss(
return "zero_one_loss", np.mean(y_true != y_pred), False


@pytest.mark.skip
def test_ogbm_classifier() -> None:
check_estimator(OGBMClassifier)
clf = OGBMClassifier()
name = clf.__class__.__name__

# check_estimator(clf)

check_set_params(name, clf)


@pytest.mark.skip
def test_ogbm_regressor() -> None:
check_estimator(OGBMRegressor)
reg = OGBMRegressor()
name = reg.__class__.__name__

# check_estimator(reg)

check_set_params(name, reg)


@pytest.mark.parametrize("reg_sqrt", [False, True])
Expand Down

0 comments on commit 577f3a3

Please sign in to comment.