From 928a48999ec662649176a994fd1bca3444fcb501 Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Wed, 21 Feb 2024 19:10:32 -0800 Subject: [PATCH] Reorder multioutputs (#35) --- docs/user_guide.rst | 6 ++--- quantile_forest/_quantile_forest.py | 22 +++++++++---------- quantile_forest/_quantile_forest_fast.pyx | 10 ++++----- .../examples/plot_quantile_multioutput.py | 6 ++--- quantile_forest/tests/test_quantile_forest.py | 4 ++-- 5 files changed, 24 insertions(+), 24 deletions(-) diff --git a/docs/user_guide.rst b/docs/user_guide.rst index b814795..c6afd18 100755 --- a/docs/user_guide.rst +++ b/docs/user_guide.rst @@ -95,7 +95,7 @@ The output of the `predict` method is an array with one column for each specifie >>> (y_pred[:, 0] >= y_pred[:, 1]).all() True -Multi-target quantile regression is also supported. If the target values are multi-dimensional, then the final output column will correspond to the number of targets:: +Multi-target quantile regression is also supported. If the target values are multi-dimensional, then the second output column will correspond to the number of targets:: >>> from sklearn import datasets >>> from quantile_forest import RandomForestQuantileRegressor @@ -111,9 +111,9 @@ Multi-target quantile regression is also supported. If the target values are mul True >>> y_pred.shape[0] == len(X) True - >>> y_pred.shape[1] == len(quantiles) + >>> y_pred.shape[-1] == len(quantiles) True - >>> y_pred.shape[-1] == y.shape[1] + >>> y_pred.shape[1] == y.shape[1] True Quantile Weighting diff --git a/quantile_forest/_quantile_forest.py b/quantile_forest/_quantile_forest.py index 64a64c7..5dc2049 100755 --- a/quantile_forest/_quantile_forest.py +++ b/quantile_forest/_quantile_forest.py @@ -516,8 +516,8 @@ def predict( Returns ------- - y_pred : array of shape (n_samples, n_quantiles, n_outputs) - If quantiles is set to None, then return ``E(Y | X)``. Else, for + y_pred : array of shape (n_samples, n_outputs, n_quantiles) + If quantiles is set to 'mean', then return ``E(Y | X)``. Else, for all quantiles, return ``y`` at ``q`` for which ``F(Y=y|x) = q``, where ``q`` is the quantile. """ @@ -565,27 +565,27 @@ def predict( if self.max_samples_leaf == 1: # optimize for single-sample-per-leaf performance y_train_leaves = np.asarray(self.forest_.y_train_leaves) y_train = np.asarray(self.forest_.y_train).T - y_pred = np.empty((len(X), len(quantiles), y_train.shape[1])) - for i in range(y_train.shape[1]): + y_pred = np.empty((len(X), y_train.shape[1], len(quantiles))) + for output in range(y_train.shape[1]): leaf_values = np.empty((len(X), self.n_estimators)) for tree in range(self.n_estimators): - if X_indices is None: - train_indices = y_train_leaves[tree, X_leaves[:, tree], i, 0] + if X_indices is None: # IB scoring + train_indices = y_train_leaves[tree, X_leaves[:, tree], output, 0] else: # OOB scoring indices = X_indices[:, tree] == 1 leaves = X_leaves[indices, tree] train_indices = np.zeros(len(X), dtype=int) - train_indices[indices] = y_train_leaves[tree, leaves, i, 0] - leaf_values[:, tree] = y_train[train_indices - 1, i] + train_indices[indices] = y_train_leaves[tree, leaves, output, 0] + leaf_values[:, tree] = y_train[train_indices - 1, output] leaf_values[train_indices == 0, tree] = np.nan if len(quantiles) == 1 and quantiles[0] == -1: # calculate mean func = np.mean if X_indices is None else np.nanmean - y_pred[..., i] = np.expand_dims(func(leaf_values, axis=1), axis=1) + y_pred[:, output, :] = np.expand_dims(func(leaf_values, axis=1), axis=1) else: # calculate quantiles func = np.quantile if X_indices is None else np.nanquantile method = interpolation.decode() - y_pred[..., i] = func(leaf_values, quantiles, method=method, axis=1).T - else: + y_pred[:, output, :] = func(leaf_values, quantiles, method=method, axis=1).T + else: # get predictions for arbitrary leaf sizes y_pred = self.forest_.predict( quantiles, X_leaves, diff --git a/quantile_forest/_quantile_forest_fast.pyx b/quantile_forest/_quantile_forest_fast.pyx index 982d2d2..e61080c 100755 --- a/quantile_forest/_quantile_forest_fast.pyx +++ b/quantile_forest/_quantile_forest_fast.pyx @@ -705,7 +705,7 @@ cdef class QuantileForest: raise ValueError(f"Invalid interpolation method {interpolation}.") # Initialize NumPy array with NaN values and get view for nogil. - preds = np.full((n_samples, n_quantiles, n_outputs), np.nan, dtype=np.float64) + preds = np.full((n_samples, n_outputs, n_quantiles), np.nan, dtype=np.float64) preds_view = preds # memoryview with nogil: @@ -832,7 +832,7 @@ cdef class QuantileForest: if not use_mean: for k in range((leaf_preds.size())): if leaf_preds[k].size() == 1: - preds_view[i, k, j] = leaf_preds[k][0] + preds_view[i, j, k] = leaf_preds[k][0] elif leaf_preds[k].size() > 1: pred = calc_quantile( leaf_preds[k], @@ -840,12 +840,12 @@ cdef class QuantileForest: interpolation, issorted=False, ) - preds_view[i, k, j] = pred[0] + preds_view[i, j, k] = pred[0] else: if leaf_preds[0].size() == 1: - preds_view[i, 0, j] = leaf_preds[0][0] + preds_view[i, j, 0] = leaf_preds[0][0] elif leaf_preds[0].size() > 1: - preds_view[i, 0, j] = calc_mean(leaf_preds[0]) + preds_view[i, j, 0] = calc_mean(leaf_preds[0]) return np.asarray(preds_view) diff --git a/quantile_forest/tests/examples/plot_quantile_multioutput.py b/quantile_forest/tests/examples/plot_quantile_multioutput.py index 6f445cb..4766a77 100755 --- a/quantile_forest/tests/examples/plot_quantile_multioutput.py +++ b/quantile_forest/tests/examples/plot_quantile_multioutput.py @@ -63,9 +63,9 @@ def make_func_Xy(funcs, bounds, n_samples): "x": np.tile(X.squeeze(), len(funcs)), "y": y.reshape(-1, order="F"), "y_true": np.concatenate([f["signal"](X.squeeze()) for f in funcs]), - "y_pred": np.concatenate([y_pred[:, 1, i] for i in range(len(funcs))]), - "y_pred_low": np.concatenate([y_pred[:, 0, i] for i in range(len(funcs))]), - "y_pred_upp": np.concatenate([y_pred[:, 2, i] for i in range(len(funcs))]), + "y_pred": np.concatenate([y_pred[:, i, 1] for i in range(len(funcs))]), + "y_pred_low": np.concatenate([y_pred[:, i, 0] for i in range(len(funcs))]), + "y_pred_upp": np.concatenate([y_pred[:, i, 2] for i in range(len(funcs))]), "target": np.concatenate([[f"{i}"] * len(X) for i in range(len(funcs))]), } ) diff --git a/quantile_forest/tests/test_quantile_forest.py b/quantile_forest/tests/test_quantile_forest.py index 0f7d6ef..89f4cc3 100755 --- a/quantile_forest/tests/test_quantile_forest.py +++ b/quantile_forest/tests/test_quantile_forest.py @@ -470,8 +470,8 @@ def check_predict_quantiles( ) score = est.score(X.reshape(-1, 1), y, quantiles=0.5) assert y_pred.ndim == (3 if isinstance(quantiles, list) else 2) - assert y_pred.shape[-1] == y.shape[1] - assert np.any(y_pred[..., 0] != y_pred[..., 1]) + assert y_pred.shape[1] == y.shape[1] + assert np.any(y_pred[:, 0, ...] != y_pred[:, 1, ...]) assert score > 0.97 # Check that specifying `quantiles` overwrites `default_quantiles`.