From 3a6714d4d2941530bcbb4c1de3c37e7f961c6c9a Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Fri, 22 Nov 2024 10:48:44 -0800 Subject: [PATCH] Make scikit-learn optional again (#596) (#597) --- python/treelite/sklearn/exporter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/treelite/sklearn/exporter.py b/python/treelite/sklearn/exporter.py index 43a44daf..fd204d25 100644 --- a/python/treelite/sklearn/exporter.py +++ b/python/treelite/sklearn/exporter.py @@ -4,8 +4,6 @@ from typing import Any import numpy as np -from sklearn.ensemble import RandomForestClassifier -from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from ..core import TreeliteError from ..model import Model @@ -61,6 +59,7 @@ def _export_tree( # pylint: disable=too-many-locals try: from sklearn import __version__ as sklearn_version + from sklearn.tree import DecisionTreeClassifier from sklearn.tree._tree import Tree as SKLearnTree except ImportError as e: raise TreeliteError("This function requires scikit-learn package") from e @@ -126,7 +125,7 @@ def _export_tree( return subestimator -def export_model(model: Model): +def export_model(model: Model) -> Any: """ Export a model as a scikit-learn RandomForest. @@ -153,7 +152,8 @@ def export_model(model: Model): # pylint: disable=too-many-locals try: from sklearn import __version__ as sklearn_version - from sklearn.ensemble import RandomForestRegressor + from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor + from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor except ImportError as e: raise TreeliteError("This function requires scikit-learn package") from e