From 7b20244f0b77c9b6120ff9a842e555af3cccee54 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Fri, 22 Nov 2024 11:02:50 -0800 Subject: [PATCH] Make scikit-learn optional again (#596) (#598) --- python/treelite/sklearn/exporter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/treelite/sklearn/exporter.py b/python/treelite/sklearn/exporter.py index 43a44daf..fd204d25 100644 --- a/python/treelite/sklearn/exporter.py +++ b/python/treelite/sklearn/exporter.py @@ -4,8 +4,6 @@ from typing import Any import numpy as np -from sklearn.ensemble import RandomForestClassifier -from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from ..core import TreeliteError from ..model import Model @@ -61,6 +59,7 @@ def _export_tree( # pylint: disable=too-many-locals try: from sklearn import __version__ as sklearn_version + from sklearn.tree import DecisionTreeClassifier from sklearn.tree._tree import Tree as SKLearnTree except ImportError as e: raise TreeliteError("This function requires scikit-learn package") from e @@ -126,7 +125,7 @@ def _export_tree( return subestimator -def export_model(model: Model): +def export_model(model: Model) -> Any: """ Export a model as a scikit-learn RandomForest. @@ -153,7 +152,8 @@ def export_model(model: Model): # pylint: disable=too-many-locals try: from sklearn import __version__ as sklearn_version - from sklearn.ensemble import RandomForestRegressor + from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor + from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor except ImportError as e: raise TreeliteError("This function requires scikit-learn package") from e