Skip to content

Commit

Permalink
Refactor use of attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Aug 16, 2024
1 parent baf13bb commit 8d6901b
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions quantile_forest/_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,9 @@ def _get_y_train_leaves(self, X, y, sorter=None, sample_weight=None):
sample_count = np.max(np.bincount(X_leaves_bootstrap[:, i]))
if sample_count > max_samples_leaf:
max_samples_leaf = sample_count
self.max_node_count = max_node_count

# Initialize NumPy array (more efficient serialization than dict/list).
shape = (self.n_estimators, self.max_node_count, y_dim, max_samples_leaf)
shape = (self.n_estimators, max_node_count, y_dim, max_samples_leaf)
y_train_leaves = np.zeros(shape, dtype=np.int64)

for i, estimator in enumerate(self.estimators_):
Expand Down Expand Up @@ -385,7 +384,7 @@ def _get_y_bound_leaves(self, y, y_train_leaves):
if self.monotonic_cst is None:
return None

y_bound_leaves = np.full((self.n_estimators, self.max_node_count, 2), [-np.inf, np.inf])
y_bound_leaves = np.full((*y_train_leaves.shape[:2], 2), [-np.inf, np.inf])

for i, estimator in enumerate(self.estimators_):
tree = estimator.tree_
Expand Down

0 comments on commit 8d6901b

Please sign in to comment.