Skip to content

Commit

Permalink
fix: undefined name errors in beeswarm
Browse files Browse the repository at this point in the history
And also temporarily remove the SHAP interaction beeswarm plot code
(it's not runnable at the moment, because we ensured that the ndim of
SHAP values was 2 above, so it can never be 3 to trigger the interaction
beeswarm plot here).

The removal of code here is mostly just to appease ruff.
Will create a separate issue for this to re-introduce SHAP interaction
values plotting again later on. Maybe it should be implemented as a
separate function.
  • Loading branch information
thatlittleboy committed May 29, 2023
1 parent 9f47185 commit d396f1c
Showing 1 changed file with 84 additions and 82 deletions.
166 changes: 84 additions & 82 deletions shap/plots/_beeswarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
# out_names = shap_exp.output_names

order = convert_ordering(order, values)


# # deprecation warnings
# if auto_size_plot is not None:
Expand Down Expand Up @@ -134,88 +133,91 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
partition_tree = None
else:
partition_tree = clustering

if partition_tree is not None:
assert partition_tree.shape[1] == 4, "The clustering provided by the Explanation object does not seem to be a partition tree (which is all shap.plots.bar supports)!"

# plotting SHAP interaction values
if len(values.shape) == 3:

if plot_type == "compact_dot":
new_values = values.reshape(values.shape[0], -1)
new_features = np.tile(features, (1, 1, features.shape[1])).reshape(features.shape[0], -1)

new_feature_names = []
for c1 in feature_names:
for c2 in feature_names:
if c1 == c2:
new_feature_names.append(c1)
else:
new_feature_names.append(c1 + "* - " + c2)

return beeswarm(
new_values, new_features, new_feature_names,
max_display=max_display, plot_type="dot", color=color, axis_color=axis_color,
title=title, alpha=alpha, show=show, sort=sort,
color_bar=color_bar, plot_size=plot_size, class_names=class_names,
color_bar_label="*" + color_bar_label
)

if max_display is None:
max_display = 7
else:
max_display = min(len(feature_names), max_display)

interaction_sort_inds = order#np.argsort(-np.abs(values.sum(1)).sum(0))

# get plotting limits
delta = 1.0 / (values.shape[1] ** 2)
slow = np.nanpercentile(values, delta)
shigh = np.nanpercentile(values, 100 - delta)
v = max(abs(slow), abs(shigh))
slow = -v
shigh = v

