Skip to content

Commit

Permalink
Remove _get_param_distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-oHr-N committed Feb 20, 2020
1 parent 79fd388 commit 685d8d9
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions optgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,19 @@ def __call__(self, env: LightGBMCallbackEnv) -> None:


class _Objective(object):
@property
def _param_distributions(
self,
) -> Dict[str, optuna.distributions.BaseDistribution]:
if self.param_distributions is None:
return DEFAULT_PARAM_DISTRIBUTIONS

return self.param_distributions

def __init__(
self,
params: Dict[str, Any],
dataset: lgb.Dataset,
param_distributions: Dict[str, optuna.distributions.BaseDistribution],
eval_name: str,
is_higher_better: bool,
callbacks: Optional[List[Callable]] = None,
Expand All @@ -123,6 +131,9 @@ def __init__(
feval: Optional[Callable] = None,
fobj: Optional[Callable] = None,
n_estimators: int = 100,
param_distributions: Optional[
Dict[str, optuna.distributions.BaseDistribution]
] = None,
) -> None:
self.callbacks = callbacks
self.categorical_feature = categorical_feature
Expand Down Expand Up @@ -204,7 +215,7 @@ def _get_callbacks(self, trial: optuna.trial.Trial) -> List[Callable]:
def _get_params(self, trial: optuna.trial.Trial) -> Dict[str, Any]:
params: Dict[str, Any] = self.params.copy()

for name, distribution in self.param_distributions.items():
for name, distribution in self._param_distributions.items():
params[name] = trial._suggest(name, distribution)

return params
Expand Down Expand Up @@ -339,14 +350,6 @@ def _get_objective(self) -> str:
else:
return "binary"

def _get_param_distributions(
self,
) -> Dict[str, optuna.distributions.BaseDistribution]:
if self.param_distributions is None:
return DEFAULT_PARAM_DISTRIBUTIONS

return self.param_distributions

def _get_random_state(self) -> Optional[int]:
if self.random_state is None or isinstance(self.random_state, int):
return self.random_state
Expand Down Expand Up @@ -545,7 +548,6 @@ def fit(
objective = _Objective(
params,
dataset,
self._get_param_distributions(),
eval_name,
is_higher_better,
callbacks=callbacks,
Expand All @@ -557,6 +559,7 @@ def fit(
feval=feval,
fobj=fobj,
n_estimators=self.n_estimators,
param_distributions=self.param_distributions,
)

self.study_.optimize(
Expand Down

0 comments on commit 685d8d9

Please sign in to comment.