Skip to content

Commit

Permalink
Update example plots
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Aug 9, 2024
1 parent 8428f26 commit cff164e
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 42 deletions.
10 changes: 3 additions & 7 deletions examples/plot_huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,12 @@ def plot_quantiles_by_latlon(df, quantiles):
name="Predicted Quantile: ",
)

q_val = alt.selection_point(
value=0.5,
bind=slider,
fields=["quantile"],
)
quantile_selection = alt.param(value=0.5, bind=slider, name="quantile")

chart = (
alt.Chart(df)
.add_params(q_val)
.transform_filter(f"datum['quantile'] == {q_val['quantile']}")
.add_params(quantile_selection)
.transform_filter("datum.quantile == quantile")
.mark_circle()
.encode(
x=alt.X(
Expand Down
19 changes: 10 additions & 9 deletions examples/plot_quantile_conformalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,6 @@ def plot_prediction_intervals(df, domain):

click = alt.selection_point(fields=["y_label"], bind="legend")

color_circle = alt.Color(
"y_label:N",
scale=alt.Scale(domain=["Yes", "No"], range=["#f2a619", "red"]),
title="Within Interval",
)
color_bar = alt.value("#e0f2ff")

tooltip = [
alt.Tooltip("y_test:Q", format="$,d", title="True Price"),
alt.Tooltip("y_pred:Q", format="$,d", title="Predicted Price"),
Expand Down Expand Up @@ -210,7 +203,15 @@ def plot_prediction_intervals(df, domain):
scale=alt.Scale(domain=domain, nice=False),
title="Predicted Prices",
),
color=alt.condition(click, color_circle, alt.value("lightgray")),
color=alt.condition(
click,
alt.Color(
"y_label:N",
scale=alt.Scale(domain=["Yes", "No"], range=["#f2a619", "red"]),
title="Within Interval",
),
alt.value("lightgray"),
),
opacity=alt.condition(click, alt.value(1), alt.value(0)),
tooltip=tooltip,
)
Expand All @@ -220,7 +221,7 @@ def plot_prediction_intervals(df, domain):
x=alt.X("y_pred:Q", scale=alt.Scale(domain=domain, padding=0), title=""),
y=alt.Y("y_pred_low:Q", scale=alt.Scale(domain=domain, padding=0), title=""),
y2=alt.Y2("y_pred_upp:Q", title=None),
color=alt.condition(click, color_bar, alt.value("lightgray")),
color=alt.condition(click, alt.value("#e0f2ff"), alt.value("lightgray")),
opacity=alt.condition(click, alt.value(0.8), alt.value(0)),
tooltip=tooltip,
)
Expand Down
16 changes: 9 additions & 7 deletions examples/plot_quantile_ranks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,6 @@ def plot_fit_and_ranks(df):

base = alt.Chart(df)

color_points = alt.Color(
"outlier:N",
scale=alt.Scale(domain=["Yes", "No"], range=["red", "#f2a619"]),
title="Outlier",
)

points = (
base.add_params(interval_selection, click)
.transform_calculate(
Expand All @@ -75,7 +69,15 @@ def plot_fit_and_ranks(df):
.encode(
x=alt.X("x:Q"),
y=alt.Y("y:Q"),
color=alt.condition(click, color_points, alt.value("lightgray")),
color=alt.condition(
click,
alt.Color(
"outlier:N",
scale=alt.Scale(domain=["Yes", "No"], range=["red", "#f2a619"]),
title="Outlier",
),
alt.value("lightgray"),
),
tooltip=[
alt.Tooltip("x:Q", format=".3f", title="x"),
alt.Tooltip("y:Q", format=".3f", title="f(x)"),
Expand Down
12 changes: 5 additions & 7 deletions examples/plot_quantile_vs_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,6 @@ def plot_prediction_histograms(df, legend):

click = alt.selection_point(fields=["label"], bind="legend")

color = alt.condition(
click,
alt.Color("label:N", sort=list(legend.keys()), title=None),
alt.value("lightgray"),
)

chart = (
alt.Chart(df)
.add_params(quantile_selection, click)
Expand All @@ -102,7 +96,11 @@ def plot_prediction_histograms(df, legend):
title="Actual and Predicted Target Values",
),
y=alt.Y("count():Q", axis=alt.Axis(format=",d", title="Counts")),
color=color,
color=alt.condition(
click,
alt.Color("label:N", sort=list(legend.keys()), title=None),
alt.value("lightgray"),
),
xOffset=alt.XOffset("label:N"),
tooltip=[
alt.Tooltip("label:N", title="Label"),
Expand Down
23 changes: 11 additions & 12 deletions examples/plot_treeshap_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def plot_shap_waterfall_with_quantiles(df, height=300):
step=0.5 if len(quantiles) == 1 else 1 / (len(quantiles) - 1),
name="Predicted Quantile: ",
)

q_val = alt.selection_point(value=0.5, bind=slider, fields=["quantile"])
quantile_selection = alt.param(value=0.5, bind=slider, name="quantile")

df_grouped = (
df.groupby("quantile")[df.columns.tolist()]
Expand Down Expand Up @@ -166,7 +165,7 @@ def plot_shap_waterfall_with_quantiles(df, height=300):

base = (
alt.Chart(df_grouped)
.transform_filter(q_val)
.transform_filter("datum.quantile == quantile")
.transform_calculate(
end_shifted=f"datum.shap_value > 0 ? datum.end - {x_shift} : datum.end + {x_shift}"
)
Expand Down Expand Up @@ -221,15 +220,15 @@ def plot_shap_waterfall_with_quantiles(df, height=300):
)
text_label_start = (
alt.Chart(df_text_labels)
.transform_filter(q_val)
.transform_filter(alt.datum["type"] == "start")
.transform_filter("datum.quantile == quantile")
.transform_filter("datum.type == 'start'")
.mark_text(align="left", color="black", dx=-16, dy=y_text_offset + 30)
.encode(text=alt.Text("label"), x=alt.X("x:Q"))
)
text_label_end = (
alt.Chart(df_text_labels)
.transform_filter(q_val)
.transform_filter(alt.datum["type"] == "end")
.transform_filter("datum.quantile == quantile")
.transform_filter("datum.type == 'end'")
.mark_text(align="left", color="black", dx=-8, dy=-y_text_offset - 15)
.encode(text=alt.Text("label"), x=alt.X("x:Q"))
)
Expand All @@ -255,23 +254,23 @@ def plot_shap_waterfall_with_quantiles(df, height=300):
)
tick_start_rule = (
alt.Chart(df_text_labels)
.transform_filter(q_val)
.transform_filter(alt.datum["type"] == "start")
.transform_filter("datum.quantile == quantile")
.transform_filter("datum.type == 'start'")
.mark_rule(color="black", opacity=1, y=height, y2=height + 6)
.encode(x=alt.X("x:Q"))
)
tick_end_rule = (
alt.Chart(df_text_labels)
.transform_filter(q_val)
.transform_filter(alt.datum["type"] == "end")
.transform_filter("datum.quantile == quantile")
.transform_filter("datum.type == 'end'")
.mark_rule(color="black", opacity=1, y=0, y2=-6)
.encode(x=alt.X("x:Q"))
)
rules = feature_bar_rule + end_bar_rule + tick_start_rule + tick_end_rule

chart = (
(bars + points + text + rules)
.add_params(q_val)
.add_params(quantile_selection)
.configure_view(strokeOpacity=0)
.properties(
width=600, height=height, title="Waterfall Plot of SHAP Values for QRF Predictions"
Expand Down

0 comments on commit cff164e

Please sign in to comment.