Skip to content

Commit

Permalink
Merge pull request #64 from Y-oHr-N/use-parse-version
Browse files Browse the repository at this point in the history
Use parse_version
  • Loading branch information
Y-oHr-N authored Mar 8, 2020
2 parents fd76bab + bcd66fc commit de8e867
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions optgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import time

from pkg_resources import parse_version
from typing import Any
from typing import Callable
from typing import Dict
Expand All @@ -14,24 +15,15 @@
import lightgbm as lgb
import numpy as np
import optuna
import sklearn

from sklearn.base import ClassifierMixin
from sklearn.base import RegressorMixin
from sklearn.model_selection import BaseCrossValidator
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import check_random_state
from sklearn.utils import safe_indexing
from sklearn.utils.validation import check_is_fitted

try: # lightgbm<=2.2.3
from lightgbm.sklearn import _eval_function_wrapper as _EvalFunctionWrapper
from lightgbm.sklearn import (
_objective_function_wrapper as _ObjectiveFunctionWrapper,
)
except ImportError:
from lightgbm.sklearn import _EvalFunctionWrapper
from lightgbm.sklearn import _ObjectiveFunctionWrapper

from .utils import check_cv
from .utils import check_fit_params
from .utils import check_X
Expand All @@ -40,6 +32,20 @@
from .utils import RANDOM_STATE_TYPE
from .utils import TWO_DIM_ARRAYLIKE_TYPE

if parse_version(lgb.__version__) >= parse_version("2.3"):
from lightgbm.sklearn import _EvalFunctionWrapper
from lightgbm.sklearn import _ObjectiveFunctionWrapper
else:
from lightgbm.sklearn import _eval_function_wrapper as _EvalFunctionWrapper
from lightgbm.sklearn import (
_objective_function_wrapper as _ObjectiveFunctionWrapper,
)

if parse_version(sklearn.__version__) >= parse_version("0.22"):
from sklearn.utils import _safe_indexing as safe_indexing
else:
from sklearn.utils import safe_indexing

MAX_INT = np.iinfo(np.int32).max

OBJECTIVE2METRIC = {
Expand Down Expand Up @@ -240,14 +246,14 @@ def __init__(
def from_representations(
cls, representations: List[str], weights: Optional[np.ndarray] = None
) -> "_VotingBooster":
try: # lightgbm<=2.2.3
if parse_version(lgb.__version__) >= parse_version("2.3"):
boosters = [
lgb.Booster(params={"model_str": model_str})
lgb.Booster(model_str=model_str, silent=True)
for model_str in representations
]
except TypeError:
else:
boosters = [
lgb.Booster(model_str=model_str, silent=True)
lgb.Booster(params={"model_str": model_str})
for model_str in representations
]

Expand Down

0 comments on commit de8e867

Please sign in to comment.