diff --git a/quantile_forest/_quantile_forest_fast.pyx b/quantile_forest/_quantile_forest_fast.pyx index a3edff1..730b3d2 100755 --- a/quantile_forest/_quantile_forest_fast.pyx +++ b/quantile_forest/_quantile_forest_fast.pyx @@ -317,6 +317,12 @@ cpdef vector[double] calc_weighted_quantile( if not issorted: parallel_qsort_asc(inputs, weights, 0, n_inputs-1) + # Get monotonic sorting of quantiles for efficient calculation. + sorted_quantile_idx = vector[double](n_quantiles) + for i in range((sorted_quantile_idx.size())): + sorted_quantile_idx[i] = i + parallel_qsort_asc(quantiles, sorted_quantile_idx, 0, n_quantiles-1) + cum_weights = vector[double](n_inputs) # Calculate the empirical cumulative distribution function (ECDF). @@ -331,16 +337,20 @@ cpdef vector[double] calc_weighted_quantile( out = vector[double](n_quantiles) + idx_floor = 0 + idx_ceil = 1 + for i in range(n_quantiles): quantile = quantiles[i] + # Assign the output based on the input quantile ordering. + i = sorted_quantile_idx[i] + # Calculate the quantile's proportion of total weight. p = quantile * f + C # Find the first index where the proportion of weight exceeds p. - idx_floor = 0 - idx_ceil = 1 - for j in range(n_inputs): + for j in range(idx_floor, n_inputs): if p >= cum_weights[j]: if weights[j] > 0: idx_floor = j