Skip to content

Commit

Permalink
Update example plots
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Aug 3, 2024
1 parent 6344ce2 commit 2901cc9
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions examples/plot_proximity_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def fillna(df):
)

df_test = digits_wide_to_long(X_test).merge(y_test.reset_index(), on="index", how="left")
df_test_corrupt = digits_wide_to_long(X_test_corrupt).rename(columns={"value": "value_corrupt"})
df_test_corrupt = digits_wide_to_long(X_test_corrupt).rename(columns={"value": "value_cor"})

df_prox = pd.DataFrame(
{
Expand All @@ -113,19 +113,17 @@ def fillna(df):
df = (
df_test.merge(df_test_corrupt, on=["index", "x", "y"], how="left")
.merge(df_prox, on="index", how="left")
.rename(columns={"index": "samp_idx"})
.explode("prox")
.assign(
**{
"index": lambda x: pd.factorize(x["samp_idx"])[0],
"index": lambda x: pd.factorize(x["index"])[0],
"prox_idx": lambda x: x["prox"].apply(lambda y: y[0]),
"prox_val": lambda x: x["prox"].apply(lambda y: y[1]),
"prox_count": lambda x: x["prox"].apply(lambda y: y[2]),
"samp_id": lambda x: get_digits_id(x, "samp_idx"),
"prox_cnt": lambda x: x["prox"].apply(lambda y: y[2]),
"prox_id": lambda x: get_digits_id(x, "prox_val"),
}
)
.drop(columns=["prox"])
.drop(columns=["prox", "prox_val"])
)


Expand All @@ -137,7 +135,7 @@ def plot_digits_proximities(
height=225,
width=225,
):
n_samples = df["samp_idx"].nunique()
n_samples = df["index"].nunique()
n_subplot_rows = n_prox // n_prox_per_row
subplot_dim = (width - subplot_spacing * (2 + max(n_subplot_rows - 2, 0))) / n_subplot_rows

Expand All @@ -164,13 +162,12 @@ def plot_digits_proximities(
.encode(
x=alt.X("x:N", axis=None),
y=alt.Y("y:N", axis=None),
color=alt.Color("value_corrupt:Q", legend=None, scale=alt.Scale(scheme="greys")),
# opacity=alt.condition(alt.datum["value_corrupt"] == -1, alt.value(0), alt.value(1)),
opacity=alt.condition(alt.datum["value_corrupt"] == 0, alt.value(0), alt.value(1)),
color=alt.Color("value_cor:Q", legend=None, scale=alt.Scale(scheme="greys")),
# opacity=alt.condition(alt.datum["value_cor"] == -1, alt.value(0), alt.value(1)),
opacity=alt.condition(alt.datum["value_cor"] == 0, alt.value(0), alt.value(1)),
tooltip=[
alt.Tooltip("target:Q", title="Digit"),
alt.Tooltip("value_corrupt:Q", title="Pixel Value"),
alt.Tooltip("value:Q", title="Pixel Value (original)"),
alt.Tooltip("value_cor:Q", title="Pixel Value"),
alt.Tooltip("x:Q", title="Pixel X"),
alt.Tooltip("y:Q", title="Pixel Y"),
],
Expand All @@ -186,7 +183,7 @@ def plot_digits_proximities(
color=color,
opacity=opacity,
tooltip=[
alt.Tooltip("prox_count", title="Training Proximity Count"),
alt.Tooltip("prox_cnt", title="Proximity Count"),
alt.Tooltip("target:Q", title="Digit"),
],
)
Expand Down

0 comments on commit 2901cc9

Please sign in to comment.