Skip to content

Commit

Permalink
Refactor leaf mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Sep 2, 2024
1 parent 03d238c commit cc70bd0
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions quantile_forest/_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class calls the ``fit`` method of the ``ForestRegressor`` and creates a
from sklearn.utils._param_validation import Interval, RealNotInt
except ImportError:
param_validation = False
from sklearn.utils.parallel import Parallel, delayed
from sklearn.utils.validation import check_is_fitted

from ._quantile_forest_fast import QuantileForest
Expand Down Expand Up @@ -203,7 +202,7 @@ def fit(self, X, y, sample_weight=None, sparse_pickle=False):
return self

@staticmethod
def _map_indices_to_leaves(
def _get_y_train_leaves_slice(
bootstrap_indices,
X_leaves_bootstrap,
sample_weight,
Expand Down Expand Up @@ -397,12 +396,8 @@ def _get_y_train_leaves(self, X, y, sorter=None, sample_weight=None):
if sample_weight is not None:
sample_weight = np.squeeze(sample_weight)

y_train_leaves = Parallel(
n_jobs=self.n_jobs,
verbose=self.verbose,
prefer="threads",
)(
delayed(self._map_indices_to_leaves)(
y_train_leaves = [
self._get_y_train_leaves_slice(
bootstrap_indices[:, i],
X_leaves_bootstrap[:, i],
sample_weight,
Expand All @@ -412,7 +407,7 @@ def _get_y_train_leaves(self, X, y, sorter=None, sample_weight=None):
estimator.random_state,
)
for i, estimator in enumerate(self.estimators_)
)
]

return np.array(y_train_leaves)

Expand Down

0 comments on commit cc70bd0

Please sign in to comment.