diff --git a/examples/plot_huggingface_model.py b/examples/plot_huggingface_model.py index 9f56e3e..7d1c373 100755 --- a/examples/plot_huggingface_model.py +++ b/examples/plot_huggingface_model.py @@ -179,7 +179,7 @@ def plot_quantiles_by_latlon(df, quantiles): min=0, max=1, step=0.5 if len(quantiles) == 1 else 1 / (len(quantiles) - 1), - name="Quantile:", + name="Quantile: ", ) q_val = alt.selection_point( diff --git a/examples/plot_predict_custom.py b/examples/plot_predict_custom.py index 03e6c20..d70f563 100755 --- a/examples/plot_predict_custom.py +++ b/examples/plot_predict_custom.py @@ -23,7 +23,7 @@ np.random.seed(0) -n_test_samples = 10 +n_test_samples = 100 def predict(reg, X, quantiles=0.5, what=None): @@ -98,7 +98,7 @@ def predict(reg, X, quantiles=0.5, what=None): def plot_ecdf(df): min_idx = df["sample_idx"].min() max_idx = df["sample_idx"].max() - slider = alt.binding_range(min=min_idx, max=max_idx, step=1, name="Sample Index:") + slider = alt.binding_range(min=min_idx, max=max_idx, step=1, name="Sample Index: ") sample_selection = alt.param(value=0, bind=slider, name="sample_idx") tooltip = [ diff --git a/examples/plot_quantile_conformalized.py b/examples/plot_quantile_conformalized.py index 58b0e98..62689b6 100755 --- a/examples/plot_quantile_conformalized.py +++ b/examples/plot_quantile_conformalized.py @@ -8,9 +8,9 @@ while QRF may require additional calibration for reliable interval estimates. Notice that in this example, by using CQR we obtain a level of coverage (i.e., percentage of samples that actaully fall within their prediction interval) -that is closer to the target level. Adapted from "Prediction intervals: -Quantile Regression Forests" by Carl McBride Ellis: -https://www.kaggle.com/code/carlmcbrideellis/prediction-intervals-quantile-regression-forests. +that is generally closer to the target level. Adapted from `"Prediction +intervals: Quantile Regression Forests" by Carl McBride Ellis +`_. """ import altair as alt @@ -30,8 +30,7 @@ random_state = 0 rng = check_random_state(random_state) -cov_pct = 95 # the "coverage level" -alpha = (100 - cov_pct) / 100 +coverages = list(np.arange(11) / 10) # the "coverage level" # Load the California Housing Prices dataset. california = datasets.fetch_california_housing() @@ -113,7 +112,7 @@ def cqr_strategy(alpha, X_train, X_test, y_train, y_test): conf_scores = (np.vstack((a, b)).T).max(axis=1) # Get the 1-alpha quantile `s` from the distribution of conformity scores. - s = np.quantile(conf_scores, (1 - alpha) * (1 + (1 / (len(y_calib))))) + s = np.quantile(conf_scores, np.clip((1 - alpha) * (1 + (1 / (len(y_calib)))), 0, 1)) # Subtract `s` from the lower quantile and add it to the upper quantile. y_conf_low = y_pred_low - s @@ -129,12 +128,15 @@ def cqr_strategy(alpha, X_train, X_test, y_train, y_test): # Get strategy outputs as a data frame. -args = (alpha, X_train, X_test, y_train, y_test) -df = pd.concat([qrf_strategy(*args), cqr_strategy(*args)]) +dfs = [] +for cov_frac in coverages: + alpha = float(round(1 - cov_frac, 2)) + args = (alpha, X_train, X_test, y_train, y_test) + dfs.append(pd.concat([qrf_strategy(*args), cqr_strategy(*args)]).assign(alpha=alpha)) +df = pd.concat(dfs) -# Calculate coverage and width metrics. metrics = ( - df.groupby("strategy") + df.groupby(["alpha", "strategy"]) .apply( lambda grp: pd.Series( { @@ -147,10 +149,14 @@ def cqr_strategy(alpha, X_train, X_test, y_train, y_test): ) # Merge the metrics into the data frame. -df = df.merge(metrics, on="strategy", how="left") +df = df.merge(metrics, on=["alpha", "strategy"], how="left") def plot_prediction_intervals(df, domain): + slider = alt.binding_range(min=0, max=1, step=0.1, name="Coverage: ") + cov_selection = alt.param(value=0.9, bind=slider, name="coverage") + cov_tol = 0.01 + click = alt.selection_point(fields=["y_label"], bind="legend") color_circle = alt.Color( @@ -168,10 +174,17 @@ def plot_prediction_intervals(df, domain): alt.Tooltip("y_label:N", title="Within Interval"), ] - base = alt.Chart(df).transform_calculate( - y_label=( - "((datum.y_test >= datum.y_pred_low) & (datum.y_test <= datum.y_pred_upp))" - " ? 'Yes' : 'No'" + base = ( + alt.Chart(df) + .transform_filter( + (1 - alt.datum["alpha"] - cov_tol <= cov_selection) + & (1 - alt.datum["alpha"] + cov_tol >= cov_selection) + ) + .transform_calculate( + y_label=( + "((datum.y_test >= datum.y_pred_low) & (datum.y_test <= datum.y_pred_upp))" + " ? 'Yes' : 'No'" + ) ) ) @@ -224,11 +237,13 @@ def plot_prediction_intervals(df, domain): ) text_coverage = ( - base.transform_aggregate(coverage="mean(coverage)", groupby=["strategy"]) + base.transform_aggregate( + alpha="mean(alpha)", coverage="mean(coverage)", groupby=["strategy"] + ) .transform_calculate( coverage_text=( - f"'Coverage: ' + format({alt.datum['coverage'] * 100}, '.1f') + '%'" - f" + ' (target = {cov_pct}%)'" + f"'Coverage: ' + format(datum.coverage * 100, '.1f') + '%'" + f" + ' (target = ' + format((1 - datum.alpha) * 100, '.1f') + '%)'" ) ) .mark_text(align="left", baseline="top") @@ -251,7 +266,9 @@ def plot_prediction_intervals(df, domain): ) ) - chart = bar + tick_low + tick_upp + circle + diagonal + text_coverage + text_with + chart = (bar + tick_low + tick_upp + circle + diagonal + text_coverage + text_with).add_params( + cov_selection + ) return chart diff --git a/examples/plot_quantile_interpolation.py b/examples/plot_quantile_interpolation.py index a34bd5d..2cabb1f 100755 --- a/examples/plot_quantile_interpolation.py +++ b/examples/plot_quantile_interpolation.py @@ -76,8 +76,8 @@ def plot_interpolations(df, legend): - slider = alt.binding_range(min=0, max=1, step=0.01, name="Prediction Interval:") - interval_selection = alt.param(value=0.95, bind=slider, name="interval") + slider = alt.binding_range(min=0, max=1, step=0.01, name="Prediction Interval: ") + interval_selection = alt.param(value=0.9, bind=slider, name="interval") interval_tol = 0.001 click = alt.selection_point(fields=["method"], bind="legend") diff --git a/examples/plot_quantile_multioutput.py b/examples/plot_quantile_multioutput.py index 1524931..b922605 100755 --- a/examples/plot_quantile_multioutput.py +++ b/examples/plot_quantile_multioutput.py @@ -80,7 +80,7 @@ def make_func_Xy(funcs, bounds, n_samples): def plot_multioutputs(df, legend): - slider = alt.binding_range(min=0, max=1, step=0.05, name="Prediction Interval:") + slider = alt.binding_range(min=0, max=1, step=0.05, name="Prediction Interval: ") interval_selection = alt.param(value=0.95, bind=slider, name="interval") interval_tol = 0.001 diff --git a/examples/plot_quantile_vs_standard.py b/examples/plot_quantile_vs_standard.py index bc511e1..6a608d1 100755 --- a/examples/plot_quantile_vs_standard.py +++ b/examples/plot_quantile_vs_standard.py @@ -6,11 +6,13 @@ forest and a standard random forest regressor on a synthetic, right-skewed dataset. In a right-skewed distribution, the mean is to the right of the median. As illustrated by a greater overlap in the frequencies of the actual -and predicted values, the median estimated by a quantile regressor can be a -more reliable estimator of a skewed distribution than the mean. +and predicted values, the median (quantile = 0.5) estimated by a quantile +regressor can be a more reliable estimator of a skewed distribution than the +mean. """ import altair as alt +import numpy as np import pandas as pd import scipy as sp from sklearn.ensemble import RandomForestRegressor @@ -29,6 +31,8 @@ y = skewnorm_rv.rvs(n_samples) X = rng.randn(n_samples, 2) * y.reshape(-1, 1) +quantiles = list(np.arange(101) / 100) + regr_rf = RandomForestRegressor(n_estimators=10, random_state=0) regr_qrf = RandomForestQuantileRegressor(n_estimators=10, random_state=0) @@ -38,7 +42,7 @@ regr_qrf.fit(X_train, y_train) y_pred_rf = regr_rf.predict(X_test) # standard RF predictions (mean) -y_pred_qrf = regr_qrf.predict(X_test, quantiles=0.5) # QRF predictions (median) +y_pred_qrf = regr_qrf.predict(X_test, quantiles=quantiles) # QRF predictions (quantiles) legend = { "Actual": "#c0c0c0", @@ -46,10 +50,30 @@ "QRF (Median)": "#006aff", } -df = pd.DataFrame({"actual": y_test, "rf": y_pred_rf, "qrf": y_pred_qrf}) +df = pd.concat( + [ + pd.DataFrame({"actual": y_test, "rf": y_pred_rf, "qrf": y_pred_qrf[..., q_idx]}).assign( + quantile=quantile + ) + for q_idx, quantile in enumerate(quantiles) + ] +) def plot_prediction_histograms(df, legend): + slider = alt.binding_range( + min=0, + max=1, + step=0.5 if len(quantiles) == 1 else 1 / (len(quantiles) - 1), + name="Quantile: ", + ) + + q_val = alt.selection_point( + value=0.5, + bind=slider, + fields=["quantile"], + ) + click = alt.selection_point(fields=["label"], bind="legend") color = alt.condition( @@ -60,14 +84,15 @@ def plot_prediction_histograms(df, legend): chart = ( alt.Chart(df) - .transform_calculate(calculate=f"round({alt.datum['actual']} * 10) / 10", as_="Actual") - .transform_calculate(calculate=f"round({alt.datum['rf']} * 10) / 10", as_="RF (Mean)") - .transform_calculate(calculate=f"round({alt.datum['qrf']} * 10) / 10", as_="QRF (Median)") - .transform_fold(["Actual", "RF (Mean)", "QRF (Median)"], as_=["label", "value"]) + .transform_filter(q_val) + .transform_calculate(calculate=f"round(datum.actual * 10) / 10", as_="Actual") + .transform_calculate(calculate=f"round(datum.rf * 10) / 10", as_="RF (Mean)") + .transform_calculate(calculate=f"round(datum.qrf * 10) / 10", as_="QRF (Quantile)") + .transform_fold(["Actual", "RF (Mean)", "QRF (Quantile)"], as_=["label", "value"]) .mark_bar() .encode( x=alt.X( - "value:O", + "value:N", axis=alt.Axis( labelAngle=0, labelExpr="datum.value % 0.5 == 0 ? datum.value : null", @@ -83,7 +108,7 @@ def plot_prediction_histograms(df, legend): alt.Tooltip("count():Q", format=",d", title="Counts"), ], ) - .add_params(click) + .add_params(q_val, click) .configure_range(category=alt.RangeScheme(list(legend.values()))) .properties(height=400, width=650) ) diff --git a/examples/plot_treeshap_example.py b/examples/plot_treeshap_example.py index 8950926..d1f6826 100644 --- a/examples/plot_treeshap_example.py +++ b/examples/plot_treeshap_example.py @@ -102,7 +102,7 @@ def plot_shap_waterfall_with_quantiles(df, height=300): min=0, max=1, step=0.5 if len(quantiles) == 1 else 1 / (len(quantiles) - 1), - name="Quantile:", + name="Quantile: ", ) q_val = alt.selection_point(