From 36aec2ca0dae27052cbbf384bae433a67133e4b0 Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Tue, 3 Oct 2023 18:26:59 -0700 Subject: [PATCH] Add unit tests for quantile ordering --- quantile_forest/tests/test_quantile_forest.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/quantile_forest/tests/test_quantile_forest.py b/quantile_forest/tests/test_quantile_forest.py index e8cc59f..8e8331f 100755 --- a/quantile_forest/tests/test_quantile_forest.py +++ b/quantile_forest/tests/test_quantile_forest.py @@ -1096,6 +1096,16 @@ def test_calc_quantile(): expected = [np.mean(i)] assert_allclose(actual, expected) + # Check that quantile order is respected. + for i in inputs: + for q in [quantiles, quantiles[::-1]]: + actual = calc_quantile(i, q) + for idx in range(len(quantiles) - 1): + if q[idx] <= q[idx + 1]: + assert np.all(np.less_equal(actual[idx], actual[idx + 1])) + else: + assert np.all(np.less_equal(actual[idx + 1], actual[idx])) + inputs = [] # Check that empty array is returned for empty list. @@ -1182,6 +1192,16 @@ def _dicts_to_input_pairs(input_dicts): expected = [np.mean(i2)] assert_allclose(actual, expected) + # Check that quantile order is respected. + for (i1, w1) in _dicts_to_weighted_inputs(inputs): + for q in [quantiles, quantiles[::-1]]: + actual = calc_weighted_quantile(i1, w1, q) + for idx in range(len(quantiles) - 1): + if q[idx] <= q[idx + 1]: + assert np.all(np.less_equal(actual[idx], actual[idx + 1])) + else: + assert np.all(np.less_equal(actual[idx + 1], actual[idx])) + inputs = [1, 2, 3, 3, 3, 3, 3, 3, 4, 5] weights = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]