Skip to content

Commit

Permalink
Update example plots
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Aug 10, 2024
1 parent cff164e commit 8db4ac6
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 192 deletions.
48 changes: 24 additions & 24 deletions examples/plot_huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
(QRF) model from Hugging Face Hub and use it to estimate new quantiles. In
this scenario, a QRF has been trained with default parameters on a train-test
split of the California housing dataset and uploaded to Hugging Face Hub. The
model is downloaded, and inference is performed over several quantiles for
each instance in the dataset. The estimates are visualized by the latitude and
longitude of each instance. The model used is available on Hugging Face Hub
model is downloaded and used to perform inference across several quantiles for
each dataset sample. The results are visualized by the latitude and longitude
of each sample. The model used is available on Hugging Face Hub
`here <https://huggingface.co/quantile-forest/california-housing-example>`_.
"""

import os
import pickle
import shutil
import tempfile

import altair as alt
import numpy as np
Expand Down Expand Up @@ -134,13 +135,16 @@ def fit_and_upload_model(token, repo_id, local_dir="./local_repo", random_state=

# Create the repository on the Hugging Face Hub if it does not exist.
# Push the model to the repository.
hub_utils.push(
repo_id=repo_id,
source=local_dir,
token=token, # personal token to be downloaded from Hugging Face
commit_message="Model commit",
create_remote=True,
)
try:
hub_utils.push(
repo_id=repo_id,
source=local_dir,
token=token, # personal token to be downloaded from Hugging Face
commit_message="Model commit",
create_remote=True,
)
except Exception as e:
print(f"Error pushing model to Hugging Face Hub: {e}")

os.remove(model_filename)
shutil.rmtree(local_dir)
Expand All @@ -149,17 +153,13 @@ def fit_and_upload_model(token, repo_id, local_dir="./local_repo", random_state=
if not load_existing:
fit_and_upload_model(token, repo_id, random_state=random_seed)

# Download the repository locally.
local_dir = "./local_repo"
if os.path.exists(local_dir):
shutil.rmtree(local_dir)
hub_utils.download(repo_id=repo_id, dst=local_dir)

# Load the fitted model.
# Download the repository locally and load the fitted model.
model_filename = "model.pkl"
with open(f"{local_dir}/{model_filename}", "rb") as file:
qrf = pickle.load(file)
shutil.rmtree(local_dir)
local_dir = "./local_repo"
with tempfile.TemporaryDirectory() as local_dir:
hub_utils.download(repo_id=repo_id, dst=local_dir)
with open(f"{local_dir}/{model_filename}", "rb") as file:
qrf = pickle.load(file)

# Estimate quantiles.
X, y = datasets.fetch_california_housing(as_frame=True, return_X_y=True)
Expand All @@ -174,7 +174,7 @@ def fit_and_upload_model(token, repo_id, local_dir="./local_repo", random_state=
)


def plot_quantiles_by_latlon(df, quantiles):
def plot_quantiles_by_latlon(df, quantiles, color_scheme="cividis"):
# Slider for varying the displayed quantile estimates.
slider = alt.binding_range(
min=0,
Expand All @@ -183,11 +183,11 @@ def plot_quantiles_by_latlon(df, quantiles):
name="Predicted Quantile: ",
)

quantile_selection = alt.param(value=0.5, bind=slider, name="quantile")
quantile_val = alt.param(value=0.5, bind=slider, name="quantile")

chart = (
alt.Chart(df)
.add_params(quantile_selection)
.add_params(quantile_val)
.transform_filter("datum.quantile == quantile")
.mark_circle()
.encode(
Expand All @@ -203,7 +203,7 @@ def plot_quantiles_by_latlon(df, quantiles):
scale=alt.Scale(zero=False),
title="Latitude",
),
color=alt.Color("value:Q", scale=alt.Scale(scheme="cividis"), title="Prediction"),
color=alt.Color("value:Q", scale=alt.Scale(scheme=color_scheme), title="Prediction"),
size=alt.Size("Population:Q"),
tooltip=[
alt.Tooltip("index:N", title="Row ID"),
Expand Down
25 changes: 12 additions & 13 deletions examples/plot_predict_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
============================================
This example demonstrates how to extract the empirical distribution from a
quantile regression forest (QRF) for one or more samples to calculate a
user-specified function of interest. While a QRF is designed to estimate
quantiles from the empirical distribution calculated for each sample, it can
also be useful to use this empirical distribution to calculate other
quantities of interest. Here, we calculate the empirical cumulative
distribution function (ECDF) for a test sample.
quantile regression forest (QRF) to calculate user-specified functions of
interest. While a QRF is designed to estimate quantiles, their empirical
distributions can also be used to calculate other statistical quantities. In
this scenario, we compute the empirical cumulative distribution function
(ECDF) for test samples.
"""

