Skip to content

Commit

Permalink
Make scikit-learn optional again (#596) (#598)
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 authored Nov 22, 2024
1 parent da70beb commit 7b20244
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions python/treelite/sklearn/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from typing import Any

import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor

from ..core import TreeliteError
from ..model import Model
Expand Down Expand Up @@ -61,6 +59,7 @@ def _export_tree(
# pylint: disable=too-many-locals
try:
from sklearn import __version__ as sklearn_version
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree._tree import Tree as SKLearnTree
except ImportError as e:
raise TreeliteError("This function requires scikit-learn package") from e
Expand Down Expand Up @@ -126,7 +125,7 @@ def _export_tree(
return subestimator


def export_model(model: Model):
def export_model(model: Model) -> Any:
"""
Export a model as a scikit-learn RandomForest.
Expand All @@ -153,7 +152,8 @@ def export_model(model: Model):
# pylint: disable=too-many-locals
try:
from sklearn import __version__ as sklearn_version
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
except ImportError as e:
raise TreeliteError("This function requires scikit-learn package") from e

Expand Down

0 comments on commit 7b20244

Please sign in to comment.