Skip to content

Commit

Permalink
Update example plots
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Aug 5, 2024
1 parent 3a20fdd commit 4ee6d6e
Showing 1 changed file with 56 additions and 53 deletions.
109 changes: 56 additions & 53 deletions examples/plot_proximity_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
target values are not used during model fitting. In this scenario, we train a
QRF on a noisy dataset to predict individual pixel values (i.e., denoise). We
then retrieve the proximity values for samples in a noisy test set. For each
test sample digit, we visualize it alongside a set of similar training samples
determined by their proximity counts, as well as the non-noisy digit. The
model is trained only on noisy training samples, but we visualize the
non-noisy training proximities. The similar samples are ordered from the
highest to the lowest proximity count for each digit, arranged from left to
right and top to bottom. This example illustrates the effectiveness of
proximity counts in identifying similar samples, even when using noisy
training and test data.
test sample digit, we visualize it alongside a set of similar (non-noisy)
training samples determined by their proximity counts, as well as the
non-noisy digit. The similar samples are ordered from the highest to the
lowest proximity count for each digit, arranged from left to right and top to
bottom. This example illustrates the effectiveness of proximity counts in
identifying similar samples, even when using noisy training and test data.
"""

import altair as alt
Expand Down Expand Up @@ -61,6 +59,19 @@ def add_gaussian_noise(X, mean=0, std=0.1, random_state=None):
return X_noisy


def combine_floats(df1, df2, scale=100):
"""Combine two floats from separate data frames into a single number."""
combined_df = df1 * scale + df2
return combined_df


def extract_floats(combined_df, scale=100):
"""Extract the original floats from the combined data frame."""
df1 = np.floor(combined_df / scale)
df2 = combined_df - (df1 * scale)
return df1, df2


# Randomly corrupt a fraction of the training and test data.
X_train_corrupt = X_train.pipe(add_gaussian_noise, std=noise_std, random_state=rng)
X_test_corrupt = X_test.pipe(add_gaussian_noise, std=noise_std, random_state=rng)
Expand All @@ -77,7 +88,7 @@ def add_gaussian_noise(X, mean=0, std=0.1, random_state=None):
)

df = (
X_test.join(X_test_corrupt.add_suffix("_corrupt"))
combine_floats(X_test, X_test_corrupt)
.join(y_test)
.reset_index()
.join(df_prox)
Expand All @@ -95,7 +106,7 @@ def add_gaussian_noise(X, mean=0, std=0.1, random_state=None):
)

df_lookup = (
X_train.join(X_train_corrupt.add_suffix("_corrupt"))
combine_floats(X_train, X_train_corrupt)
.assign(**{"index": np.arange(len(X_train))})
.join(y_train)
)
Expand All @@ -112,7 +123,7 @@ def plot_digits_proximities(
):
n_samples = df["index"].nunique()
n_subplot_rows = n_prox // n_prox_per_row
subplot_dim = (width - subplot_spacing * (2 + max(n_subplot_rows - 2, 0))) / n_subplot_rows
subplot_dim = (width - subplot_spacing * (n_subplot_rows - 1)) / n_subplot_rows

slider = alt.binding_range(
min=0,
Expand All @@ -127,48 +138,51 @@ def plot_digits_proximities(
fields=["index"],
)

color = alt.Color("value:Q", legend=None, scale=alt.Scale(scheme="greys"))
opacity = alt.condition(alt.datum["value"] == 0, alt.value(0), alt.value(1))

base = alt.Chart(df).add_params(idx_val).transform_filter(idx_val)

chart1 = (
base.transform_fold(
fold=[f"pixel_{y}_{x}_corrupt" for y in range(8) for x in range(8)],
as_=["pixel", "value_cor"],
)
.transform_calculate(
x="substring(datum.pixel, 8, 9)",
y="substring(datum.pixel, 6, 7)",
base.transform_filter(f"datum.prox_idx == 0")
.transform_fold(
fold=[f"pixel_{y}_{x}" for y in range(8) for x in range(8)],
as_=["pixel", "value"],
)
.transform_calculate(value_orig="floor(datum.value / 100)")
.transform_calculate(value_corr="datum.value - (datum.value_orig * 100)")
.transform_calculate(x="substring(datum.pixel, 8, 9)", y="substring(datum.pixel, 6, 7)")
.mark_rect()
.encode(
x=alt.X("x:N", axis=None),
y=alt.Y("y:N", axis=None),
color=alt.Color("value_cor:Q", legend=None, scale=alt.Scale(scheme="greys")),
# opacity=alt.condition(alt.datum["value_cor"] == -1, alt.value(0), alt.value(1)),
opacity=alt.condition(alt.datum["value_cor"] == 0, alt.value(0), alt.value(1)),
color=alt.Color("value_corr:Q", legend=None, scale=alt.Scale(scheme="greys")),
opacity=alt.condition(alt.datum["value_corr"] == 0, alt.value(0), alt.value(0.67)),
tooltip=[
alt.Tooltip("target:Q", title="Digit"),
alt.Tooltip("value_cor:Q", title="Pixel Value"),
alt.Tooltip("value_corr:Q", format=".3f", title="Pixel Value"),
alt.Tooltip("x:Q", title="Pixel X"),
alt.Tooltip("y:Q", title="Pixel Y"),
],
)
.properties(height=height, width=width, title="Test Digit (corrupted)")
)

chart2_i = (
chart2 = (
base.mark_rect()
.transform_filter(f"datum.prox_idx < {n_prox}")
.encode(
x=alt.X("x:N", axis=None),
y=alt.Y("y:N", axis=None),
color=color,
opacity=opacity,
color=alt.Color("value_orig:Q", legend=None, scale=alt.Scale(scheme="greys")),
opacity=alt.condition(alt.datum["value_orig"] == 0, alt.value(0), alt.value(0.67)),
tooltip=[
alt.Tooltip("prox_cnt", title="Proximity Count"),
alt.Tooltip("target:Q", title="Digit"),
],
facet=alt.Facet(
"prox_idx:N",
columns=n_prox // n_prox_per_row,
title=None,
header=alt.Header(labels=False, labelFontSize=0, labelPadding=0),
),
)
.transform_lookup(
lookup="prox_val",
Expand All @@ -182,54 +196,43 @@ def plot_digits_proximities(
fold=[f"pixel_{y}_{x}" for y in range(8) for x in range(8)],
as_=["pixel", "value"],
)
.transform_calculate(
x="substring(datum.pixel, 8, 9)",
y="substring(datum.pixel, 6, 7)",
.transform_calculate(value_orig="floor(datum.value / 100)")
.transform_calculate(x="substring(datum.pixel, 8, 9)", y="substring(datum.pixel, 6, 7)")
.properties(
height=subplot_dim, width=subplot_dim, title=f"Proximity Digits (top {n_prox})"
)
.properties(height=subplot_dim, width=subplot_dim)
)

chart2_plots = [chart2_i.transform_filter(f"datum.prox_idx == {i}") for i in range(n_prox)]
chart2_rows = [chart2_plots[i : i + n_prox_per_row] for i in range(0, n_prox, n_prox_per_row)]

chart2 = alt.hconcat()
for row in chart2_rows:
rowplot = alt.vconcat()
for item in row:
rowplot |= item
chart2 &= rowplot
chart2 = chart2.properties(title=f"Proximity Digits (top {n_prox})")

chart3 = (
base.transform_fold(
base.transform_filter(f"datum.prox_idx == 0")
.transform_fold(
fold=[f"pixel_{y}_{x}" for y in range(8) for x in range(8)],
as_=["pixel", "value"],
)
.transform_calculate(
x="substring(datum.pixel, 8, 9)",
y="substring(datum.pixel, 6, 7)",
)
.transform_calculate(value_orig="floor(datum.value / 100)")
.transform_calculate(x="substring(datum.pixel, 8, 9)", y="substring(datum.pixel, 6, 7)")
.mark_rect()
.encode(
x=alt.X("x:N", axis=None),
y=alt.Y("y:N", axis=None),
color=alt.Color("value:Q", legend=None, scale=alt.Scale(scheme="greys")),
opacity=alt.condition(alt.datum["value"] == 0, alt.value(0), alt.value(1)),
color=alt.Color("value_orig:Q", legend=None, scale=alt.Scale(scheme="greys")),
opacity=alt.condition(alt.datum["value_orig"] == 0, alt.value(0), alt.value(0.67)),
tooltip=[
alt.Tooltip("target:Q", title="Digit"),
alt.Tooltip("value:Q", title="Pixel Value"),
alt.Tooltip("value_orig:Q", title="Pixel Value"),
alt.Tooltip("x:Q", title="Pixel X"),
alt.Tooltip("y:Q", title="Pixel Y"),
],
)
.properties(height=height, width=width, title="Test Digit (original)")
)

chart_spacer = alt.Chart(pd.DataFrame()).mark_rect().properties(width=subplot_dim)
chart_spacer = alt.Chart(pd.DataFrame()).mark_rect().properties(width=subplot_dim * 2)

chart = (
(chart1 | chart_spacer | chart2 | chart_spacer | chart3)
.configure_concat(spacing=subplot_spacing)
.configure_concat(spacing=0)
.configure_facet(spacing=subplot_spacing)
.configure_title(anchor="middle")
.configure_view(strokeOpacity=0)
)
Expand Down

0 comments on commit 4ee6d6e

Please sign in to comment.