diff --git a/optgbm/sklearn.py b/optgbm/sklearn.py index 05e25e2..0eb9aa0 100644 --- a/optgbm/sklearn.py +++ b/optgbm/sklearn.py @@ -522,21 +522,13 @@ def fit( ): params.pop(attr, None) + params["objective"] = self._get_objective() params["random_state"] = seed params["verbose"] = -1 if self._n_classes is not None and self._n_classes > 2: params["num_classes"] = self._n_classes - if callable(self.objective): - fobj = _ObjectiveFunctionWrapper(self.objective) - else: - fobj = None - - self._objective = self._get_objective() - - params["objective"] = self._objective - if callable(eval_metric): params["metric"] = "None" feval = _EvalFunctionWrapper(eval_metric) @@ -552,6 +544,11 @@ def fit( eval_name = params["metric"] is_higher_better = _is_higher_better(params["metric"]) + if callable(self.objective): + fobj = _ObjectiveFunctionWrapper(self.objective) + else: + fobj = None + if isinstance(init_model, lgb.LGBMModel): init_model = init_model.booster_ @@ -612,6 +609,7 @@ def fit( "best_iteration" ] self._best_score = self.study_.best_value + self._objective = params["objective"] self.best_params_ = {**params, **self.study_.best_params} self.n_splits_ = cv.get_n_splits(X, y, groups=groups)