Skip to content

Commit

Permalink
Update example plots
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Aug 23, 2024
1 parent ee9f371 commit c93668e
Show file tree
Hide file tree
Showing 12 changed files with 46 additions and 24 deletions.
1 change: 1 addition & 0 deletions examples/plot_huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def fit_and_upload_model(token, repo_id, local_dir="./local_repo", random_state=


def plot_quantiles_by_latlon(df, quantiles, color_scheme="cividis"):
"""Plot quantile predictions on California Housing dataset by lat/lon."""
# Slider for varying the displayed quantile estimates.
slider = alt.binding_range(
name="Predicted Quantile: ",
Expand Down
1 change: 1 addition & 0 deletions examples/plot_predict_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def predict(qrf, X, quantiles=0.5, what=None):


def plot_ecdf(df):
"""Plot the ECDF for test samples."""
min_idx = df["index"].min()
max_idx = df["index"].max()

Expand Down
1 change: 1 addition & 0 deletions examples/plot_proximity_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def plot_digits_proximities(
height=225,
width=225,
):
"""Plot Digits dataset proximities for test samples."""
dim_x, dim_y = pixel_dim[0], pixel_dim[1]
dgt_x, dgt_y = len(str(dim_x)), len(str(dim_y))

Expand Down
6 changes: 4 additions & 2 deletions examples/plot_quantile_conformalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def cqr_strategy(alpha, X_train, X_test, y_train, y_test, random_state=None):


def plot_prediction_intervals_by_strategy(df):
def plot_prediction_intervals(df, domain):
"""Plot prediction intervals by interval estimate strategy."""

def _plot_prediction_intervals(df, domain):
# Slider for varying the target coverage level.
slider = alt.binding_range(name="Coverage Target: ", min=0, max=1, step=0.1)
coverage_val = alt.param(name="coverage", value=0.9, bind=slider)
Expand Down Expand Up @@ -280,7 +282,7 @@ def plot_prediction_intervals(df, domain):
int(np.max((df[["y_test", "y_pred"]].max(axis=0)))), # max of all axes
]
df_i = df.query(f"strategy == '{strategy}'").reset_index(drop=True)
base = plot_prediction_intervals(df_i, domain)
base = _plot_prediction_intervals(df_i, domain)
chart |= base.properties(height=225, width=300, title=strategies[strategy])

return chart
Expand Down
5 changes: 3 additions & 2 deletions examples/plot_quantile_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def make_toy_dataset(n_samples, bounds, add_noise=True, random_state=0):
)


