Skip to content

Commit

Permalink
fix: ruff rule F821 (undefined name) (#54)
Browse files Browse the repository at this point in the history
* chore: remove F821 from ruff ignore

* fix: missing mpl import

* fix: undefined name errors in beeswarm

And also temporarily remove the SHAP interaction beeswarm plot code
(it's not runnable at the moment, because we ensured that the ndim of
SHAP values was 2 above, so it can never be 3 to trigger the interaction
beeswarm plot here).

The removal of code here is mostly just to appease ruff.
Will create a separate issue for this to re-introduce SHAP interaction
values plotting again later on. Maybe it should be implemented as a
separate function.
  • Loading branch information
thatlittleboy authored May 30, 2023
1 parent 886f86c commit 14c76da
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 86 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,5 @@ select = ["F"]
ignore = [
"F401", # unused imports
"F811", # Redefinition of unused variable
"F821", # Undefined name
"F841", # Local variables assigned but unused
]
166 changes: 84 additions & 82 deletions shap/plots/_beeswarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
# out_names = shap_exp.output_names

order = convert_ordering(order, values)


# # deprecation warnings
# if auto_size_plot is not None:
Expand Down Expand Up @@ -134,88 +133,91 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
partition_tree = None
else:
partition_tree = clustering

if partition_tree is not None:
assert partition_tree.shape[1] == 4, "The clustering provided by the Explanation object does not seem to be a partition tree (which is all shap.plots.bar supports)!"

# plotting SHAP interaction values
if len(values.shape) == 3:

if plot_type == "compact_dot":
new_values = values.reshape(values.shape[0], -1)
new_features = np.tile(features, (1, 1, features.shape[1])).reshape(features.shape[0], -1)

new_feature_names = []
for c1 in feature_names:
for c2 in feature_names:
if c1 == c2:
new_feature_names.append(c1)
else:
new_feature_names.append(c1 + "* - " + c2)

return beeswarm(
new_values, new_features, new_feature_names,
max_display=max_display, plot_type="dot", color=color, axis_color=axis_color,
title=title, alpha=alpha, show=show, sort=sort,
color_bar=color_bar, plot_size=plot_size, class_names=class_names,
color_bar_label="*" + color_bar_label
)

if max_display is None:
max_display = 7
else:
max_display = min(len(feature_names), max_display)

interaction_sort_inds = order#np.argsort(-np.abs(values.sum(1)).sum(0))

# get plotting limits
delta = 1.0 / (values.shape[1] ** 2)
slow = np.nanpercentile(values, delta)
shigh = np.nanpercentile(values, 100 - delta)
v = max(abs(slow), abs(shigh))
slow = -v
shigh = v

pl.figure(figsize=(1.5 * max_display + 1, 0.8 * max_display + 1))
pl.subplot(1, max_display, 1)
proj_values = values[:, interaction_sort_inds[0], interaction_sort_inds]
proj_values[:, 1:] *= 2 # because off diag effects are split in half
beeswarm(
proj_values, features[:, interaction_sort_inds] if features is not None else None,
feature_names=feature_names[interaction_sort_inds],
sort=False, show=False, color_bar=False,
plot_size=None,
max_display=max_display
)
pl.xlim((slow, shigh))
pl.xlabel("")
title_length_limit = 11
pl.title(shorten_text(feature_names[interaction_sort_inds[0]], title_length_limit))
for i in range(1, min(len(interaction_sort_inds), max_display)):
ind = interaction_sort_inds[i]
pl.subplot(1, max_display, i + 1)
proj_values = values[:, ind, interaction_sort_inds]
proj_values *= 2
proj_values[:, i] /= 2 # because only off diag effects are split in half
summary(
proj_values, features[:, interaction_sort_inds] if features is not None else None,
sort=False,
feature_names=["" for i in range(len(feature_names))],
show=False,
color_bar=False,
plot_size=None,
max_display=max_display
)
pl.xlim((slow, shigh))
pl.xlabel("")
if i == min(len(interaction_sort_inds), max_display) // 2:
pl.xlabel(labels['INTERACTION_VALUE'])
pl.title(shorten_text(feature_names[ind], title_length_limit))
pl.tight_layout(pad=0, w_pad=0, h_pad=0.0)
pl.subplots_adjust(hspace=0, wspace=0.1)
if show:
pl.show()
return
# FIXME: introduce beeswarm interaction values as a separate function `beeswarm_interaction()` (?)
# In the meantime, users can use the `shap.summary_plot()` function.
#
# # plotting SHAP interaction values
# if len(values.shape) == 3:
#
# if plot_type == "compact_dot":
# new_values = values.reshape(values.shape[0], -1)
# new_features = np.tile(features, (1, 1, features.shape[1])).reshape(features.shape[0], -1)
#
# new_feature_names = []
# for c1 in feature_names:
# for c2 in feature_names:
# if c1 == c2:
# new_feature_names.append(c1)
# else:
# new_feature_names.append(c1 + "* - " + c2)
#
# return beeswarm(
# new_values, new_features, new_feature_names,
# max_display=max_display, plot_type="dot", color=color, axis_color=axis_color,
# title=title, alpha=alpha, show=show, sort=sort,
# color_bar=color_bar, plot_size=plot_size, class_names=class_names,
# color_bar_label="*" + color_bar_label
# )
#
# if max_display is None:
# max_display = 7
# else:
# max_display = min(len(feature_names), max_display)
#
# interaction_sort_inds = order#np.argsort(-np.abs(values.sum(1)).sum(0))
#
# # get plotting limits
# delta = 1.0 / (values.shape[1] ** 2)
# slow = np.nanpercentile(values, delta)
# shigh = np.nanpercentile(values, 100 - delta)
# v = max(abs(slow), abs(shigh))
# slow = -v
# shigh = v
#
# pl.figure(figsize=(1.5 * max_display + 1, 0.8 * max_display + 1))
# pl.subplot(1, max_display, 1)
# proj_values = values[:, interaction_sort_inds[0], interaction_sort_inds]
# proj_values[:, 1:] *= 2 # because off diag effects are split in half
# beeswarm(
# proj_values, features[:, interaction_sort_inds] if features is not None else None,
# feature_names=feature_names[interaction_sort_inds],
# sort=False, show=False, color_bar=False,
# plot_size=None,
# max_display=max_display
# )
# pl.xlim((slow, shigh))
# pl.xlabel("")
# title_length_limit = 11
# pl.title(shorten_text(feature_names[interaction_sort_inds[0]], title_length_limit))
# for i in range(1, min(len(interaction_sort_inds), max_display)):
# ind = interaction_sort_inds[i]
# pl.subplot(1, max_display, i + 1)
# proj_values = values[:, ind, interaction_sort_inds]
# proj_values *= 2
# proj_values[:, i] /= 2 # because only off diag effects are split in half
# summary(
# proj_values, features[:, interaction_sort_inds] if features is not None else None,
# sort=False,
# feature_names=["" for i in range(len(feature_names))],
# show=False,
# color_bar=False,
# plot_size=None,
# max_display=max_display
# )
# pl.xlim((slow, shigh))
# pl.xlabel("")
# if i == min(len(interaction_sort_inds), max_display) // 2:
# pl.xlabel(labels['INTERACTION_VALUE'])
# pl.title(shorten_text(feature_names[ind], title_length_limit))
# pl.tight_layout(pad=0, w_pad=0, h_pad=0.0)
# pl.subplots_adjust(hspace=0, wspace=0.1)
# if show:
# pl.show()
# return

# determine how many top features we will plot
if max_display is None:
Expand All @@ -234,9 +236,9 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
clust_order = sort_inds(partition_tree, np.abs(values))

# now relax the requirement to match the parition tree ordering for connections above cluster_threshold
dist = scipy.spatial.distance.squareform(scipy.cluster.hierarchy.cophenet(partition_tree))
dist = sp.spatial.distance.squareform(sp.cluster.hierarchy.cophenet(partition_tree))
feature_order = get_sort_order(dist, clust_order, cluster_threshold, feature_order)

# if the last feature we can display is connected in a tree the next feature then we can't just cut
# off the feature ordering, so we need to merge some tree nodes and then try again.
if max_display < len(feature_order) and dist[feature_order[max_display-1],feature_order[max_display-2]] <= cluster_threshold:
Expand Down
7 changes: 4 additions & 3 deletions shap/plots/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
from ..utils import OpChain
from . import colors
import numpy as np
import matplotlib.pyplot as pl


def convert_color(color):
try:
color = pl.get_cmap(color)
except:
except Exception:
pass

if color == "shap_red":
color = colors.red_rgb
elif color == "shap_blue":
color = colors.blue_rgb

return color

def convert_ordering(ordering, shap_values):
Expand Down

0 comments on commit 14c76da

Please sign in to comment.