diff --git a/quantile_forest/_quantile_forest.py b/quantile_forest/_quantile_forest.py index d5392c2..4039490 100755 --- a/quantile_forest/_quantile_forest.py +++ b/quantile_forest/_quantile_forest.py @@ -320,10 +320,9 @@ def _get_y_train_leaves(self, X, y, sorter=None, sample_weight=None): sample_count = np.max(np.bincount(X_leaves_bootstrap[:, i])) if sample_count > max_samples_leaf: max_samples_leaf = sample_count - self.max_node_count = max_node_count # Initialize NumPy array (more efficient serialization than dict/list). - shape = (self.n_estimators, self.max_node_count, y_dim, max_samples_leaf) + shape = (self.n_estimators, max_node_count, y_dim, max_samples_leaf) y_train_leaves = np.zeros(shape, dtype=np.int64) for i, estimator in enumerate(self.estimators_): @@ -385,7 +384,7 @@ def _get_y_bound_leaves(self, y, y_train_leaves): if self.monotonic_cst is None: return None - y_bound_leaves = np.full((self.n_estimators, self.max_node_count, 2), [-np.inf, np.inf]) + y_bound_leaves = np.full((*y_train_leaves.shape[:2], 2), [-np.inf, np.inf]) for i, estimator in enumerate(self.estimators_): tree = estimator.tree_