Skip to content

Commit

Permalink
Update example plots
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Aug 3, 2024
1 parent 024962c commit b396aff
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 2 deletions.
2 changes: 2 additions & 0 deletions examples/plot_quantile_conformalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from quantile_forest import RandomForestQuantileRegressor

alt.data_transformers.disable_max_rows()

strategies = {
"qrf": "Quantile Regression Forest (QRF)",
"cqr": "Conformalized Quantile Regression (CQR)",
Expand Down
2 changes: 2 additions & 0 deletions examples/plot_quantile_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from quantile_forest import RandomForestQuantileRegressor

alt.data_transformers.disable_max_rows()

np.random.seed(0)

n_samples = 2500
Expand Down
2 changes: 2 additions & 0 deletions examples/plot_quantile_vs_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from quantile_forest import RandomForestQuantileRegressor

alt.data_transformers.disable_max_rows()

rng = check_random_state(0)

# Create right-skewed dataset.
Expand Down
6 changes: 4 additions & 2 deletions examples/plot_treeshap_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from quantile_forest import RandomForestQuantileRegressor

alt.data_transformers.disable_max_rows()

n_samples = 500
test_idx = 0
quantiles = list((np.arange(11) * 10) / 100)
Expand All @@ -38,8 +40,8 @@ def get_shap_values(qrf, X, quantile=0.5, **kwargs):
# Use Tree SHAP to generate explanations.
explainer = shap.TreeExplainer(model, X)

qrf_pred = qrf.predict(X.to_numpy(), quantiles=quantile, **kwargs)
rf_pred = qrf.predict(X.to_numpy(), quantiles="mean", aggregate_leaves_first=False)
qrf_pred = qrf.predict(X, quantiles=quantile, **kwargs)
rf_pred = qrf.predict(X, quantiles="mean", aggregate_leaves_first=False)

scaling = 1.0 / len(qrf.estimators_) # scale factor based on the number of estimators
base_offset = qrf_pred - rf_pred # difference between the QRF and RF (baseline) predictions
Expand Down

0 comments on commit b396aff

Please sign in to comment.