Skip to content

Commit

Permalink
Merge pull request #176 from arnegevaert/util_functions
Browse files Browse the repository at this point in the history
Util functions
  • Loading branch information
arnegevaert authored Aug 23, 2023
2 parents e556cd8 + b2e09a4 commit c022f1b
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 27 deletions.
2 changes: 1 addition & 1 deletion attribench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from ._method_factory import MethodFactory
from ._model_factory import ModelFactory, BasicModelFactory

__version__ = "0.1.2"
__version__ = "0.1.3"
9 changes: 7 additions & 2 deletions attribench/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from .hdf5_dataset._hdf5_dataset import HDF5Dataset
from .hdf5_dataset._hdf5_dataset_writer import HDF5DatasetWriter
from ._index_dataset import IndexDataset
from .attributions_dataset._attributions_dataset import AttributionsDataset
from .attributions_dataset._attributions_dataset_writer import AttributionsDatasetWriter
from .attributions_dataset._attributions_dataset import (
AttributionsDataset,
GroupedAttributionsDataset,
)
from .attributions_dataset._attributions_dataset_writer import (
AttributionsDatasetWriter,
)
42 changes: 20 additions & 22 deletions attribench/functional/_compute_attributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@ def compute_attributions(
writer: Optional[AttributionsDatasetWriter] = None,
device: Optional[torch.device] = None,
) -> Optional[Dict[str, torch.Tensor]]:
"""Compute attributions for a given model and dataset using a dictionary of
"""Compute attributions for a given model and dataset using a dictionary of
attribution methods, and optionally write them to a HDF5 file. If the `writer`
is `None`, the attributions are simply returned in a dictionary.
Otherwise, the attributions are written to the HDF5 file and `None` is returned.
TODO don't write to file, just return the dict
Parameters
----------
model : nn.Module
Expand Down Expand Up @@ -57,27 +55,27 @@ def compute_attributions(
pin_memory=True,
)

result_dict: Dict[str, List[torch.Tensor]] = {method_name: [
torch.zeros(1) for _ in range(len(index_dataset))
] for method_name in method_dict.keys()}
num_samples = len(index_dataset)
sample_shape = None
result_dict: Dict[str, torch.Tensor] = {}
for batch_indices, batch_x, batch_y in tqdm(dataloader):
if sample_shape is None:
sample_shape = batch_x.shape[1:]
result_dict = {
method_name: torch.zeros(num_samples, *sample_shape)
for method_name in method_dict.keys()
}
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
for method_name, method in method_dict.items():
with torch.no_grad():
attrs = method(batch_x, batch_y)
if writer is None:
for idx in batch_indices:
result_dict[method_name][idx] = attrs[idx, ...].cpu()
else:
writer.write(
batch_indices.cpu().numpy(),
attrs.cpu().numpy(),
method_name,
)
attrs = method(batch_x, batch_y)
if writer is None:
result_dict[method_name][batch_indices, ...] = attrs.cpu()
else:
writer.write(
batch_indices.cpu().numpy(),
attrs.cpu().numpy(),
method_name,
)
if writer is None:
result_dict_cat = {
method_name: torch.cat(attrs_list)
for method_name, attrs_list in result_dict.items()
}
return result_dict_cat
return result_dict
1 change: 1 addition & 0 deletions attribench/util/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .visualize_attributions import visualize_attributions
102 changes: 102 additions & 0 deletions attribench/util/visualize_attributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from typing import Optional, Dict


def _plot_heatmap(
fig: Figure,
ax: plt.Axes,
attributions: torch.Tensor,
image: torch.Tensor,
cmap: str,
center_at_zero: bool,
title: Optional[str],
overlay: bool,
):
vmax = (
attributions.abs().max().item()
if center_at_zero
else attributions.max().item()
)
if overlay:
ax.imshow(image, alpha=0.5)
vmin = -vmax if center_at_zero else attributions.min().item()
alpha = 0.5 if overlay else 1.0
img = ax.imshow(attributions, cmap=cmap, vmin=vmin, vmax=vmax, alpha=alpha)
fig.colorbar(img, ax=ax)
if title is not None:
ax.set_title(title)


def visualize_attributions(
attributions: Dict[str, torch.Tensor],
image: torch.Tensor,
cmap: str = "bwr",
center_at_zero: bool = True,
overlay = False,
) -> Figure:
"""Visualize attributions.
Attributions can be visualized by overlaying them on the original
image, by plotting them as a heatmap, or by plotting the original image
with a transparency mask over it, making pixels with higher attribution
values more visible.
The shape of images and attributions is assumed to be (N, C, H, W),
where N is the number of samples, C is the number of channels, and
H and W are the height and width of the images. The channel dimension is
eliminated by averaging over it.
Parameters
----------
attributions : Dict[str, torch.Tensor]
Dictionary mapping method names to attributions. The attributions
should have shape (C, H, W).
image : torch.Tensor, optional
Original image. Shape: (C, H, W), by default None.
cmap : str, optional
Colormap to use for plotting the heatmap, by default "bwr".
center_at_zero : bool, optional
Whether to center the colormap at zero, making a zero attribution
value correspond to white in a diverging colormap. By default True.
overlay : bool, optional
Whether to overlay the attributions on the original image, by default
False.
"""
# Checking inputs
num_methods = len(attributions.keys())

n_rows = num_methods // 4 + 1
n_cols = 4

fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, n_rows * 4))
if num_methods != 1:
axs = axs.flatten()

# Plot original image
axs[0].imshow(image)
axs[0].set_title("Original image")

# Plot heatmaps
for idx, method_name in enumerate(attributions.keys()):
if num_methods != 1:
ax = axs[idx + 1]
else:
ax = axs
assert isinstance(ax, plt.Axes)
_plot_heatmap(
fig,
ax,
attributions[method_name].mean(dim=0),
image,
cmap=cmap,
center_at_zero=center_at_zero,
title=method_name,
overlay=overlay
)

if n_rows * n_cols > num_methods + 1:
for ax in axs[num_methods + 1:]:
ax.remove()

return fig
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "attribench"
version = "0.1.2"
version = "0.1.3"
description = "A benchmark for feature attribution techniques"
readme = "README.rst"
authors = [
Expand Down Expand Up @@ -59,7 +59,7 @@ Homepage = "https://github.com/arnegevaert/benchmark"
Documentation = "http://attribench.readthedocs.io/"

[tool.bumpver]
current_version = "0.1.2"
current_version = "0.1.3"
version_pattern = "MAJOR.MINOR.PATCH"
commit_message = "Bump version {old_version} -> {new_version}"
commit = true
Expand Down

0 comments on commit c022f1b

Please sign in to comment.