def plot_fit_and_intervals(df):
def plot_predictions_and_intervals(df):
"""Plot model predictions and prediction intervals with ground truth."""
area_pred = (
alt.Chart(df)
.transform_filter(~alt.datum["test"]) # filter to non-test data
Expand Down Expand Up @@ -143,5 +144,5 @@ def plot_fit_and_intervals(df):
return chart


chart = plot_fit_and_intervals(df)
chart = plot_predictions_and_intervals(df)
chart
25 changes: 17 additions & 8 deletions examples/plot_quantile_extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
func = lambda x: x * np.sin(x)
func_str = "f(x) = x sin(x)"
quantiles = [0.025, 0.975, 0.5]
qrf_params = {"max_samples_leaf": None, "min_samples_leaf": 4, "random_state": random_state}
qrf_params = {"min_samples_leaf": 4, "max_samples_leaf": None, "random_state": random_state}


def make_func_Xy(func, bounds, n_samples, add_noise=True, random_state=0):
Expand Down Expand Up @@ -404,7 +404,16 @@ def get_coverage_xtr(bounds_list, train_indices, test_indices, y_train, level, *


def plot_qrf_vs_xtrapolation_comparison(df, func_str):
def plot_extrapolations(df, title="", legend=False, func_str="", x_domain=None, y_domain=None):
"""Plot comparison of QRF vs Xtrapolation on extrapolated data."""

def _plot_extrapolations(
df,
title="",
legend=False,
func_str="",
x_domain=None,
y_domain=None,
):
x_scale = None
if x_domain is not None:
x_scale = alt.Scale(domain=x_domain, nice=False, padding=0)
Expand Down Expand Up @@ -521,29 +530,29 @@ def plot_extrapolations(df, title="", legend=False, func_str="", x_domain=None,
xtra_mapper = {"bb_mid": "y_pred", "bb_low": "y_pred_low", "bb_upp": "y_pred_upp"}

chart1 = alt.layer(
plot_extrapolations(
_plot_extrapolations(
df.query("~(test_left | test_right)").assign(**{"coverage": lambda x: x["cov_qrf"]}),
title="Extrapolation with Standard QRF",
**kwargs,
).resolve_scale(color="independent"),
plot_extrapolations(df.query("test_left").assign(extrapolate=True), **kwargs),
plot_extrapolations(df.query("test_right").assign(extrapolate=True), **kwargs),
_plot_extrapolations(df.query("test_left").assign(extrapolate=True), **kwargs),
_plot_extrapolations(df.query("test_right").assign(extrapolate=True), **kwargs),
)
chart2 = alt.layer(
plot_extrapolations(
_plot_extrapolations(
df.query("~(test_left | test_right)").assign(**{"coverage": lambda x: x["cov_xtr"]}),
title="Extrapolation with Xtrapolation Procedure",
legend=True,
**kwargs,
).resolve_scale(color="independent"),
plot_extrapolations(
_plot_extrapolations(
df.query("test_left")
.assign(extrapolate=True)
.drop(columns=["y_pred", "y_pred_low", "y_pred_upp"])
.rename(xtra_mapper, axis="columns"),
**kwargs,
),
plot_extrapolations(
_plot_extrapolations(
df.query("test_right")
.assign(extrapolate=True)
.drop(columns=["y_pred", "y_pred_low", "y_pred_upp"])
Expand Down
5 changes: 3 additions & 2 deletions examples/plot_quantile_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@
df = pd.concat(dfs, ignore_index=True)


def plot_interpolations(df, legend):
def plot_interpolation_predictions(df, legend):
"""Plot predictions by quantile interpolation methods."""
# Slider for varying the prediction interval that determines the quantiles being interpolated.
slider = alt.binding_range(name="Prediction Interval: ", min=0, max=1, step=0.01)
interval_val = alt.param(name="interval", value=0.9, bind=slider)
Expand Down Expand Up @@ -163,5 +164,5 @@ def plot_interpolations(df, legend):
return chart


chart = plot_interpolations(df, legend)
chart = plot_interpolation_predictions(df, legend)
chart
14 changes: 8 additions & 6 deletions examples/plot_quantile_intervals.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@
df = pd.DataFrame(data).pipe(lambda x: x * 100_000) # convert to dollars


def plot_calibration_and_intervals(df):
def plot_calibration(df):
def plot_california_calibration_and_intervals(df):
"""Plot calibration and intervals on California Housing dataset."""

def _plot_calibration(df):
domain = [
int(np.min(np.minimum(df["y_true"], df["y_pred"]))), # min of both axes
int(np.max(np.maximum(df["y_true"], df["y_pred"]))), # max of both axes
Expand Down Expand Up @@ -111,7 +113,7 @@ def plot_calibration(df):
chart = bar + tick_low + tick_upp + circle + diagonal
return chart

def plot_intervals(df):
def _plot_intervals(df):
df = df.copy()

# Order samples by interval width.
Expand Down Expand Up @@ -175,12 +177,12 @@ def plot_intervals(df):
chart = bar + tick_low + tick_upp + circle
return chart

chart1 = plot_calibration(df).properties(height=250, width=325)
chart2 = plot_intervals(df).properties(height=250, width=325)
chart1 = _plot_calibration(df).properties(height=250, width=325)
chart2 = _plot_intervals(df).properties(height=250, width=325)
chart = chart1 | chart2

return chart


chart = plot_calibration_and_intervals(df)
chart = plot_california_calibration_and_intervals(df)
chart
7 changes: 4 additions & 3 deletions examples/plot_quantile_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@


def make_func_Xy(funcs, bounds, n_samples):
"""Make a dataset from a specified function."""
"""Make a dataset from specified function(s)."""
x = np.linspace(*bounds, n_samples)
y = np.empty((len(x), len(funcs)))
for i, func in enumerate(funcs):
Expand All @@ -53,11 +53,11 @@ def make_func_Xy(funcs, bounds, n_samples):

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=random_state)

qrf = RandomForestQuantileRegressor(max_samples_leaf=None, max_depth=4, random_state=random_state)
qrf = RandomForestQuantileRegressor(max_depth=4, max_samples_leaf=None, random_state=random_state)
qrf.fit(X_train, y_train) # fit on all of the targets simultaneously

# Get multi-target predictions at specified quantiles.
y_pred = qrf.predict(X, quantiles=quantiles) # shape = (n_samples, n_targets, n_quantiles)
y_pred = qrf.predict(X, quantiles=quantiles) # output shape = (n_samples, n_targets, n_quantiles)

df = pd.DataFrame(
{
Expand All @@ -72,6 +72,7 @@ def make_func_Xy(funcs, bounds, n_samples):


def plot_multitargets(df, legend):
"""Plot predictions and prediction intervals for multi-target outputs."""
# Slider for varying the displayed prediction intervals.
slider = alt.binding_range(name="Prediction Interval: ", min=0, max=1, step=0.05)
interval_val = alt.param(name="interval", value=0.95, bind=slider)
Expand Down
3 changes: 2 additions & 1 deletion examples/plot_quantile_ranks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def make_toy_dataset(n_samples, bounds, random_state=0):
X, y = make_toy_dataset(n_samples, bounds, random_state=0)

qrf = RandomForestQuantileRegressor(
max_samples_leaf=None,
min_samples_leaf=50,
max_samples_leaf=None,
random_state=random_state,
).fit(X, y)

Expand All @@ -51,6 +51,7 @@ def make_toy_dataset(n_samples, bounds, random_state=0):


def plot_pred_and_ranks(df):
"""Plot quantile predictions and ranks."""
# Slider for varying the interval that defines the upper and lower quantile rank thresholds.
slider = alt.binding_range(name="Rank Interval Threshold: ", min=0, max=1, step=0.01)
interval_val = alt.param(name="interval", value=0.05, bind=slider)
Expand Down
1 change: 1 addition & 0 deletions examples/plot_quantile_vs_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def make_skewed_dataset(a=7, loc=-1, scale=1, random_state=0):


def plot_prediction_histograms(df, legend):
"""Plot histogram of predictions by model."""
# Slider for varying the quantile value used for generating the QRF histogram.
slider = alt.binding_range(
name="Predicted Quantile: ",
Expand Down
1 change: 1 addition & 0 deletions examples/plot_treeshap_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def get_shap_value_by_index(shap_values, index):


def plot_shap_waterfall_with_quantiles(df, height=300):
"""Plot SHAP waterfall plot by quantile predictions."""
df = df.copy()

# Slider for varying the applied quantile estimates.
Expand Down

0 comments on commit c93668e

Please sign in to comment.