Skip to content

Commit

Permalink
Update plots (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson authored Jul 29, 2024
1 parent 2d2f6d5 commit 18f3275
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 114 deletions.
63 changes: 35 additions & 28 deletions examples/plot_huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
housing dataset and uploaded to Hugging Face Hub. The model is downloaded and
inference is performed over several quantiles for each instance in the dataset.
The estimates are visualized by the latitude and longitude of each instance.
The model used is availabe on Hugging Face Hub:
https://huggingface.co/quantile-forest/california-housing-example
"""

import os
Expand All @@ -29,6 +31,9 @@
repo_id = "quantile-forest/california-housing-example"
load_existing = True

quantiles = list((np.arange(5) * 25) / 100)
sample_frac = 1


def fit_and_upload_model(token, repo_id, local_dir="./local_repo"):
"""Function used to fit the model and upload it to Hugging Face Hub."""
Expand Down Expand Up @@ -139,6 +144,35 @@ def fit_and_upload_model(token, repo_id, local_dir="./local_repo"):
shutil.rmtree(local_dir)


if not load_existing:
fit_and_upload_model(token, repo_id)

# Download the repository locally.
local_dir = "./local_repo"
if os.path.exists(local_dir):
shutil.rmtree(local_dir)
hub_utils.download(repo_id=repo_id, dst=local_dir)

# Load the fitted model.
model_filename = "model.pkl"
with open(f"{local_dir}/{model_filename}", "rb") as file:
qrf = pickle.load(file)
shutil.rmtree(local_dir)

# Estimate quantiles.
X, y = datasets.fetch_california_housing(as_frame=True, return_X_y=True)
y_pred = qrf.predict(X, quantiles=quantiles) * 100_000 # predict in dollars

df = (
pd.DataFrame(y_pred, columns=quantiles)
.reset_index()
.sample(frac=sample_frac, random_state=0)
.melt(id_vars=["index"], var_name="quantile", value_name="value")
.merge(X[["Latitude", "Longitude", "Population"]].reset_index(), on="index", how="right")
)
print(df)


def plot_quantiles_by_latlon(df, quantiles):
# Slider for varying the displayed quantile estimates.
slider = alt.binding_range(
Expand Down Expand Up @@ -172,7 +206,7 @@ def plot_quantiles_by_latlon(df, quantiles):
scale=alt.Scale(zero=False),
title="Latitude",
),
color=alt.Color("value:Q", scale=alt.Scale(scheme="viridis"), title="Prediction"),
color=alt.Color("value:Q", scale=alt.Scale(scheme="cividis"), title="Prediction"),
size=alt.Size("Population:Q"),
tooltip=[
alt.Tooltip("index:N", title="Row ID"),
Expand All @@ -190,32 +224,5 @@ def plot_quantiles_by_latlon(df, quantiles):
return chart


if not load_existing:
fit_and_upload_model(token, repo_id)

# Download the repository locally.
local_dir = "./local_repo"
if os.path.exists(local_dir):
shutil.rmtree(local_dir)
hub_utils.download(repo_id=repo_id, dst=local_dir)

# Load the fitted model.
model_filename = "model.pkl"
with open(f"{local_dir}/{model_filename}", "rb") as file:
qrf = pickle.load(file)
shutil.rmtree(local_dir)

# Estimate quantiles.
quantiles = list((np.arange(11) * 10) / 100)
X, y = datasets.fetch_california_housing(as_frame=True, return_X_y=True)
y_pred = qrf.predict(X, quantiles=quantiles) * 100_000 # predict in dollars

df = (
pd.DataFrame(y_pred, columns=quantiles)
.reset_index()
.melt(id_vars=["index"], var_name="quantile", value_name="value")
.merge(X[["Latitude", "Longitude", "Population"]].reset_index(), on="index", how="right")
)

chart = plot_quantiles_by_latlon(df, quantiles)
chart
62 changes: 40 additions & 22 deletions examples/plot_predict_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

np.random.seed(0)

n_test_samples = 10


def predict(reg, X, quantiles=0.5, what=None):
"""Custom prediction method that allows user-specified function.
Expand Down Expand Up @@ -65,7 +67,7 @@ def predict(reg, X, quantiles=0.5, what=None):


X, y = datasets.load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1, random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=n_test_samples, random_state=0)

