Skip to content

Commit

Permalink
Update example plots
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Aug 6, 2024
1 parent f10e1cd commit f72ba1b
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 91 deletions.
2 changes: 1 addition & 1 deletion examples/plot_huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
repo_id = "quantile-forest/california-housing-example"
load_existing = True

quantiles = list((np.arange(5) * 25) / 100)
quantiles = np.arange(0, 1.25, 0.25).round(2).tolist()
sample_frac = 1


Expand Down
50 changes: 26 additions & 24 deletions examples/plot_proximity_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@

rng = check_random_state(0)

n_examples = 25
n_test = 25
noise_std = 0.1

# Load the Digits dataset.
X, y = datasets.load_digits(return_X_y=True, as_frame=True)

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=n_test, random_state=0)


def add_gaussian_noise(X, mean=0, std=0.1, random_state=None):
Expand Down Expand Up @@ -72,27 +72,29 @@ def extract_floats(combined_df, scale=100):
return df1, df2


# Randomly corrupt a fraction of the training and test data.
X_train_corrupt = X_train.pipe(add_gaussian_noise, std=noise_std, random_state=rng)
X_test_corrupt = X_test.pipe(add_gaussian_noise, std=noise_std, random_state=rng)
# Randomly add noise to the training and test data.
X_train_noisy = X_train.pipe(add_gaussian_noise, std=noise_std, random_state=rng)
X_test_noisy = X_test.pipe(add_gaussian_noise, std=noise_std, random_state=rng)

# We set `max_samples_leaf=None` so that all leaf node samples are stored.
# We set `max_samples_leaf=None` to ensure that every sample in the training
# data is stored in the leaf nodes. By doing this, we allow the model to
# consider all samples as potential candidates for proximity calculations.
qrf = RandomForestQuantileRegressor(max_samples_leaf=None, random_state=0)
qrf.fit(X_train_corrupt, X_train)
qrf.fit(X_train_noisy, X_train)

# Get the proximity counts.
proximities = qrf.proximity_counts(X_test_corrupt)
proximities = qrf.proximity_counts(X_test_noisy)

df_prox = pd.DataFrame(
{"prox": [[(j, *p) for j, p in enumerate(proximities[i])] for i in range(len(X_test))]}
)

