Skip to content

Commit

Permalink
Update example plots
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Jul 31, 2024
1 parent 1cd79ba commit 4e8ae27
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 36 deletions.
2 changes: 1 addition & 1 deletion examples/plot_huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def plot_quantiles_by_latlon(df, quantiles):
min=0,
max=1,
step=0.5 if len(quantiles) == 1 else 1 / (len(quantiles) - 1),
name="Quantile:",
name="Quantile: ",
)

q_val = alt.selection_point(
Expand Down
4 changes: 2 additions & 2 deletions examples/plot_predict_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

np.random.seed(0)

n_test_samples = 10
n_test_samples = 100


def predict(reg, X, quantiles=0.5, what=None):
Expand Down Expand Up @@ -98,7 +98,7 @@ def predict(reg, X, quantiles=0.5, what=None):
def plot_ecdf(df):
min_idx = df["sample_idx"].min()
max_idx = df["sample_idx"].max()
slider = alt.binding_range(min=min_idx, max=max_idx, step=1, name="Sample Index:")
slider = alt.binding_range(min=min_idx, max=max_idx, step=1, name="Sample Index: ")
sample_selection = alt.param(value=0, bind=slider, name="sample_idx")

tooltip = [
Expand Down
55 changes: 36 additions & 19 deletions examples/plot_quantile_conformalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
while QRF may require additional calibration for reliable interval estimates.
Notice that in this example, by using CQR we obtain a level of coverage (i.e.,
percentage of samples that actaully fall within their prediction interval)
that is closer to the target level. Adapted from "Prediction intervals:
Quantile Regression Forests" by Carl McBride Ellis:
https://www.kaggle.com/code/carlmcbrideellis/prediction-intervals-quantile-regression-forests.
that is generally closer to the target level. Adapted from `"Prediction
intervals: Quantile Regression Forests" by Carl McBride Ellis
<https://www.kaggle.com/code/carlmcbrideellis/prediction-intervals-quantile-regression-forests>`_.
"""

import altair as alt
Expand All @@ -30,8 +30,7 @@
random_state = 0
rng = check_random_state(random_state)

cov_pct = 95 # the "coverage level"
alpha = (100 - cov_pct) / 100
coverages = list(np.arange(11) / 10) # the "coverage level"

# Load the California Housing Prices dataset.
california = datasets.fetch_california_housing()
Expand Down Expand Up @@ -113,7 +112,7 @@ def cqr_strategy(alpha, X_train, X_test, y_train, y_test):
conf_scores = (np.vstack((a, b)).T).max(axis=1)

# Get the 1-alpha quantile `s` from the distribution of conformity scores.
s = np.quantile(conf_scores, (1 - alpha) * (1 + (1 / (len(y_calib)))))
s = np.quantile(conf_scores, np.clip((1 - alpha) * (1 + (1 / (len(y_calib)))), 0, 1))

# Subtract `s` from the lower quantile and add it to the upper quantile.
y_conf_low = y_pred_low - s
Expand All @@ -129,12 +128,15 @@ def cqr_strategy(alpha, X_train, X_test, y_train, y_test):


# Get strategy outputs as a data frame.
args = (alpha, X_train, X_test, y_train, y_test)
df = pd.concat([qrf_strategy(*args), cqr_strategy(*args)])
dfs = []
for cov_frac in coverages:
alpha = float(round(1 - cov_frac, 2))
args = (alpha, X_train, X_test, y_train, y_test)
dfs.append(pd.concat([qrf_strategy(*args), cqr_strategy(*args)]).assign(alpha=alpha))
df = pd.concat(dfs)

# Calculate coverage and width metrics.
metrics = (
df.groupby("strategy")
df.groupby(["alpha", "strategy"])
.apply(
lambda grp: pd.Series(
{
Expand All @@ -147,10 +149,14 @@ def cqr_strategy(alpha, X_train, X_test, y_train, y_test):
)

# Merge the metrics into the data frame.
df = df.merge(metrics, on="strategy", how="left")
df = df.merge(metrics, on=["alpha", "strategy"], how="left")


def plot_prediction_intervals(df, domain):
slider = alt.binding_range(min=0, max=1, step=0.1, name="Coverage: ")
cov_selection = alt.param(value=0.9, bind=slider, name="coverage")
cov_tol = 0.01

click = alt.selection_point(fields=["y_label"], bind="legend")

color_circle = alt.Color(
Expand All @@ -168,10 +174,17 @@ def plot_prediction_intervals(df, domain):
alt.Tooltip("y_label:N", title="Within Interval"),
]

base = alt.Chart(df).transform_calculate(
y_label=(
"((datum.y_test >= datum.y_pred_low) & (datum.y_test <= datum.y_pred_upp))"
" ? 'Yes' : 'No'"
base = (
alt.Chart(df)
.transform_filter(
(1 - alt.datum["alpha"] - cov_tol <= cov_selection)
& (1 - alt.datum["alpha"] + cov_tol >= cov_selection)
)
.transform_calculate(
y_label=(
"((datum.y_test >= datum.y_pred_low) & (datum.y_test <= datum.y_pred_upp))"
" ? 'Yes' : 'No'"
)
)
)

Expand Down Expand Up @@ -224,11 +237,13 @@ def plot_prediction_intervals(df, domain):
)

text_coverage = (
base.transform_aggregate(coverage="mean(coverage)", groupby=["strategy"])
base.transform_aggregate(
alpha="mean(alpha)", coverage="mean(coverage)", groupby=["strategy"]
)
.transform_calculate(
coverage_text=(
f"'Coverage: ' + format({alt.datum['coverage'] * 100}, '.1f') + '%'"
f" + ' (target = {cov_pct}%)'"
f"'Coverage: ' + format(datum.coverage * 100, '.1f') + '%'"
f" + ' (target = ' + format((1 - datum.alpha) * 100, '.1f') + '%)'"
)
)
.mark_text(align="left", baseline="top")
Expand All @@ -251,7 +266,9 @@ def plot_prediction_intervals(df, domain):
)
)

chart = bar + tick_low + tick_upp + circle + diagonal + text_coverage + text_with
chart = (bar + tick_low + tick_upp + circle + diagonal + text_coverage + text_with).add_params(
cov_selection
)

return chart

Expand Down
4 changes: 2 additions & 2 deletions examples/plot_quantile_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@


def plot_interpolations(df, legend):
slider = alt.binding_range(min=0, max=1, step=0.01, name="Prediction Interval:")
interval_selection = alt.param(value=0.95, bind=slider, name="interval")
slider = alt.binding_range(min=0, max=1, step=0.01, name="Prediction Interval: ")
interval_selection = alt.param(value=0.9, bind=slider, name="interval")
interval_tol = 0.001

click = alt.selection_point(fields=["method"], bind="legend")
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_quantile_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def make_func_Xy(funcs, bounds, n_samples):


def plot_multioutputs(df, legend):
slider = alt.binding_range(min=0, max=1, step=0.05, name="Prediction Interval:")
slider = alt.binding_range(min=0, max=1, step=0.05, name="Prediction Interval: ")
interval_selection = alt.param(value=0.95, bind=slider, name="interval")
interval_tol = 0.001

Expand Down
45 changes: 35 additions & 10 deletions examples/plot_quantile_vs_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
forest and a standard random forest regressor on a synthetic, right-skewed
dataset. In a right-skewed distribution, the mean is to the right of the
median. As illustrated by a greater overlap in the frequencies of the actual
and predicted values, the median estimated by a quantile regressor can be a
more reliable estimator of a skewed distribution than the mean.
and predicted values, the median (quantile = 0.5) estimated by a quantile
regressor can be a more reliable estimator of a skewed distribution than the
mean.
"""

import altair as alt
import numpy as np
import pandas as pd
import scipy as sp
from sklearn.ensemble import RandomForestRegressor
Expand All @@ -29,6 +31,8 @@
y = skewnorm_rv.rvs(n_samples)
X = rng.randn(n_samples, 2) * y.reshape(-1, 1)

quantiles = list(np.arange(101) / 100)

regr_rf = RandomForestRegressor(n_estimators=10, random_state=0)
regr_qrf = RandomForestQuantileRegressor(n_estimators=10, random_state=0)

Expand All @@ -38,18 +42,38 @@
regr_qrf.fit(X_train, y_train)

y_pred_rf = regr_rf.predict(X_test) # standard RF predictions (mean)
y_pred_qrf = regr_qrf.predict(X_test, quantiles=0.5) # QRF predictions (median)
y_pred_qrf = regr_qrf.predict(X_test, quantiles=quantiles) # QRF predictions (quantiles)

legend = {
"Actual": "#c0c0c0",
"RF (Mean)": "#f2a619",
"QRF (Median)": "#006aff",
}

df = pd.DataFrame({"actual": y_test, "rf": y_pred_rf, "qrf": y_pred_qrf})
df = pd.concat(
[
pd.DataFrame({"actual": y_test, "rf": y_pred_rf, "qrf": y_pred_qrf[..., q_idx]}).assign(
quantile=quantile
)
for q_idx, quantile in enumerate(quantiles)
]
)


def plot_prediction_histograms(df, legend):
slider = alt.binding_range(
min=0,
max=1,
step=0.5 if len(quantiles) == 1 else 1 / (len(quantiles) - 1),
name="Quantile: ",
)

q_val = alt.selection_point(
value=0.5,
bind=slider,
fields=["quantile"],
)

click = alt.selection_point(fields=["label"], bind="legend")

color = alt.condition(
Expand All @@ -60,14 +84,15 @@ def plot_prediction_histograms(df, legend):

chart = (
alt.Chart(df)
.transform_calculate(calculate=f"round({alt.datum['actual']} * 10) / 10", as_="Actual")
.transform_calculate(calculate=f"round({alt.datum['rf']} * 10) / 10", as_="RF (Mean)")
.transform_calculate(calculate=f"round({alt.datum['qrf']} * 10) / 10", as_="QRF (Median)")
.transform_fold(["Actual", "RF (Mean)", "QRF (Median)"], as_=["label", "value"])
.transform_filter(q_val)
.transform_calculate(calculate=f"round(datum.actual * 10) / 10", as_="Actual")
.transform_calculate(calculate=f"round(datum.rf * 10) / 10", as_="RF (Mean)")
.transform_calculate(calculate=f"round(datum.qrf * 10) / 10", as_="QRF (Quantile)")
.transform_fold(["Actual", "RF (Mean)", "QRF (Quantile)"], as_=["label", "value"])
.mark_bar()
.encode(
x=alt.X(
"value:O",
"value:N",
axis=alt.Axis(
labelAngle=0,
labelExpr="datum.value % 0.5 == 0 ? datum.value : null",
Expand All @@ -83,7 +108,7 @@ def plot_prediction_histograms(df, legend):
alt.Tooltip("count():Q", format=",d", title="Counts"),
],
)
.add_params(click)
.add_params(q_val, click)
.configure_range(category=alt.RangeScheme(list(legend.values())))
.properties(height=400, width=650)
)
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_treeshap_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def plot_shap_waterfall_with_quantiles(df, height=300):
min=0,
max=1,
step=0.5 if len(quantiles) == 1 else 1 / (len(quantiles) - 1),
name="Quantile:",
name="Quantile: ",
)

q_val = alt.selection_point(
Expand Down

0 comments on commit 4e8ae27

Please sign in to comment.