pl.figure(figsize=(1.5 * max_display + 1, 0.8 * max_display + 1))
pl.subplot(1, max_display, 1)
proj_values = values[:, interaction_sort_inds[0], interaction_sort_inds]
proj_values[:, 1:] *= 2 # because off diag effects are split in half
beeswarm(
proj_values, features[:, interaction_sort_inds] if features is not None else None,
feature_names=feature_names[interaction_sort_inds],
sort=False, show=False, color_bar=False,
plot_size=None,
max_display=max_display
)
pl.xlim((slow, shigh))
pl.xlabel("")
title_length_limit = 11
pl.title(shorten_text(feature_names[interaction_sort_inds[0]], title_length_limit))
for i in range(1, min(len(interaction_sort_inds), max_display)):
ind = interaction_sort_inds[i]
pl.subplot(1, max_display, i + 1)
proj_values = values[:, ind, interaction_sort_inds]
proj_values *= 2
proj_values[:, i] /= 2 # because only off diag effects are split in half
summary(
proj_values, features[:, interaction_sort_inds] if features is not None else None,
sort=False,
feature_names=["" for i in range(len(feature_names))],
show=False,
color_bar=False,
plot_size=None,
max_display=max_display
)
pl.xlim((slow, shigh))
pl.xlabel("")
if i == min(len(interaction_sort_inds), max_display) // 2:
pl.xlabel(labels['INTERACTION_VALUE'])
pl.title(shorten_text(feature_names[ind], title_length_limit))
pl.tight_layout(pad=0, w_pad=0, h_pad=0.0)
pl.subplots_adjust(hspace=0, wspace=0.1)
if show:
pl.show()
return
# FIXME: introduce beeswarm interaction values as a separate function `beeswarm_interaction()` (?)
# In the meantime, users can use the `shap.summary_plot()` function.
#
# # plotting SHAP interaction values
# if len(values.shape) == 3:
#
# if plot_type == "compact_dot":
# new_values = values.reshape(values.shape[0], -1)
# new_features = np.tile(features, (1, 1, features.shape[1])).reshape(features.shape[0], -1)
#
# new_feature_names = []
# for c1 in feature_names:
# for c2 in feature_names:
# if c1 == c2:
# new_feature_names.append(c1)
# else:
# new_feature_names.append(c1 + "* - " + c2)
#
# return beeswarm(
# new_values, new_features, new_feature_names,
# max_display=max_display, plot_type="dot", color=color, axis_color=axis_color,
# title=title, alpha=alpha, show=show, sort=sort,
# color_bar=color_bar, plot_size=plot_size, class_names=class_names,
# color_bar_label="*" + color_bar_label
# )
#
# if max_display is None:
# max_display = 7
# else:
# max_display = min(len(feature_names), max_display)
#
# interaction_sort_inds = order#np.argsort(-np.abs(values.sum(1)).sum(0))
#
# # get plotting limits
# delta = 1.0 / (values.shape[1] ** 2)
# slow = np.nanpercentile(values, delta)
# shigh = np.nanpercentile(values, 100 - delta)
# v = max(abs(slow), abs(shigh))
# slow = -v
# shigh = v
#
# pl.figure(figsize=(1.5 * max_display + 1, 0.8 * max_display + 1))
# pl.subplot(1, max_display, 1)
# proj_values = values[:, interaction_sort_inds[0], interaction_sort_inds]
# proj_values[:, 1:] *= 2 # because off diag effects are split in half
# beeswarm(
# proj_values, features[:, interaction_sort_inds] if features is not None else None,
# feature_names=feature_names[interaction_sort_inds],
# sort=False, show=False, color_bar=False,
# plot_size=None,
# max_display=max_display
# )
# pl.xlim((slow, shigh))
# pl.xlabel("")
# title_length_limit = 11
# pl.title(shorten_text(feature_names[interaction_sort_inds[0]], title_length_limit))
# for i in range(1, min(len(interaction_sort_inds), max_display)):
# ind = interaction_sort_inds[i]
# pl.subplot(1, max_display, i + 1)
# proj_values = values[:, ind, interaction_sort_inds]
# proj_values *= 2
# proj_values[:, i] /= 2 # because only off diag effects are split in half
# summary(
# proj_values, features[:, interaction_sort_inds] if features is not None else None,
# sort=False,
# feature_names=["" for i in range(len(feature_names))],
# show=False,
# color_bar=False,
# plot_size=None,
# max_display=max_display
# )
# pl.xlim((slow, shigh))
# pl.xlabel("")
# if i == min(len(interaction_sort_inds), max_display) // 2:
# pl.xlabel(labels['INTERACTION_VALUE'])
# pl.title(shorten_text(feature_names[ind], title_length_limit))
# pl.tight_layout(pad=0, w_pad=0, h_pad=0.0)
# pl.subplots_adjust(hspace=0, wspace=0.1)
# if show:
# pl.show()
# return

# determine how many top features we will plot
if max_display is None:
Expand All @@ -234,9 +236,9 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
clust_order = sort_inds(partition_tree, np.abs(values))

# now relax the requirement to match the parition tree ordering for connections above cluster_threshold
dist = scipy.spatial.distance.squareform(scipy.cluster.hierarchy.cophenet(partition_tree))
dist = sp.spatial.distance.squareform(sp.cluster.hierarchy.cophenet(partition_tree))
feature_order = get_sort_order(dist, clust_order, cluster_threshold, feature_order)

# if the last feature we can display is connected in a tree the next feature then we can't just cut
# off the feature ordering, so we need to merge some tree nodes and then try again.
if max_display < len(feature_order) and dist[feature_order[max_display-1],feature_order[max_display-2]] <= cluster_threshold:
Expand Down

0 comments on commit d396f1c

Please sign in to comment.