diff --git a/optgbm/sklearn.py b/optgbm/sklearn.py index 686b3de..821f7fc 100644 --- a/optgbm/sklearn.py +++ b/optgbm/sklearn.py @@ -4,6 +4,7 @@ import logging import time +from pkg_resources import parse_version from typing import Any from typing import Callable from typing import Dict @@ -14,24 +15,15 @@ import lightgbm as lgb import numpy as np import optuna +import sklearn from sklearn.base import ClassifierMixin from sklearn.base import RegressorMixin from sklearn.model_selection import BaseCrossValidator from sklearn.preprocessing import LabelEncoder from sklearn.utils import check_random_state -from sklearn.utils import safe_indexing from sklearn.utils.validation import check_is_fitted -try: # lightgbm<=2.2.3 - from lightgbm.sklearn import _eval_function_wrapper as _EvalFunctionWrapper - from lightgbm.sklearn import ( - _objective_function_wrapper as _ObjectiveFunctionWrapper, - ) -except ImportError: - from lightgbm.sklearn import _EvalFunctionWrapper - from lightgbm.sklearn import _ObjectiveFunctionWrapper - from .utils import check_cv from .utils import check_fit_params from .utils import check_X @@ -40,6 +32,20 @@ from .utils import RANDOM_STATE_TYPE from .utils import TWO_DIM_ARRAYLIKE_TYPE +if parse_version(lgb.__version__) >= parse_version("2.3"): + from lightgbm.sklearn import _EvalFunctionWrapper + from lightgbm.sklearn import _ObjectiveFunctionWrapper +else: + from lightgbm.sklearn import _eval_function_wrapper as _EvalFunctionWrapper + from lightgbm.sklearn import ( + _objective_function_wrapper as _ObjectiveFunctionWrapper, + ) + +if parse_version(sklearn.__version__) >= parse_version("0.22"): + from sklearn.utils import _safe_indexing as safe_indexing +else: + from sklearn.utils import safe_indexing + MAX_INT = np.iinfo(np.int32).max OBJECTIVE2METRIC = { @@ -240,14 +246,14 @@ def __init__( def from_representations( cls, representations: List[str], weights: Optional[np.ndarray] = None ) -> "_VotingBooster": - try: # lightgbm<=2.2.3 + if parse_version(lgb.__version__) >= parse_version("2.3"): boosters = [ - lgb.Booster(params={"model_str": model_str}) + lgb.Booster(model_str=model_str, silent=True) for model_str in representations ] - except TypeError: + else: boosters = [ - lgb.Booster(model_str=model_str, silent=True) + lgb.Booster(params={"model_str": model_str}) for model_str in representations ]