Skip to content

Commit

Permalink
Squeeze bootstrap indices (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson authored Feb 12, 2024
1 parent e0cbe7d commit c59741e
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions quantile_forest/_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ def _get_y_train_leaves(self, X, y_dim, sorter=None, sample_weight=None):
if sorter is not None:
# Reassign bootstrap indices to account for target sorting.
bootstrap_indices = np.argsort(sorter, axis=0)[bootstrap_indices]
if bootstrap_indices.shape[-1] == 1:
bootstrap_indices = np.squeeze(bootstrap_indices, -1)

bootstrap_indices += 1 # for sparse matrix (0s as empty)

Expand Down Expand Up @@ -319,10 +321,15 @@ def _get_y_train_leaves(self, X, y_dim, sorter=None, sample_weight=None):
if not isinstance(y_indices, list):
y_indices = list(y_indices)
y_indices = random.sample(y_indices, max_samples_leaf)
y_indices = np.asarray(y_indices).reshape(y_dim, -1)

for j in range(y_dim):
y_train_leaves[i, leaf_idx, j, : len(y_indices[j])] = y_indices[j]
if sorter is not None:
y_indices = np.asarray(y_indices).reshape(y_dim, -1)

for j in range(y_dim):
y_train_leaves[i, leaf_idx, j, : len(y_indices[j])] = y_indices[j]
else:
for j in range(y_dim):
y_train_leaves[i, leaf_idx, j, : len(y_indices)] = y_indices

return y_train_leaves

Expand Down

0 comments on commit c59741e

Please sign in to comment.