Skip to content

Commit

Permalink
Update examples
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Feb 19, 2024
1 parent 18cf5ef commit 8b1170d
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 30 deletions.
12 changes: 8 additions & 4 deletions quantile_forest/tests/examples/plot_quantile_conformalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
construct reliable prediction intervals using conformalized quantile
regression (CQR). CQR offers prediction intervals that attain valid coverage,
while QRF may require additional calibration for reliable interval estimates.
Adapted from "Prediction intervals: Quantile Regression Forests" by Carl McBride
Ellis:
Notice that in this example, by using CQR we obtain a level of coverage (i.e.,
percentage of samples that actaully fall within their prediction interval)
that is closer to the target level. Adapted from "Prediction intervals:
Quantile Regression Forests" by Carl McBride Ellis:
https://www.kaggle.com/code/carlmcbrideellis/prediction-intervals-quantile-regression-forests.
"""

Expand Down Expand Up @@ -225,7 +227,7 @@ def plot_prediction_intervals(df, domain):
)
.transform_calculate(
coverage_text=(
"'Coverage: ' + format(datum.coverage * 100, '.1f') + '%'"
f"'Coverage: ' + format({alt.datum['coverage'] * 100}, '.1f') + '%'"
f" + ' (target = {cov_pct}%)'"
)
)
Expand All @@ -240,7 +242,9 @@ def plot_prediction_intervals(df, domain):
base.transform_aggregate(
coverage="mean(coverage)", width="mean(width)", groupby=["strategy"]
)
.transform_calculate(width_text="'Interval Width: ' + format(datum.width, '$,d')")
.transform_calculate(
width_text=f"'Interval Width: ' + format({alt.datum['width']}, '$,d')"
)
.mark_text(align="left", baseline="top")
.encode(
x=alt.value(5),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
regressor for multiple target variables. For each target, multiple quantiles
can be estimated simultaneously. In this example, the target variable has
two output values for each sample, with a single regressor used to estimate
three quantiles (the median and interval) for each target output.
three quantiles (the median and 95% interval) for each target output.
"""

import altair as alt
Expand Down
44 changes: 22 additions & 22 deletions quantile_forest/tests/examples/plot_quantile_vs_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,49 +46,49 @@
"QRF (Median)": "#006aff",
}

df = pd.DataFrame({"Actual": y_test, "RF (Mean)": y_pred_rf, "QRF (Median)": y_pred_qrf})
df = pd.DataFrame({"actual": y_test, "rf": y_pred_rf, "qrf": y_pred_qrf})


def plot_prediction_histograms(df, legend):
click = alt.selection_point(fields=["estimator"], bind="legend")
click = alt.selection_point(fields=["label"], bind="legend")

color = alt.condition(
click,
alt.Color("estimator:N", sort=list(legend.keys()), title=None),
alt.Color("label:N", sort=list(legend.keys()), title=None),
alt.value("lightgray"),
)

chart = (
alt.Chart(df, width=alt.Step(6))
.transform_fold(list(legend.keys()), as_=["estimator", "y_pred"])
.transform_joinaggregate(total="count(*)", groupby=["estimator"])
alt.Chart(df)
.transform_calculate(calculate=f"round({alt.datum['actual']} * 10) / 10", as_="Actual")
.transform_calculate(calculate=f"round({alt.datum['rf']} * 10) / 10", as_="RF (Mean)")
.transform_calculate(calculate=f"round({alt.datum['qrf']} * 10) / 10", as_="QRF (Median)")
.transform_fold(["Actual", "RF (Mean)", "QRF (Median)"], as_=["label", "value"])
.transform_joinaggregate(total="count(*)", groupby=["label"])
.transform_calculate(pct="1 / datum.total")
.mark_bar()
.encode(
x=alt.X("estimator:N", axis=alt.Axis(labels=False, title=None)),
y=alt.Y("sum(pct):Q", axis=alt.Axis(title="Frequency")),
color=color,
column=alt.Column(
"y_pred:Q",
bin=alt.Bin(maxbins=80),
header=alt.Header(
labelExpr="datum.value % 1 == 0 ? floor(datum.value) : null",
labelOrient="bottom",
titleOrient="bottom",
x=alt.X(
"value:O",
axis=alt.Axis(
labelAngle=0,
labelExpr="datum.value % 0.5 == 0 ? datum.value : null",
),
title="Actual and Predicted Target Values",
title="Value",
),
y=alt.Y("sum(pct):Q", axis=alt.Axis(format=".0%", title="Percentage")),
color=color,
xOffset=alt.XOffset("label:N"),
tooltip=[
alt.Tooltip("estimator:N", title=" "),
alt.Tooltip("label:N", title="Label"),
alt.Tooltip("value:O", title="Value (binned)"),
alt.Tooltip("sum(pct):Q", format=".0%", title="Percentage"),
],
)
.add_params(click)
.configure_facet(spacing=0)
.configure_range(category=alt.RangeScheme(list(legend.values())))
.configure_scale(bandPaddingInner=0.2)
.configure_view(stroke=None)
.properties(height=400, width=650)
)

return chart


Expand Down
6 changes: 3 additions & 3 deletions quantile_forest/tests/examples/plot_quantile_weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def timing():
t1 = time.time()


X, y = datasets.make_regression(n_samples=250, n_features=4, n_targets=5, random_state=0)
X, y = datasets.make_regression(n_samples=500, n_features=4, n_targets=5, random_state=0)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)

Expand All @@ -40,8 +40,8 @@ def timing():
"QRF Unweighted Quantile": "#001751",
}

est_sizes = [1, 5, 10, 25, 50, 100]
n_repeats = 10
est_sizes = [1, 5, 10, 25, 50, 75, 100]
n_repeats = 5

timings = np.empty((len(est_sizes), n_repeats, 3))
for i, n_estimators in enumerate(est_sizes):
Expand Down

0 comments on commit 8b1170d

Please sign in to comment.