From b396aff33376eb463e3179095efef52550fbbe60 Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Sat, 3 Aug 2024 01:17:08 -0700 Subject: [PATCH] Update example plots --- examples/plot_quantile_conformalized.py | 2 ++ examples/plot_quantile_multioutput.py | 2 ++ examples/plot_quantile_vs_standard.py | 2 ++ examples/plot_treeshap_example.py | 6 ++++-- 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/plot_quantile_conformalized.py b/examples/plot_quantile_conformalized.py index 34dc98a..e9c9e9f 100755 --- a/examples/plot_quantile_conformalized.py +++ b/examples/plot_quantile_conformalized.py @@ -23,6 +23,8 @@ from quantile_forest import RandomForestQuantileRegressor +alt.data_transformers.disable_max_rows() + strategies = { "qrf": "Quantile Regression Forest (QRF)", "cqr": "Conformalized Quantile Regression (CQR)", diff --git a/examples/plot_quantile_multioutput.py b/examples/plot_quantile_multioutput.py index 81716be..2583df4 100755 --- a/examples/plot_quantile_multioutput.py +++ b/examples/plot_quantile_multioutput.py @@ -16,6 +16,8 @@ from quantile_forest import RandomForestQuantileRegressor +alt.data_transformers.disable_max_rows() + np.random.seed(0) n_samples = 2500 diff --git a/examples/plot_quantile_vs_standard.py b/examples/plot_quantile_vs_standard.py index 6aff07c..08f9d60 100755 --- a/examples/plot_quantile_vs_standard.py +++ b/examples/plot_quantile_vs_standard.py @@ -21,6 +21,8 @@ from quantile_forest import RandomForestQuantileRegressor +alt.data_transformers.disable_max_rows() + rng = check_random_state(0) # Create right-skewed dataset. diff --git a/examples/plot_treeshap_example.py b/examples/plot_treeshap_example.py index 44aaeb7..f186dd7 100644 --- a/examples/plot_treeshap_example.py +++ b/examples/plot_treeshap_example.py @@ -22,6 +22,8 @@ from quantile_forest import RandomForestQuantileRegressor +alt.data_transformers.disable_max_rows() + n_samples = 500 test_idx = 0 quantiles = list((np.arange(11) * 10) / 100) @@ -38,8 +40,8 @@ def get_shap_values(qrf, X, quantile=0.5, **kwargs): # Use Tree SHAP to generate explanations. explainer = shap.TreeExplainer(model, X) - qrf_pred = qrf.predict(X.to_numpy(), quantiles=quantile, **kwargs) - rf_pred = qrf.predict(X.to_numpy(), quantiles="mean", aggregate_leaves_first=False) + qrf_pred = qrf.predict(X, quantiles=quantile, **kwargs) + rf_pred = qrf.predict(X, quantiles="mean", aggregate_leaves_first=False) scaling = 1.0 / len(qrf.estimators_) # scale factor based on the number of estimators base_offset = qrf_pred - rf_pred # difference between the QRF and RF (baseline) predictions