From 1e8add1354fd5ceb14b5c28349d6be3b6e027b7c Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Sat, 17 Aug 2024 05:05:15 -0700 Subject: [PATCH] Refactor weighted leaves --- quantile_forest/_quantile_forest_fast.pyx | 25 ++++++++++++----------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/quantile_forest/_quantile_forest_fast.pyx b/quantile_forest/_quantile_forest_fast.pyx index 6fd4436..8d34f38 100755 --- a/quantile_forest/_quantile_forest_fast.pyx +++ b/quantile_forest/_quantile_forest_fast.pyx @@ -331,8 +331,8 @@ cpdef vector[double] calc_weighted_quantile( for i in range((sorted_quantile_indices.size())): sorted_quantile_indices[i] = i - # Sort the quantiles in ascending order to help reuse calculations for efficiency. - # The sorted quantile indices allow assigning output based on the original ordering. + # Sort the quantiles in ascending order to help reuse calculations. + # The sorted indices allow assigning output based on the input ordering. parallel_qsort_asc(quantiles, sorted_quantile_indices, 0, n_quantiles - 1) out = vector[double](n_quantiles) @@ -746,16 +746,17 @@ cdef class QuantileForest: leaf_weights = vector[double](n_train) for i in range(n_samples): - n_total_samples = 0 - n_total_trees = 0 - for j in range(n_trees): - if X_indices is None or X_indices[i, j] is True: - n_leaf_samples[j] = 0 - for k in range(max_idx): - if self.y_train_leaves[j, X_leaves[i, j], 0, k] != 0: - n_leaf_samples[j] += 1 - n_total_samples += n_leaf_samples[j] - n_total_trees += 1 + if weighted_leaves: + n_total_samples = 0 + n_total_trees = 0 + for j in range(n_trees): + if X_indices is None or X_indices[i, j] is True: + n_leaf_samples[j] = 0 + for k in range(max_idx): + if self.y_train_leaves[j, X_leaves[i, j], 0, k] != 0: + n_leaf_samples[j] += 1 + n_total_samples += n_leaf_samples[j] + n_total_trees += 1 for j in range(n_outputs): for k in range((train_indices.size())):