From 439f976c73b31e9da36953d72db6bad8369c5a27 Mon Sep 17 00:00:00 2001 From: Y-oHr-N Date: Thu, 19 Mar 2020 23:53:49 +0900 Subject: [PATCH] Add init_model as a fit parameter --- optgbm/sklearn.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/optgbm/sklearn.py b/optgbm/sklearn.py index 233363a..05e25e2 100644 --- a/optgbm/sklearn.py +++ b/optgbm/sklearn.py @@ -120,6 +120,7 @@ def __init__( feature_name: Union[List[str], str] = "auto", feval: Optional[Callable] = None, fobj: Optional[Callable] = None, + init_model: Optional[Union[lgb.Booster, lgb.LGBMModel, str]] = None, n_estimators: int = 100, param_distributions: Optional[ Dict[str, distributions.BaseDistribution] @@ -135,6 +136,7 @@ def __init__( self.feature_name = feature_name self.feval = feval self.fobj = fobj + self.init_model = init_model self.is_higher_better = is_higher_better self.n_estimators = n_estimators self.n_samples = n_samples @@ -155,6 +157,7 @@ def __call__(self, trial: trial_module.Trial) -> float: feval=self.feval, fobj=self.fobj, folds=self.cv, + init_model=self.init_model, num_boost_round=self.n_estimators, ) # Dict[str, List[float]] best_iteration = callbacks[0]._best_iteration # type: ignore @@ -386,6 +389,7 @@ def _make_booster( feature_name: Union[List[str], str] = "auto", categorical_feature: Union[List[int], List[str], str] = "auto", callbacks: Optional[List[Callable]] = None, + init_model: Optional[Union[lgb.Booster, lgb.LGBMModel, str]] = None, ) -> Union[_VotingBooster, lgb.Booster]: if self.refit: booster = lgb.train( @@ -396,6 +400,7 @@ def _make_booster( feature_name=feature_name, categorical_feature=categorical_feature, callbacks=callbacks, + init_model=init_model, ) booster.free_dataset() @@ -424,6 +429,7 @@ def fit( feature_name: Union[List[str], str] = "auto", categorical_feature: Union[List[int], List[str], str] = "auto", callbacks: Optional[List[Callable]] = None, + init_model: Optional[Union[lgb.Booster, lgb.LGBMModel, str]] = None, groups: Optional[OneDimArrayLikeType] = None, **fit_params: Any ) -> "LGBMModel": @@ -467,6 +473,10 @@ def fit( callbacks List of callback functions that are applied at each iteration. + init_model + Filename of LightGBM model, Booster instance or LGBMModel instance + used for continue training. + groups Group labels for the samples used while splitting the dataset into train/test set. If `group` is not None, this parameter is ignored. @@ -542,6 +552,9 @@ def fit( eval_name = params["metric"] is_higher_better = _is_higher_better(params["metric"]) + if isinstance(init_model, lgb.LGBMModel): + init_model = init_model.booster_ + if self.study is None: sampler = samplers.TPESampler(seed=seed) @@ -578,6 +591,7 @@ def fit( feature_name=feature_name, feval=feval, fobj=fobj, + init_model=init_model, n_estimators=self.n_estimators, param_distributions=self.param_distributions, ) @@ -841,6 +855,7 @@ def fit( feature_name: Union[List[str], str] = "auto", categorical_feature: Union[List[int], List[str], str] = "auto", callbacks: Optional[List[Callable]] = None, + init_model: Optional[Union[lgb.Booster, lgb.LGBMModel, str]] = None, groups: Optional[OneDimArrayLikeType] = None, **fit_params: Any ) -> "LGBMClassifier": @@ -884,6 +899,10 @@ def fit( callbacks List of callback functions that are applied at each iteration. + init_model + Filename of LightGBM model, Booster instance or LGBMModel instance + used for continue training. + groups Group labels for the samples used while splitting the dataset into train/test set. If `group` is not None, this parameter is ignored. @@ -913,6 +932,7 @@ def fit( feature_name=feature_name, categorical_feature=categorical_feature, callbacks=callbacks, + init_model=init_model, groups=groups, **fit_params )