Skip to content

Commit

Permalink
Fix fit on LGBMModel
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-oHr-N committed Mar 20, 2020
1 parent 439f976 commit 94389b8
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions optgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,21 +522,13 @@ def fit(
):
params.pop(attr, None)

params["objective"] = self._get_objective()
params["random_state"] = seed
params["verbose"] = -1

if self._n_classes is not None and self._n_classes > 2:
params["num_classes"] = self._n_classes

if callable(self.objective):
fobj = _ObjectiveFunctionWrapper(self.objective)
else:
fobj = None

self._objective = self._get_objective()

params["objective"] = self._objective

if callable(eval_metric):
params["metric"] = "None"
feval = _EvalFunctionWrapper(eval_metric)
Expand All @@ -552,6 +544,11 @@ def fit(
eval_name = params["metric"]
is_higher_better = _is_higher_better(params["metric"])

if callable(self.objective):
fobj = _ObjectiveFunctionWrapper(self.objective)
else:
fobj = None

if isinstance(init_model, lgb.LGBMModel):
init_model = init_model.booster_

Expand Down Expand Up @@ -612,6 +609,7 @@ def fit(
"best_iteration"
]
self._best_score = self.study_.best_value
self._objective = params["objective"]
self.best_params_ = {**params, **self.study_.best_params}
self.n_splits_ = cv.get_n_splits(X, y, groups=groups)

Expand Down

0 comments on commit 94389b8

Please sign in to comment.