From cff164eef7c9d64eaa143a551ee3dd801fb7eb10 Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Fri, 9 Aug 2024 04:35:55 -0700 Subject: [PATCH] Update example plots --- examples/plot_huggingface_model.py | 10 +++------- examples/plot_quantile_conformalized.py | 19 ++++++++++--------- examples/plot_quantile_ranks.py | 16 +++++++++------- examples/plot_quantile_vs_standard.py | 12 +++++------- examples/plot_treeshap_example.py | 23 +++++++++++------------ 5 files changed, 38 insertions(+), 42 deletions(-) diff --git a/examples/plot_huggingface_model.py b/examples/plot_huggingface_model.py index f2a546c..9508bfb 100755 --- a/examples/plot_huggingface_model.py +++ b/examples/plot_huggingface_model.py @@ -183,16 +183,12 @@ def plot_quantiles_by_latlon(df, quantiles): name="Predicted Quantile: ", ) - q_val = alt.selection_point( - value=0.5, - bind=slider, - fields=["quantile"], - ) + quantile_selection = alt.param(value=0.5, bind=slider, name="quantile") chart = ( alt.Chart(df) - .add_params(q_val) - .transform_filter(f"datum['quantile'] == {q_val['quantile']}") + .add_params(quantile_selection) + .transform_filter("datum.quantile == quantile") .mark_circle() .encode( x=alt.X( diff --git a/examples/plot_quantile_conformalized.py b/examples/plot_quantile_conformalized.py index 2b2153e..52d1857 100755 --- a/examples/plot_quantile_conformalized.py +++ b/examples/plot_quantile_conformalized.py @@ -165,13 +165,6 @@ def plot_prediction_intervals(df, domain): click = alt.selection_point(fields=["y_label"], bind="legend") - color_circle = alt.Color( - "y_label:N", - scale=alt.Scale(domain=["Yes", "No"], range=["#f2a619", "red"]), - title="Within Interval", - ) - color_bar = alt.value("#e0f2ff") - tooltip = [ alt.Tooltip("y_test:Q", format="$,d", title="True Price"), alt.Tooltip("y_pred:Q", format="$,d", title="Predicted Price"), @@ -210,7 +203,15 @@ def plot_prediction_intervals(df, domain): scale=alt.Scale(domain=domain, nice=False), title="Predicted Prices", ), - color=alt.condition(click, color_circle, alt.value("lightgray")), + color=alt.condition( + click, + alt.Color( + "y_label:N", + scale=alt.Scale(domain=["Yes", "No"], range=["#f2a619", "red"]), + title="Within Interval", + ), + alt.value("lightgray"), + ), opacity=alt.condition(click, alt.value(1), alt.value(0)), tooltip=tooltip, ) @@ -220,7 +221,7 @@ def plot_prediction_intervals(df, domain): x=alt.X("y_pred:Q", scale=alt.Scale(domain=domain, padding=0), title=""), y=alt.Y("y_pred_low:Q", scale=alt.Scale(domain=domain, padding=0), title=""), y2=alt.Y2("y_pred_upp:Q", title=None), - color=alt.condition(click, color_bar, alt.value("lightgray")), + color=alt.condition(click, alt.value("#e0f2ff"), alt.value("lightgray")), opacity=alt.condition(click, alt.value(0.8), alt.value(0)), tooltip=tooltip, ) diff --git a/examples/plot_quantile_ranks.py b/examples/plot_quantile_ranks.py index 3ec3a7b..746052e 100644 --- a/examples/plot_quantile_ranks.py +++ b/examples/plot_quantile_ranks.py @@ -60,12 +60,6 @@ def plot_fit_and_ranks(df): base = alt.Chart(df) - color_points = alt.Color( - "outlier:N", - scale=alt.Scale(domain=["Yes", "No"], range=["red", "#f2a619"]), - title="Outlier", - ) - points = ( base.add_params(interval_selection, click) .transform_calculate( @@ -75,7 +69,15 @@ def plot_fit_and_ranks(df): .encode( x=alt.X("x:Q"), y=alt.Y("y:Q"), - color=alt.condition(click, color_points, alt.value("lightgray")), + color=alt.condition( + click, + alt.Color( + "outlier:N", + scale=alt.Scale(domain=["Yes", "No"], range=["red", "#f2a619"]), + title="Outlier", + ), + alt.value("lightgray"), + ), tooltip=[ alt.Tooltip("x:Q", format=".3f", title="x"), alt.Tooltip("y:Q", format=".3f", title="f(x)"), diff --git a/examples/plot_quantile_vs_standard.py b/examples/plot_quantile_vs_standard.py index 6601150..56aa5ea 100755 --- a/examples/plot_quantile_vs_standard.py +++ b/examples/plot_quantile_vs_standard.py @@ -76,12 +76,6 @@ def plot_prediction_histograms(df, legend): click = alt.selection_point(fields=["label"], bind="legend") - color = alt.condition( - click, - alt.Color("label:N", sort=list(legend.keys()), title=None), - alt.value("lightgray"), - ) - chart = ( alt.Chart(df) .add_params(quantile_selection, click) @@ -102,7 +96,11 @@ def plot_prediction_histograms(df, legend): title="Actual and Predicted Target Values", ), y=alt.Y("count():Q", axis=alt.Axis(format=",d", title="Counts")), - color=color, + color=alt.condition( + click, + alt.Color("label:N", sort=list(legend.keys()), title=None), + alt.value("lightgray"), + ), xOffset=alt.XOffset("label:N"), tooltip=[ alt.Tooltip("label:N", title="Label"), diff --git a/examples/plot_treeshap_example.py b/examples/plot_treeshap_example.py index 1532380..6bc13b2 100644 --- a/examples/plot_treeshap_example.py +++ b/examples/plot_treeshap_example.py @@ -106,8 +106,7 @@ def plot_shap_waterfall_with_quantiles(df, height=300): step=0.5 if len(quantiles) == 1 else 1 / (len(quantiles) - 1), name="Predicted Quantile: ", ) - - q_val = alt.selection_point(value=0.5, bind=slider, fields=["quantile"]) + quantile_selection = alt.param(value=0.5, bind=slider, name="quantile") df_grouped = ( df.groupby("quantile")[df.columns.tolist()] @@ -166,7 +165,7 @@ def plot_shap_waterfall_with_quantiles(df, height=300): base = ( alt.Chart(df_grouped) - .transform_filter(q_val) + .transform_filter("datum.quantile == quantile") .transform_calculate( end_shifted=f"datum.shap_value > 0 ? datum.end - {x_shift} : datum.end + {x_shift}" ) @@ -221,15 +220,15 @@ def plot_shap_waterfall_with_quantiles(df, height=300): ) text_label_start = ( alt.Chart(df_text_labels) - .transform_filter(q_val) - .transform_filter(alt.datum["type"] == "start") + .transform_filter("datum.quantile == quantile") + .transform_filter("datum.type == 'start'") .mark_text(align="left", color="black", dx=-16, dy=y_text_offset + 30) .encode(text=alt.Text("label"), x=alt.X("x:Q")) ) text_label_end = ( alt.Chart(df_text_labels) - .transform_filter(q_val) - .transform_filter(alt.datum["type"] == "end") + .transform_filter("datum.quantile == quantile") + .transform_filter("datum.type == 'end'") .mark_text(align="left", color="black", dx=-8, dy=-y_text_offset - 15) .encode(text=alt.Text("label"), x=alt.X("x:Q")) ) @@ -255,15 +254,15 @@ def plot_shap_waterfall_with_quantiles(df, height=300): ) tick_start_rule = ( alt.Chart(df_text_labels) - .transform_filter(q_val) - .transform_filter(alt.datum["type"] == "start") + .transform_filter("datum.quantile == quantile") + .transform_filter("datum.type == 'start'") .mark_rule(color="black", opacity=1, y=height, y2=height + 6) .encode(x=alt.X("x:Q")) ) tick_end_rule = ( alt.Chart(df_text_labels) - .transform_filter(q_val) - .transform_filter(alt.datum["type"] == "end") + .transform_filter("datum.quantile == quantile") + .transform_filter("datum.type == 'end'") .mark_rule(color="black", opacity=1, y=0, y2=-6) .encode(x=alt.X("x:Q")) ) @@ -271,7 +270,7 @@ def plot_shap_waterfall_with_quantiles(df, height=300): chart = ( (bars + points + text + rules) - .add_params(q_val) + .add_params(quantile_selection) .configure_view(strokeOpacity=0) .properties( width=600, height=height, title="Waterfall Plot of SHAP Values for QRF Predictions"