Skip to content

Commit

Permalink
Refactor weighted leaves
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Aug 17, 2024
1 parent 830f492 commit 1e8add1
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions quantile_forest/_quantile_forest_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ cpdef vector[double] calc_weighted_quantile(
for i in range(<intp_t>(sorted_quantile_indices.size())):
sorted_quantile_indices[i] = <double>i

# Sort the quantiles in ascending order to help reuse calculations for efficiency.
# The sorted quantile indices allow assigning output based on the original ordering.
# Sort the quantiles in ascending order to help reuse calculations.
# The sorted indices allow assigning output based on the input ordering.
parallel_qsort_asc(quantiles, sorted_quantile_indices, 0, n_quantiles - 1)

out = vector[double](n_quantiles)
Expand Down Expand Up @@ -746,16 +746,17 @@ cdef class QuantileForest:
leaf_weights = vector[double](n_train)

for i in range(n_samples):
n_total_samples = 0
n_total_trees = 0
for j in range(n_trees):
if X_indices is None or X_indices[i, j] is True:
n_leaf_samples[j] = 0
for k in range(max_idx):
if self.y_train_leaves[j, X_leaves[i, j], 0, k] != 0:
n_leaf_samples[j] += 1
n_total_samples += n_leaf_samples[j]
n_total_trees += 1
if weighted_leaves:
n_total_samples = 0
n_total_trees = 0
for j in range(n_trees):
if X_indices is None or X_indices[i, j] is True:
n_leaf_samples[j] = 0
for k in range(max_idx):
if self.y_train_leaves[j, X_leaves[i, j], 0, k] != 0:
n_leaf_samples[j] += 1
n_total_samples += n_leaf_samples[j]
n_total_trees += 1

for j in range(n_outputs):
for k in range(<intp_t>(train_indices.size())):
Expand Down

0 comments on commit 1e8add1

Please sign in to comment.