diff --git a/examples/plot_proximity_counts.py b/examples/plot_proximity_counts.py index 18addfc..e54e0af 100644 --- a/examples/plot_proximity_counts.py +++ b/examples/plot_proximity_counts.py @@ -101,7 +101,7 @@ def fillna(df): ) df_test = digits_wide_to_long(X_test).merge(y_test.reset_index(), on="index", how="left") -df_test_corrupt = digits_wide_to_long(X_test_corrupt).rename(columns={"value": "value_corrupt"}) +df_test_corrupt = digits_wide_to_long(X_test_corrupt).rename(columns={"value": "value_cor"}) df_prox = pd.DataFrame( { @@ -113,19 +113,17 @@ def fillna(df): df = ( df_test.merge(df_test_corrupt, on=["index", "x", "y"], how="left") .merge(df_prox, on="index", how="left") - .rename(columns={"index": "samp_idx"}) .explode("prox") .assign( **{ - "index": lambda x: pd.factorize(x["samp_idx"])[0], + "index": lambda x: pd.factorize(x["index"])[0], "prox_idx": lambda x: x["prox"].apply(lambda y: y[0]), "prox_val": lambda x: x["prox"].apply(lambda y: y[1]), - "prox_count": lambda x: x["prox"].apply(lambda y: y[2]), - "samp_id": lambda x: get_digits_id(x, "samp_idx"), + "prox_cnt": lambda x: x["prox"].apply(lambda y: y[2]), "prox_id": lambda x: get_digits_id(x, "prox_val"), } ) - .drop(columns=["prox"]) + .drop(columns=["prox", "prox_val"]) ) @@ -137,7 +135,7 @@ def plot_digits_proximities( height=225, width=225, ): - n_samples = df["samp_idx"].nunique() + n_samples = df["index"].nunique() n_subplot_rows = n_prox // n_prox_per_row subplot_dim = (width - subplot_spacing * (2 + max(n_subplot_rows - 2, 0))) / n_subplot_rows @@ -164,13 +162,12 @@ def plot_digits_proximities( .encode( x=alt.X("x:N", axis=None), y=alt.Y("y:N", axis=None), - color=alt.Color("value_corrupt:Q", legend=None, scale=alt.Scale(scheme="greys")), - # opacity=alt.condition(alt.datum["value_corrupt"] == -1, alt.value(0), alt.value(1)), - opacity=alt.condition(alt.datum["value_corrupt"] == 0, alt.value(0), alt.value(1)), + color=alt.Color("value_cor:Q", legend=None, scale=alt.Scale(scheme="greys")), + # opacity=alt.condition(alt.datum["value_cor"] == -1, alt.value(0), alt.value(1)), + opacity=alt.condition(alt.datum["value_cor"] == 0, alt.value(0), alt.value(1)), tooltip=[ alt.Tooltip("target:Q", title="Digit"), - alt.Tooltip("value_corrupt:Q", title="Pixel Value"), - alt.Tooltip("value:Q", title="Pixel Value (original)"), + alt.Tooltip("value_cor:Q", title="Pixel Value"), alt.Tooltip("x:Q", title="Pixel X"), alt.Tooltip("y:Q", title="Pixel Y"), ], @@ -186,7 +183,7 @@ def plot_digits_proximities( color=color, opacity=opacity, tooltip=[ - alt.Tooltip("prox_count", title="Training Proximity Count"), + alt.Tooltip("prox_cnt", title="Proximity Count"), alt.Tooltip("target:Q", title="Digit"), ], )