from itertools import chain
Expand Down Expand Up @@ -93,20 +92,20 @@ def predict(reg, X, quantiles=0.5, what=None):
"y_val": list(chain.from_iterable([y_i.quantiles for y_i in y_ecdf])),
"y_val2": list(chain.from_iterable([y_i.quantiles for y_i in y_ecdf]))[1:] + [np.nan],
"proba": list(chain.from_iterable([y_i.probabilities for y_i in y_ecdf])),
"sample_idx": [idx] * n_quantiles,
"index": [idx] * n_quantiles,
}
)
dfs.append(df_i)
df = pd.concat(dfs, ignore_index=True)


def plot_ecdf(df):
min_idx = df["sample_idx"].min()
max_idx = df["sample_idx"].max()
min_idx = df["index"].min()
max_idx = df["index"].max()

# Slider for determining the sample index for which the custom function is being visualized.
slider = alt.binding_range(min=min_idx, max=max_idx, step=1, name="Sample Index: ")
sample_selection = alt.param(value=0, bind=slider, name="sample_idx")
slider = alt.binding_range(min=min_idx, max=max_idx, step=1, name="Test Sample Index: ")
index_selection = alt.selection_point(value=0, bind=slider, fields=["index"])

tooltip = [
alt.Tooltip("y_val:Q", title="Response Value"),
Expand Down Expand Up @@ -136,8 +135,8 @@ def plot_ecdf(df):

chart = (
(circles + lines)
.add_params(sample_selection)
.transform_filter(alt.datum["sample_idx"] == sample_selection)
.add_params(index_selection)
.transform_filter(index_selection)
.properties(
height=400,
width=650,
Expand Down
53 changes: 18 additions & 35 deletions examples/plot_proximity_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
==================================================
This example demonstrates the use of quantile regression forest (QRF)
proximity counts to identify similar samples in an unsupervised manner, as the
target values are not used during model fitting. In this scenario, we train a
QRF on a noisy dataset to predict individual pixel values (i.e., denoise). We
then retrieve the proximity values for samples in a noisy test set. For each
test sample digit, we visualize it alongside a set of similar (non-noisy)
training samples determined by their proximity counts, as well as the
non-noisy digit. The similar samples are ordered from the highest to the
lowest proximity count for each digit, arranged from left to right and top to
bottom. This example illustrates the effectiveness of proximity counts in
identifying similar samples, even when using noisy training and test data.
proximity counts to identify similar samples. In this scenario, we train a QRF
on a noisy dataset to predict individual pixel values in an unsupervised
manner (the target labels are not used during training) for denoising
purposes. We then retrieve the proximity values for the noisy test samples. We
visualize each test sample alongside a set of similar (non-noisy) training
samples determined by their proximity counts. These similar samples are
ordered from the highest to the lowest proximity count. This illustrates how
proximity counts can effectively identify similar samples even in noisy
conditions.
"""

import altair as alt
Expand Down Expand Up @@ -41,12 +40,7 @@

def add_gaussian_noise(X, mean=0, std=0.1, random_state=None):
"""Add Gaussian noise to input data."""
if random_state is None:
rng = check_random_state(0)
elif isinstance(random_state, int):
rng = check_random_state(random_state)
else:
rng = random_state
rng = check_random_state(random_state)

scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(X)
Expand Down Expand Up @@ -97,7 +91,6 @@ def extract_floats(combined_df, scale=100):
.join(y_test)
.reset_index()
.join(df_prox)
.iloc[:n_test_samples]
.explode("prox")
.assign(
**{
Expand Down Expand Up @@ -140,26 +133,16 @@ def plot_digits_proximities(
subplot_dim = (width - subplot_spacing * (n_subplot_rows - 1)) / n_subplot_rows

# Slider for determining the test index for which the data is being visualized.
slider = alt.binding_range(
min=0,
max=n_samples - 1,
step=1,
name="Test Index: ",
)

idx_val = alt.selection_point(
value=0,
bind=slider,
fields=["index"],
)
slider = alt.binding_range(min=0, max=n_samples - 1, step=1, name="Test Sample Index: ")
index_selection = alt.selection_point(value=0, bind=slider, fields=["index"])

scale = alt.Scale(domain=[x_min, x_max], scheme="greys")
opacity = (alt.value(0), alt.value(0.67))
opacity = (alt.value(0), alt.value(0.5))

base = alt.Chart(df).add_params(idx_val).transform_filter(idx_val)
base = alt.Chart(df).add_params(index_selection).transform_filter(index_selection)

chart1 = (
base.transform_filter(f"datum.prox_idx == 0")
base.transform_filter("datum.prox_idx == 0")
.transform_fold(fold=pixel_cols, as_=["pixel", "value"])
.transform_calculate(value_clean=f"floor(datum.value / {pixel_scale})")
.transform_calculate(value_noisy=f"datum.value - (datum.value_clean * {pixel_scale})")
Expand All @@ -172,7 +155,7 @@ def plot_digits_proximities(
opacity=alt.condition(alt.datum["value_noisy"] == 0, *opacity),
tooltip=[
alt.Tooltip("target:Q", title="Digit"),
alt.Tooltip("value_noisy:Q", format=".3f", title="Pixel Value"),
alt.Tooltip("value_noisy:Q", format=",.3f", title="Pixel Value"),
alt.Tooltip("x:Q", title="Pixel X"),
alt.Tooltip("y:Q", title="Pixel Y"),
],
Expand Down Expand Up @@ -213,7 +196,7 @@ def plot_digits_proximities(
)

chart3 = (
base.transform_filter(f"datum.prox_idx == 0")
base.transform_filter("datum.prox_idx == 0")
.transform_fold(fold=pixel_cols, as_=["pixel", "value"])
.transform_calculate(value_clean=f"floor(datum.value / {pixel_scale})")
.transform_calculate(x=pixel_x, y=pixel_y)
Expand All @@ -225,7 +208,7 @@ def plot_digits_proximities(
opacity=alt.condition(alt.datum["value_clean"] == 0, *opacity),
tooltip=[
alt.Tooltip("target:Q", title="Digit"),
alt.Tooltip("value_clean:Q", title="Pixel Value"),
alt.Tooltip("value_clean:Q", format=",.3f", title="Pixel Value"),
alt.Tooltip("x:Q", title="Pixel X"),
alt.Tooltip("y:Q", title="Pixel Y"),
],
Expand Down
33 changes: 15 additions & 18 deletions examples/plot_quantile_conformalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
This example demonstrates the use of a quantile regression forest (QRF) to
construct reliable prediction intervals using conformalized quantile
regression (CQR). CQR provides prediction intervals that attain valid
coverage, whereas QRF may require additional calibration for reliable interval
estimates. In this example, by using CQR, we achieve a level of coverage
(i.e., the percentage of samples that actually fall within their prediction
interval) that is generally closer to the target level. This example is
adapted from `"Prediction intervals: Quantile Regression Forests"
regression (CQR). While QRFs can estimate quantiles, they may require
additional calibration to provide reliable interval estimates. CQR provides
prediction intervals that attain valid coverage. In this example, we use CQR
to enhance QRF by producing prediction intervals that achieve a level of
coverage (i.e., the percentage of samples that actually fall within their
prediction interval) that is generally closer to the target level. This
example is adapted from `"Prediction intervals: Quantile Regression Forests"
<https://www.kaggle.com/code/carlmcbrideellis/prediction-intervals-quantile-regression-forests>`_
by Carl McBride Ellis.
"""
Expand Down Expand Up @@ -69,6 +70,7 @@ def mean_width_score(y_pred_low, y_pred_upp):


def qrf_strategy(alpha, X_train, X_test, y_train, y_test, random_state=None):
"""QRF (baseline) strategy."""
quantiles = [alpha / 2, 1 - alpha / 2]

qrf = RandomForestQuantileRegressor(random_state=random_state)
Expand All @@ -89,6 +91,7 @@ def qrf_strategy(alpha, X_train, X_test, y_train, y_test, random_state=None):


def cqr_strategy(alpha, X_train, X_test, y_train, y_test, random_state=None):
"""Conformalized Quantile Regression (CQR) strategy with a QRF."""
quantiles = [alpha / 2, 1 - alpha / 2]

# Create calibration set.
Expand Down Expand Up @@ -160,8 +163,7 @@ def plot_prediction_intervals_by_strategy(df):
def plot_prediction_intervals(df, domain):
# Slider for varying the target coverage level.
slider = alt.binding_range(min=0, max=1, step=0.1, name="Coverage Target: ")
cov_selection = alt.param(value=0.9, bind=slider, name="coverage")
cov_tol = 0.01
coverage_val = alt.param(value=0.9, bind=slider, name="coverage")

click = alt.selection_point(fields=["y_label"], bind="legend")

Expand All @@ -175,10 +177,7 @@ def plot_prediction_intervals(df, domain):

base = (
alt.Chart(df)
.transform_filter(
(1 - alt.datum["alpha"] - cov_tol <= cov_selection)
& (1 - alt.datum["alpha"] + cov_tol >= cov_selection)
)
.transform_filter("round((1 - datum.alpha) * 100) / 100 == coverage")
.transform_calculate(
y_label=(
"((datum.y_test >= datum.y_pred_low) & (datum.y_test <= datum.y_pred_upp))"
Expand Down Expand Up @@ -249,8 +248,8 @@ def plot_prediction_intervals(df, domain):
)
.transform_calculate(
coverage_text=(
f"'Coverage: ' + format(datum.coverage * 100, '.1f') + '%'"
f" + ' (target = ' + format((1 - datum.alpha) * 100, '.1f') + '%)'"
"'Coverage: ' + format(datum.coverage * 100, '.1f') + '%'"
" + ' (target = ' + format((1 - datum.alpha) * 100, '.1f') + '%)'"
)
)
.mark_text(align="left", baseline="top")
Expand All @@ -262,9 +261,7 @@ def plot_prediction_intervals(df, domain):
)
text_with = (
base.transform_aggregate(width="mean(width)", groupby=["strategy"])
.transform_calculate(
width_text=f"'Interval Width: ' + format({alt.datum['width']}, '$,d')"
)
.transform_calculate(width_text="'Interval Width: ' + format(datum.width, '$,d')")
.mark_text(align="left", baseline="top")
.encode(
x=alt.value(5),
Expand All @@ -275,7 +272,7 @@ def plot_prediction_intervals(df, domain):

chart = (
bar + tick_low + tick_upp + circle + diagonal + text_coverage + text_with
).add_params(cov_selection)
).add_params(coverage_val)

return chart

Expand Down
Loading

0 comments on commit 8db4ac6

Please sign in to comment.