From 4ee6d6e6d9de6753c9294149d40a1f2de2f7f9af Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Mon, 5 Aug 2024 04:32:45 -0700 Subject: [PATCH] Update example plots --- examples/plot_proximity_counts.py | 109 +++++++++++++++--------------- 1 file changed, 56 insertions(+), 53 deletions(-) diff --git a/examples/plot_proximity_counts.py b/examples/plot_proximity_counts.py index 64352f3..c9f7d92 100644 --- a/examples/plot_proximity_counts.py +++ b/examples/plot_proximity_counts.py @@ -7,14 +7,12 @@ target values are not used during model fitting. In this scenario, we train a QRF on a noisy dataset to predict individual pixel values (i.e., denoise). We then retrieve the proximity values for samples in a noisy test set. For each -test sample digit, we visualize it alongside a set of similar training samples -determined by their proximity counts, as well as the non-noisy digit. The -model is trained only on noisy training samples, but we visualize the -non-noisy training proximities. The similar samples are ordered from the -highest to the lowest proximity count for each digit, arranged from left to -right and top to bottom. This example illustrates the effectiveness of -proximity counts in identifying similar samples, even when using noisy -training and test data. +test sample digit, we visualize it alongside a set of similar (non-noisy) +training samples determined by their proximity counts, as well as the +non-noisy digit. The similar samples are ordered from the highest to the +lowest proximity count for each digit, arranged from left to right and top to +bottom. This example illustrates the effectiveness of proximity counts in +identifying similar samples, even when using noisy training and test data. """ import altair as alt @@ -61,6 +59,19 @@ def add_gaussian_noise(X, mean=0, std=0.1, random_state=None): return X_noisy +def combine_floats(df1, df2, scale=100): + """Combine two floats from separate data frames into a single number.""" + combined_df = df1 * scale + df2 + return combined_df + + +def extract_floats(combined_df, scale=100): + """Extract the original floats from the combined data frame.""" + df1 = np.floor(combined_df / scale) + df2 = combined_df - (df1 * scale) + return df1, df2 + + # Randomly corrupt a fraction of the training and test data. X_train_corrupt = X_train.pipe(add_gaussian_noise, std=noise_std, random_state=rng) X_test_corrupt = X_test.pipe(add_gaussian_noise, std=noise_std, random_state=rng) @@ -77,7 +88,7 @@ def add_gaussian_noise(X, mean=0, std=0.1, random_state=None): ) df = ( - X_test.join(X_test_corrupt.add_suffix("_corrupt")) + combine_floats(X_test, X_test_corrupt) .join(y_test) .reset_index() .join(df_prox) @@ -95,7 +106,7 @@ def add_gaussian_noise(X, mean=0, std=0.1, random_state=None): ) df_lookup = ( - X_train.join(X_train_corrupt.add_suffix("_corrupt")) + combine_floats(X_train, X_train_corrupt) .assign(**{"index": np.arange(len(X_train))}) .join(y_train) ) @@ -112,7 +123,7 @@ def plot_digits_proximities( ): n_samples = df["index"].nunique() n_subplot_rows = n_prox // n_prox_per_row - subplot_dim = (width - subplot_spacing * (2 + max(n_subplot_rows - 2, 0))) / n_subplot_rows + subplot_dim = (width - subplot_spacing * (n_subplot_rows - 1)) / n_subplot_rows slider = alt.binding_range( min=0, @@ -127,30 +138,26 @@ def plot_digits_proximities( fields=["index"], ) - color = alt.Color("value:Q", legend=None, scale=alt.Scale(scheme="greys")) - opacity = alt.condition(alt.datum["value"] == 0, alt.value(0), alt.value(1)) - base = alt.Chart(df).add_params(idx_val).transform_filter(idx_val) chart1 = ( - base.transform_fold( - fold=[f"pixel_{y}_{x}_corrupt" for y in range(8) for x in range(8)], - as_=["pixel", "value_cor"], - ) - .transform_calculate( - x="substring(datum.pixel, 8, 9)", - y="substring(datum.pixel, 6, 7)", + base.transform_filter(f"datum.prox_idx == 0") + .transform_fold( + fold=[f"pixel_{y}_{x}" for y in range(8) for x in range(8)], + as_=["pixel", "value"], ) + .transform_calculate(value_orig="floor(datum.value / 100)") + .transform_calculate(value_corr="datum.value - (datum.value_orig * 100)") + .transform_calculate(x="substring(datum.pixel, 8, 9)", y="substring(datum.pixel, 6, 7)") .mark_rect() .encode( x=alt.X("x:N", axis=None), y=alt.Y("y:N", axis=None), - color=alt.Color("value_cor:Q", legend=None, scale=alt.Scale(scheme="greys")), - # opacity=alt.condition(alt.datum["value_cor"] == -1, alt.value(0), alt.value(1)), - opacity=alt.condition(alt.datum["value_cor"] == 0, alt.value(0), alt.value(1)), + color=alt.Color("value_corr:Q", legend=None, scale=alt.Scale(scheme="greys")), + opacity=alt.condition(alt.datum["value_corr"] == 0, alt.value(0), alt.value(0.67)), tooltip=[ alt.Tooltip("target:Q", title="Digit"), - alt.Tooltip("value_cor:Q", title="Pixel Value"), + alt.Tooltip("value_corr:Q", format=".3f", title="Pixel Value"), alt.Tooltip("x:Q", title="Pixel X"), alt.Tooltip("y:Q", title="Pixel Y"), ], @@ -158,17 +165,24 @@ def plot_digits_proximities( .properties(height=height, width=width, title="Test Digit (corrupted)") ) - chart2_i = ( + chart2 = ( base.mark_rect() + .transform_filter(f"datum.prox_idx < {n_prox}") .encode( x=alt.X("x:N", axis=None), y=alt.Y("y:N", axis=None), - color=color, - opacity=opacity, + color=alt.Color("value_orig:Q", legend=None, scale=alt.Scale(scheme="greys")), + opacity=alt.condition(alt.datum["value_orig"] == 0, alt.value(0), alt.value(0.67)), tooltip=[ alt.Tooltip("prox_cnt", title="Proximity Count"), alt.Tooltip("target:Q", title="Digit"), ], + facet=alt.Facet( + "prox_idx:N", + columns=n_prox // n_prox_per_row, + title=None, + header=alt.Header(labels=False, labelFontSize=0, labelPadding=0), + ), ) .transform_lookup( lookup="prox_val", @@ -182,42 +196,30 @@ def plot_digits_proximities( fold=[f"pixel_{y}_{x}" for y in range(8) for x in range(8)], as_=["pixel", "value"], ) - .transform_calculate( - x="substring(datum.pixel, 8, 9)", - y="substring(datum.pixel, 6, 7)", + .transform_calculate(value_orig="floor(datum.value / 100)") + .transform_calculate(x="substring(datum.pixel, 8, 9)", y="substring(datum.pixel, 6, 7)") + .properties( + height=subplot_dim, width=subplot_dim, title=f"Proximity Digits (top {n_prox})" ) - .properties(height=subplot_dim, width=subplot_dim) ) - chart2_plots = [chart2_i.transform_filter(f"datum.prox_idx == {i}") for i in range(n_prox)] - chart2_rows = [chart2_plots[i : i + n_prox_per_row] for i in range(0, n_prox, n_prox_per_row)] - - chart2 = alt.hconcat() - for row in chart2_rows: - rowplot = alt.vconcat() - for item in row: - rowplot |= item - chart2 &= rowplot - chart2 = chart2.properties(title=f"Proximity Digits (top {n_prox})") - chart3 = ( - base.transform_fold( + base.transform_filter(f"datum.prox_idx == 0") + .transform_fold( fold=[f"pixel_{y}_{x}" for y in range(8) for x in range(8)], as_=["pixel", "value"], ) - .transform_calculate( - x="substring(datum.pixel, 8, 9)", - y="substring(datum.pixel, 6, 7)", - ) + .transform_calculate(value_orig="floor(datum.value / 100)") + .transform_calculate(x="substring(datum.pixel, 8, 9)", y="substring(datum.pixel, 6, 7)") .mark_rect() .encode( x=alt.X("x:N", axis=None), y=alt.Y("y:N", axis=None), - color=alt.Color("value:Q", legend=None, scale=alt.Scale(scheme="greys")), - opacity=alt.condition(alt.datum["value"] == 0, alt.value(0), alt.value(1)), + color=alt.Color("value_orig:Q", legend=None, scale=alt.Scale(scheme="greys")), + opacity=alt.condition(alt.datum["value_orig"] == 0, alt.value(0), alt.value(0.67)), tooltip=[ alt.Tooltip("target:Q", title="Digit"), - alt.Tooltip("value:Q", title="Pixel Value"), + alt.Tooltip("value_orig:Q", title="Pixel Value"), alt.Tooltip("x:Q", title="Pixel X"), alt.Tooltip("y:Q", title="Pixel Y"), ], @@ -225,11 +227,12 @@ def plot_digits_proximities( .properties(height=height, width=width, title="Test Digit (original)") ) - chart_spacer = alt.Chart(pd.DataFrame()).mark_rect().properties(width=subplot_dim) + chart_spacer = alt.Chart(pd.DataFrame()).mark_rect().properties(width=subplot_dim * 2) chart = ( (chart1 | chart_spacer | chart2 | chart_spacer | chart3) - .configure_concat(spacing=subplot_spacing) + .configure_concat(spacing=0) + .configure_facet(spacing=subplot_spacing) .configure_title(anchor="middle") .configure_view(strokeOpacity=0) )