Skip to content

Commit

Permalink
Add init_model as a fit parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-oHr-N committed Mar 19, 2020
1 parent 68c9bbc commit 439f976
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions optgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
feature_name: Union[List[str], str] = "auto",
feval: Optional[Callable] = None,
fobj: Optional[Callable] = None,
init_model: Optional[Union[lgb.Booster, lgb.LGBMModel, str]] = None,
n_estimators: int = 100,
param_distributions: Optional[
Dict[str, distributions.BaseDistribution]
Expand All @@ -135,6 +136,7 @@ def __init__(
self.feature_name = feature_name
self.feval = feval
self.fobj = fobj
self.init_model = init_model
self.is_higher_better = is_higher_better
self.n_estimators = n_estimators
self.n_samples = n_samples
Expand All @@ -155,6 +157,7 @@ def __call__(self, trial: trial_module.Trial) -> float:
feval=self.feval,
fobj=self.fobj,
folds=self.cv,
init_model=self.init_model,
num_boost_round=self.n_estimators,
) # Dict[str, List[float]]
best_iteration = callbacks[0]._best_iteration # type: ignore
Expand Down Expand Up @@ -386,6 +389,7 @@ def _make_booster(
feature_name: Union[List[str], str] = "auto",
categorical_feature: Union[List[int], List[str], str] = "auto",
callbacks: Optional[List[Callable]] = None,
init_model: Optional[Union[lgb.Booster, lgb.LGBMModel, str]] = None,
) -> Union[_VotingBooster, lgb.Booster]:
if self.refit:
booster = lgb.train(
Expand All @@ -396,6 +400,7 @@ def _make_booster(
feature_name=feature_name,
categorical_feature=categorical_feature,
callbacks=callbacks,
init_model=init_model,
)

booster.free_dataset()
Expand Down Expand Up @@ -424,6 +429,7 @@ def fit(
feature_name: Union[List[str], str] = "auto",
categorical_feature: Union[List[int], List[str], str] = "auto",
callbacks: Optional[List[Callable]] = None,
init_model: Optional[Union[lgb.Booster, lgb.LGBMModel, str]] = None,
groups: Optional[OneDimArrayLikeType] = None,
**fit_params: Any
) -> "LGBMModel":
Expand Down Expand Up @@ -467,6 +473,10 @@ def fit(
callbacks
List of callback functions that are applied at each iteration.
init_model
Filename of LightGBM model, Booster instance or LGBMModel instance
used for continue training.
groups
Group labels for the samples used while splitting the dataset into
train/test set. If `group` is not None, this parameter is ignored.
Expand Down Expand Up @@ -542,6 +552,9 @@ def fit(
eval_name = params["metric"]
is_higher_better = _is_higher_better(params["metric"])

if isinstance(init_model, lgb.LGBMModel):
init_model = init_model.booster_

if self.study is None:
sampler = samplers.TPESampler(seed=seed)

Expand Down Expand Up @@ -578,6 +591,7 @@ def fit(
feature_name=feature_name,
feval=feval,
fobj=fobj,
init_model=init_model,
n_estimators=self.n_estimators,
param_distributions=self.param_distributions,
)
Expand Down Expand Up @@ -841,6 +855,7 @@ def fit(
feature_name: Union[List[str], str] = "auto",
categorical_feature: Union[List[int], List[str], str] = "auto",
callbacks: Optional[List[Callable]] = None,
init_model: Optional[Union[lgb.Booster, lgb.LGBMModel, str]] = None,
groups: Optional[OneDimArrayLikeType] = None,
**fit_params: Any
) -> "LGBMClassifier":
Expand Down Expand Up @@ -884,6 +899,10 @@ def fit(
callbacks
List of callback functions that are applied at each iteration.
init_model
Filename of LightGBM model, Booster instance or LGBMModel instance
used for continue training.
groups
Group labels for the samples used while splitting the dataset into
train/test set. If `group` is not None, this parameter is ignored.
Expand Down Expand Up @@ -913,6 +932,7 @@ def fit(
feature_name=feature_name,
categorical_feature=categorical_feature,
callbacks=callbacks,
init_model=init_model,
groups=groups,
**fit_params
)
Expand Down

0 comments on commit 439f976

Please sign in to comment.