diff --git a/python/treelite/sklearn/exporter.py b/python/treelite/sklearn/exporter.py index 43a44daf..fd204d25 100644 --- a/python/treelite/sklearn/exporter.py +++ b/python/treelite/sklearn/exporter.py @@ -4,8 +4,6 @@ from typing import Any import numpy as np -from sklearn.ensemble import RandomForestClassifier -from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from ..core import TreeliteError from ..model import Model @@ -61,6 +59,7 @@ def _export_tree( # pylint: disable=too-many-locals try: from sklearn import __version__ as sklearn_version + from sklearn.tree import DecisionTreeClassifier from sklearn.tree._tree import Tree as SKLearnTree except ImportError as e: raise TreeliteError("This function requires scikit-learn package") from e @@ -126,7 +125,7 @@ def _export_tree( return subestimator -def export_model(model: Model): +def export_model(model: Model) -> Any: """ Export a model as a scikit-learn RandomForest. @@ -153,7 +152,8 @@ def export_model(model: Model): # pylint: disable=too-many-locals try: from sklearn import __version__ as sklearn_version - from sklearn.ensemble import RandomForestRegressor + from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor + from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor except ImportError as e: raise TreeliteError("This function requires scikit-learn package") from e