reg = RandomForestQuantileRegressor().fit(X_train, y_train)

Expand All @@ -75,30 +77,41 @@ def predict(reg, X, quantiles=0.5, what=None):
# Output array with the user-specified function applied to each sample's empirical distribution.
y_out = predict(reg, X_test, what=func)

# Calculate the ECDF from output array.
y_ecdf = [sp.stats.ecdf(y_i).cdf for y_i in y_out]

df = pd.DataFrame(
{
"y_value": list(chain.from_iterable([y_i.quantiles for y_i in y_ecdf])),
"y_value2": list(chain.from_iterable([y_i.quantiles for y_i in y_ecdf]))[1:] + [np.nan],
"probability": list(chain.from_iterable([y_i.probabilities for y_i in y_ecdf])),
}
)
dfs = []
for idx in range(n_test_samples):
# Calculate the ECDF from output array.
y_ecdf = [sp.stats.ecdf(y_i).cdf for y_i in y_out[idx].reshape(1, -1)]
n_quantiles = len(list(chain.from_iterable([y_i.quantiles for y_i in y_ecdf])))

df_i = pd.DataFrame(
{
"y_val": list(chain.from_iterable([y_i.quantiles for y_i in y_ecdf])),
"y_val2": list(chain.from_iterable([y_i.quantiles for y_i in y_ecdf]))[1:] + [np.nan],
"proba": list(chain.from_iterable([y_i.probabilities for y_i in y_ecdf])),
"sample_idx": [idx] * n_quantiles,
}
)
dfs.append(df_i)
df = pd.concat(dfs)


def plot_ecdf(df):
min_idx = df["sample_idx"].min()
max_idx = df["sample_idx"].max()
slider = alt.binding_range(min=min_idx, max=max_idx, step=1, name="Sample Index:")
sample_selection = alt.param(value=0, bind=slider, name="sample_idx")

tooltip = [
alt.Tooltip("y_value", title="Response Value"),
alt.Tooltip("probability", title="Probability"),
alt.Tooltip("y_val", title="Response Value"),
alt.Tooltip("proba", title="Probability"),
]

circles = (
alt.Chart(df)
.mark_circle(color="#006aff", opacity=1, size=50)
.encode(
x=alt.X("y_value", title="Response Value"),
y=alt.Y("probability", title="Probability"),
x=alt.X("y_val", title="Response Value"),
y=alt.Y("proba", title="Probability"),
tooltip=tooltip,
)
)
Expand All @@ -107,17 +120,22 @@ def plot_ecdf(df):
alt.Chart(df)
.mark_line(color="#006aff", size=2)
.encode(
x=alt.X("y_value", title="Response Value"),
x2=alt.X2("y_value2"),
y=alt.Y("probability", title="Probability"),
x=alt.X("y_val", title="Response Value"),
x2=alt.X2("y_val2"),
y=alt.Y("proba", title="Probability"),
tooltip=tooltip,
)
)