df = (
combine_floats(X_test, X_test_corrupt)
combine_floats(X_test, X_test_noisy)
.join(y_test)
.reset_index()
.join(df_prox)
.iloc[:n_examples]
.iloc[:n_test]
.explode("prox")
.assign(
**{
Expand All @@ -106,7 +108,7 @@ def extract_floats(combined_df, scale=100):
)

df_lookup = (
combine_floats(X_train, X_train_corrupt)
combine_floats(X_train, X_train_noisy)
.assign(**{"index": np.arange(len(X_train))})
.join(y_train)
)
Expand Down Expand Up @@ -146,23 +148,23 @@ def plot_digits_proximities(
fold=[f"pixel_{y}_{x}" for y in range(8) for x in range(8)],
as_=["pixel", "value"],
)
.transform_calculate(value_orig="floor(datum.value / 100)")
.transform_calculate(value_corr="datum.value - (datum.value_orig * 100)")
.transform_calculate(value_clean="floor(datum.value / 100)")
.transform_calculate(value_noisy="datum.value - (datum.value_clean * 100)")
.transform_calculate(x="substring(datum.pixel, 8, 9)", y="substring(datum.pixel, 6, 7)")
.mark_rect()
.encode(
x=alt.X("x:N", axis=None),
y=alt.Y("y:N", axis=None),
color=alt.Color("value_corr:Q", legend=None, scale=alt.Scale(scheme="greys")),
opacity=alt.condition(alt.datum["value_corr"] == 0, alt.value(0), alt.value(0.67)),
color=alt.Color("value_noisy:Q", legend=None, scale=alt.Scale(scheme="greys")),
opacity=alt.condition(alt.datum["value_noisy"] == 0, alt.value(0), alt.value(0.67)),
tooltip=[
alt.Tooltip("target:Q", title="Digit"),
alt.Tooltip("value_corr:Q", format=".3f", title="Pixel Value"),
alt.Tooltip("value_noisy:Q", format=".3f", title="Pixel Value"),
alt.Tooltip("x:Q", title="Pixel X"),
alt.Tooltip("y:Q", title="Pixel Y"),
],
)
.properties(height=height, width=width, title="Test Digit (corrupted)")
.properties(height=height, width=width, title="Test Digit (noisy)")
)

chart2 = (
Expand All @@ -171,8 +173,8 @@ def plot_digits_proximities(
.encode(
x=alt.X("x:N", axis=None),
y=alt.Y("y:N", axis=None),
color=alt.Color("value_orig:Q", legend=None, scale=alt.Scale(scheme="greys")),
opacity=alt.condition(alt.datum["value_orig"] == 0, alt.value(0), alt.value(0.67)),
color=alt.Color("value_clean:Q", legend=None, scale=alt.Scale(scheme="greys")),
opacity=alt.condition(alt.datum["value_clean"] == 0, alt.value(0), alt.value(0.67)),
tooltip=[
alt.Tooltip("prox_cnt", title="Proximity Count"),
alt.Tooltip("target:Q", title="Digit"),
Expand All @@ -196,7 +198,7 @@ def plot_digits_proximities(
fold=[f"pixel_{y}_{x}" for y in range(8) for x in range(8)],
as_=["pixel", "value"],
)
.transform_calculate(value_orig="floor(datum.value / 100)")
.transform_calculate(value_clean="floor(datum.value / 100)")
.transform_calculate(x="substring(datum.pixel, 8, 9)", y="substring(datum.pixel, 6, 7)")
.properties(
height=subplot_dim, width=subplot_dim, title=f"Proximity Digits (top {n_prox})"
Expand All @@ -209,17 +211,17 @@ def plot_digits_proximities(
fold=[f"pixel_{y}_{x}" for y in range(8) for x in range(8)],
as_=["pixel", "value"],
)
.transform_calculate(value_orig="floor(datum.value / 100)")
.transform_calculate(value_clean="floor(datum.value / 100)")
.transform_calculate(x="substring(datum.pixel, 8, 9)", y="substring(datum.pixel, 6, 7)")
.mark_rect()
.encode(
x=alt.X("x:N", axis=None),
y=alt.Y("y:N", axis=None),
color=alt.Color("value_orig:Q", legend=None, scale=alt.Scale(scheme="greys")),
opacity=alt.condition(alt.datum["value_orig"] == 0, alt.value(0), alt.value(0.67)),
color=alt.Color("value_clean:Q", legend=None, scale=alt.Scale(scheme="greys")),
opacity=alt.condition(alt.datum["value_clean"] == 0, alt.value(0), alt.value(0.67)),
tooltip=[
alt.Tooltip("target:Q", title="Digit"),
alt.Tooltip("value_orig:Q", title="Pixel Value"),
alt.Tooltip("value_clean:Q", title="Pixel Value"),
alt.Tooltip("x:Q", title="Pixel X"),
alt.Tooltip("y:Q", title="Pixel Y"),
],
Expand Down
5 changes: 3 additions & 2 deletions examples/plot_quantile_conformalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
random_state = 0
rng = check_random_state(random_state)

coverages = list(np.arange(11) / 10) # the "coverage level"
coverages = np.arange(0, 1.1, 0.1).round(1).tolist() # the "coverage level"

# Load the California Housing Prices dataset.
california = datasets.fetch_california_housing()
Expand Down Expand Up @@ -146,7 +146,8 @@ def cqr_strategy(alpha, X_train, X_test, y_train, y_test):
"coverage": coverage_score(grp["y_test"], grp["y_pred_low"], grp["y_pred_upp"]),
"width": mean_width_score(grp["y_pred_low"], grp["y_pred_upp"]),
}
)
),
include_groups=False,
)
.reset_index()
)
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_quantile_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from quantile_forest import RandomForestQuantileRegressor

intervals = list(np.arange(101) / 100)
intervals = np.arange(0, 1.01, 0.01).round(2).tolist()

# Create toy dataset.
X = np.array([[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1]])
Expand Down
112 changes: 52 additions & 60 deletions examples/plot_quantile_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
Multiple-Output Quantile Regression with QRFs
=============================================
This example demonstrates fitting a single quantile regressor for multiple
This example demonstrates how to fit a single quantile regressor for multiple
target variables on a toy dataset. For each target, multiple quantiles can be
estimated simultaneously. In this example, the target variable has two output
values for each sample, with a single regressor used to estimate three
quantiles (the median and interval points) for each target.
values for each sample, with a single regressor used to estimate many
quantiles simultaneously. Three of these quantiles are visualized concurrently
for each target: the median line and the area defined by the interval points.
"""

import altair as alt
Expand All @@ -16,13 +17,11 @@

from quantile_forest import RandomForestQuantileRegressor

alt.data_transformers.disable_max_rows()

np.random.seed(0)

n_samples = 2500
bounds = [0, 100]
intervals = list(np.arange(21) / 20)
quantiles = np.arange(0, 1.025, 0.025).round(3).tolist()

# Define functions that generate targets; each function maps to one target.
funcs = [
Expand Down Expand Up @@ -50,6 +49,11 @@ def make_func_Xy(funcs, bounds, n_samples):
return np.atleast_2d(x).T, y


def format_frac(fraction):
formatted = ("%.3g" % fraction).rstrip("0").rstrip(".")
return formatted if formatted else "0"


# Create the dataset with multiple target variables.
X, y = make_func_Xy(funcs, bounds, n_samples)

Expand All @@ -58,33 +62,30 @@ def make_func_Xy(funcs, bounds, n_samples):
qrf = RandomForestQuantileRegressor(max_samples_leaf=None, max_depth=4, random_state=0)
qrf.fit(X_train, y_train) # fit on all of the targets simultaneously

dfs = []
for idx, interval in enumerate(intervals):
# Get multiple-output predictions at median and prediction intervals.
quantiles = [0.5, round(0.5 - interval / 2, 3), round(0.5 + interval / 2, 3)]
y_pred = qrf.predict(X, quantiles=quantiles)
# Get multiple-output predictions at many quantiles.
y_pred = qrf.predict(X, quantiles=quantiles)

df_i = pd.DataFrame(
df = pd.DataFrame(
{
"x": np.tile(X.squeeze(), len(funcs)),
"y": y.reshape(-1, order="F"),
"y_true": np.concatenate([f["signal"](X.squeeze()) for f in funcs]),
"y_pred": np.concatenate([y_pred[:, i, len(quantiles) // 2] for i in range(len(funcs))]),
"target": np.concatenate([[f"{i}"] * len(X) for i in range(len(funcs))]),
}
).join(
pd.DataFrame(
{
"x": np.tile(X.squeeze(), len(funcs)),
"y": y.reshape(-1, order="F"),
"y_true": np.concatenate([f["signal"](X.squeeze()) for f in funcs]),
"y_pred": np.concatenate([y_pred[:, i, 0] for i in range(len(funcs))]),
"y_pred_low": np.concatenate([y_pred[:, i, 1] for i in range(len(funcs))]),
"y_pred_upp": np.concatenate([y_pred[:, i, 2] for i in range(len(funcs))]),
"quantile_low": np.concatenate([[quantiles[1]] * len(X) for i in range(len(funcs))]),
"quantile_upp": np.concatenate([[quantiles[2]] * len(X) for i in range(len(funcs))]),
"target": np.concatenate([[f"{i}"] * len(X) for i in range(len(funcs))]),
f"q_{format_frac(q)}": np.concatenate([y_pred[:, t, idx] for t in range(len(funcs))])
for idx, q in enumerate(quantiles)
}
)
dfs.append(df_i)
df = pd.concat(dfs)
)


def plot_multioutputs(df, legend):
slider = alt.binding_range(min=0, max=1, step=0.05, name="Prediction Interval: ")
interval_selection = alt.param(value=0.95, bind=slider, name="interval")
interval_tol = 0.001

click = alt.selection_point(fields=["target"], bind="legend")

Expand Down Expand Up @@ -112,52 +113,43 @@ def plot_multioutputs(df, legend):
alt.Tooltip("quantile_upp:Q", format=".3f", title="Upper Quantile"),
]

points = (
base = (
alt.Chart(df)
.mark_circle(color="black", opacity=0.25, size=25)
.encode(
x=alt.X("x:Q", scale=alt.Scale(nice=False)),
y=alt.Y("y:Q"),
color=alt.condition(click, alt.Color("target:N"), alt.value("lightgray")),
tooltip=tooltip,
.transform_calculate(
quantile_low=f"round((0.5 - interval / 2) * 1000) / 1000",
quantile_upp=f"round((0.5 + interval / 2) * 1000) / 1000",
quantile_low_col="'q_' + datum.quantile_low",
quantile_upp_col="'q_' + datum.quantile_upp",
)
.transform_calculate(
y_pred_low=f"datum[datum.quantile_low_col]",
y_pred_upp=f"datum[datum.quantile_upp_col]",
)
)

line = (
alt.Chart(df)
.mark_line(color="black", size=3)
.encode(
x=alt.X("x:Q", scale=alt.Scale(nice=False), title="x"),
y=alt.Y("y_pred:Q", title="y"),
color=color,
tooltip=tooltip,
)
points = base.mark_circle(color="black", opacity=0.25, size=25).encode(
x=alt.X("x:Q", scale=alt.Scale(nice=False)),
y=alt.Y("y:Q"),
color=alt.condition(click, alt.Color("target:N"), alt.value("lightgray")),
tooltip=tooltip,
)

area = (
alt.Chart(df)
.mark_area(opacity=0.25)
.encode(
x=alt.X("x:Q", scale=alt.Scale(nice=False), title="x"),
y=alt.Y("y_pred_low:Q", title="y"),
y2=alt.Y2("y_pred_upp:Q", title=None),
color=color,
tooltip=tooltip,
)
line = base.mark_line(color="black", size=3).encode(
x=alt.X("x:Q", scale=alt.Scale(nice=False), title="x"),
y=alt.Y("y_pred:Q", title="y"),
color=color,
tooltip=tooltip,
)

area = base.mark_area(opacity=0.25).encode(
x=alt.X("x:Q", scale=alt.Scale(nice=False), title="x"),
y=alt.Y("y_pred_low:Q", title="y"),
y2=alt.Y2("y_pred_upp:Q", title=None),
color=color,
tooltip=tooltip,
)
chart = (
(points + area + line)
.transform_filter(
(
(alt.datum.quantile_low >= (0.5 - interval_selection / 2 - interval_tol))
& (alt.datum.quantile_low <= (0.5 - interval_selection / 2 + interval_tol))
)
| (
(alt.datum.quantile_upp >= (0.5 + interval_selection / 2 - interval_tol))
& (alt.datum.quantile_upp <= (0.5 + interval_selection / 2 + interval_tol))
)
)
.add_params(interval_selection, click)
.configure_range(category=alt.RangeScheme(list(legend.values())))
.properties(height=400, width=650, title="Multi-target Prediction Intervals")
Expand Down
4 changes: 2 additions & 2 deletions examples/plot_quantile_vs_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

rng = check_random_state(0)

quantiles = np.arange(0, 1.05, 0.05).round(2).tolist()

# Create right-skewed dataset.
n_samples = 5000
a, loc, scale = 5, -1, 1
Expand All @@ -33,8 +35,6 @@
y = skewnorm_rv.rvs(n_samples)
X = rng.randn(n_samples, 2) * y.reshape(-1, 1)

quantiles = list(np.arange(21) * 5 / 100)

regr_rf = RandomForestRegressor(n_estimators=10, random_state=0)
regr_qrf = RandomForestQuantileRegressor(n_estimators=10, random_state=0)

Expand Down
2 changes: 1 addition & 1 deletion examples/plot_treeshap_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

n_samples = 500
test_idx = 0
quantiles = list((np.arange(11) * 10) / 100)
quantiles = np.arange(0, 1.1, 0.1).round(1).tolist()


def get_shap_values(qrf, X, quantile=0.5, **kwargs):
Expand Down

0 comments on commit f72ba1b

Please sign in to comment.