chart = (circles + lines).properties(
height=400,
width=650,
title="Empirical Cumulative Distribution Function (ECDF) Plot",
chart = (
(circles + lines)
.transform_filter(alt.datum.sample_idx == sample_selection)
.add_params(sample_selection)
.properties(
height=400,
width=650,
title="Empirical Cumulative Distribution Function (ECDF) Plot",
)
)
return chart

Expand Down
82 changes: 55 additions & 27 deletions examples/plot_quantile_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@

from quantile_forest import RandomForestQuantileRegressor

intervals = list(np.arange(101) / 100)

# Create toy dataset.
X = np.array([[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1]])
y = np.array([-2, -1, 0, 1, 2])

est = RandomForestQuantileRegressor(
qrf = RandomForestQuantileRegressor(
n_estimators=1,
max_samples_leaf=None,
bootstrap=False,
random_state=0,
)
est.fit(X, y)
qrf.fit(X, y)

interpolations = {
"Linear": "#006aff",
Expand All @@ -41,30 +43,43 @@
legend = {"Actual": "#000000"}
legend.update(interpolations)

# Initialize data with actual values.
data = {
"method": ["Actual"] * len(y),
"X": [f"Sample {idx + 1} ({x})" for idx, x in enumerate(X.tolist())],
"y_pred": y.tolist(),
"y_pred_low": y.tolist(),
"y_pred_upp": y.tolist(),
}

# Populate data based on prediction results with different interpolations.
for interpolation in interpolations:
# Get predictions at 95% prediction intervals and median.
y_pred = est.predict(X, quantiles=[0.025, 0.5, 0.975], interpolation=interpolation.lower())

data["method"].extend([interpolation] * len(y))
data["X"].extend([f"Sample {idx + 1} ({x})" for idx, x in enumerate(X.tolist())])
data["y_pred"].extend(y_pred[:, 1])
data["y_pred_low"].extend(y_pred[:, 0])
data["y_pred_upp"].extend(y_pred[:, 2])

df = pd.DataFrame(data)
dfs = []
for idx, interval in enumerate(intervals):
# Initialize data with actual values.
data = {
"method": ["Actual"] * len(y),
"X": [f"Sample {idx + 1} ({x})" for idx, x in enumerate(X.tolist())],
"y_pred": y.tolist(),
"y_pred_low": y.tolist(),
"y_pred_upp": y.tolist(),
"quantile_low": [None] * len(y),
"quantile_upp": [None] * len(y),
}

# Populate data based on prediction results with different interpolations.
for interpolation in interpolations:
# Get predictions at median and prediction intervals.
quantiles = [0.5, round(0.5 - interval / 2, 3), round(0.5 + interval / 2, 3)]
y_pred = qrf.predict(X, quantiles=quantiles, interpolation=interpolation.lower())

data["method"].extend([interpolation] * len(y))
data["X"].extend([f"Sample {idx + 1} ({x})" for idx, x in enumerate(X.tolist())])
data["y_pred"].extend(y_pred[:, 0])
data["y_pred_low"].extend(y_pred[:, 1])
data["y_pred_upp"].extend(y_pred[:, 2])
data["quantile_low"].extend([quantiles[1]] * len(y))
data["quantile_upp"].extend([quantiles[2]] * len(y))

df_i = pd.DataFrame(data)
dfs.append(df_i)
df = pd.concat(dfs)


def plot_interpolations(df, legend):
slider = alt.binding_range(min=0, max=1, step=0.01, name="Prediction Interval:")
interval_selection = alt.param(value=0.95, bind=slider, name="interval")
interval_tol = 0.001

click = alt.selection_point(fields=["method"], bind="legend")

color = alt.condition(
Expand All @@ -76,9 +91,11 @@ def plot_interpolations(df, legend):
tooltip = [
alt.Tooltip("method:N", title="Method"),
alt.Tooltip("X:N", title="X Values"),
alt.Tooltip("y_pred:N", format=".3f", title="Median Y Value"),
alt.Tooltip("y_pred_low:N", format=".3f", title="Lower Y Value"),
alt.Tooltip("y_pred_upp:N", format=".3f", title="Upper Y Value"),
alt.Tooltip("y_pred:Q", format=".3f", title="Predicted Y"),
alt.Tooltip("y_pred_low:Q", format=".3f", title="Predicted Lower Y"),
alt.Tooltip("y_pred_upp:Q", format=".3f", title="Predicted Upper Y"),
alt.Tooltip("quantile_low:Q", format=".3f", title="Lower Quantile"),
alt.Tooltip("quantile_upp:Q", format=".3f", title="Upper Quantile"),
]

point = (
Expand Down Expand Up @@ -116,7 +133,18 @@ def plot_interpolations(df, legend):

chart = (
(area + point)
.add_params(click)
.transform_filter(
(
(alt.datum.quantile_low >= (0.5 - interval_selection / 2 - interval_tol))
& (alt.datum.quantile_low <= (0.5 - interval_selection / 2 + interval_tol))
)
| (
(alt.datum.quantile_upp >= (0.5 + interval_selection / 2 - interval_tol))
& (alt.datum.quantile_upp <= (0.5 + interval_selection / 2 + interval_tol))
)
| (alt.datum.method == "Actual")
)
.add_params(interval_selection, click)
.properties(height=400)
.facet(
column=alt.Column(
Expand Down
Loading

0 comments on commit 18f3275

Please sign in to comment.