From de9eb2c4bb4d470875cc0300e44498161f13c992 Mon Sep 17 00:00:00 2001 From: Arne Gevaert Date: Wed, 23 Aug 2023 14:32:56 +0200 Subject: [PATCH 1/2] Add attribution visualization functions + bugfixes --- attribench/data/__init__.py | 9 +- .../functional/_compute_attributions.py | 42 ++++---- attribench/util/__init__.py | 1 + attribench/util/visualize_attributions.py | 102 ++++++++++++++++++ 4 files changed, 130 insertions(+), 24 deletions(-) create mode 100644 attribench/util/__init__.py create mode 100644 attribench/util/visualize_attributions.py diff --git a/attribench/data/__init__.py b/attribench/data/__init__.py index f58b1ea..acef1a1 100644 --- a/attribench/data/__init__.py +++ b/attribench/data/__init__.py @@ -1,5 +1,10 @@ from .hdf5_dataset._hdf5_dataset import HDF5Dataset from .hdf5_dataset._hdf5_dataset_writer import HDF5DatasetWriter from ._index_dataset import IndexDataset -from .attributions_dataset._attributions_dataset import AttributionsDataset -from .attributions_dataset._attributions_dataset_writer import AttributionsDatasetWriter +from .attributions_dataset._attributions_dataset import ( + AttributionsDataset, + GroupedAttributionsDataset, +) +from .attributions_dataset._attributions_dataset_writer import ( + AttributionsDatasetWriter, +) diff --git a/attribench/functional/_compute_attributions.py b/attribench/functional/_compute_attributions.py index d11540d..072b229 100644 --- a/attribench/functional/_compute_attributions.py +++ b/attribench/functional/_compute_attributions.py @@ -15,13 +15,11 @@ def compute_attributions( writer: Optional[AttributionsDatasetWriter] = None, device: Optional[torch.device] = None, ) -> Optional[Dict[str, torch.Tensor]]: - """Compute attributions for a given model and dataset using a dictionary of + """Compute attributions for a given model and dataset using a dictionary of attribution methods, and optionally write them to a HDF5 file. If the `writer` is `None`, the attributions are simply returned in a dictionary. Otherwise, the attributions are written to the HDF5 file and `None` is returned. - TODO don't write to file, just return the dict - Parameters ---------- model : nn.Module @@ -57,27 +55,27 @@ def compute_attributions( pin_memory=True, ) - result_dict: Dict[str, List[torch.Tensor]] = {method_name: [ - torch.zeros(1) for _ in range(len(index_dataset)) - ] for method_name in method_dict.keys()} + num_samples = len(index_dataset) + sample_shape = None + result_dict: Dict[str, torch.Tensor] = {} for batch_indices, batch_x, batch_y in tqdm(dataloader): + if sample_shape is None: + sample_shape = batch_x.shape[1:] + result_dict = { + method_name: torch.zeros(num_samples, *sample_shape) + for method_name in method_dict.keys() + } batch_x = batch_x.to(device) batch_y = batch_y.to(device) for method_name, method in method_dict.items(): - with torch.no_grad(): - attrs = method(batch_x, batch_y) - if writer is None: - for idx in batch_indices: - result_dict[method_name][idx] = attrs[idx, ...].cpu() - else: - writer.write( - batch_indices.cpu().numpy(), - attrs.cpu().numpy(), - method_name, - ) + attrs = method(batch_x, batch_y) + if writer is None: + result_dict[method_name][batch_indices, ...] = attrs.cpu() + else: + writer.write( + batch_indices.cpu().numpy(), + attrs.cpu().numpy(), + method_name, + ) if writer is None: - result_dict_cat = { - method_name: torch.cat(attrs_list) - for method_name, attrs_list in result_dict.items() - } - return result_dict_cat + return result_dict diff --git a/attribench/util/__init__.py b/attribench/util/__init__.py new file mode 100644 index 0000000..e1df90d --- /dev/null +++ b/attribench/util/__init__.py @@ -0,0 +1 @@ +from .visualize_attributions import visualize_attributions \ No newline at end of file diff --git a/attribench/util/visualize_attributions.py b/attribench/util/visualize_attributions.py new file mode 100644 index 0000000..2be5e9d --- /dev/null +++ b/attribench/util/visualize_attributions.py @@ -0,0 +1,102 @@ +import torch +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from typing import Optional, Dict + + +def _plot_heatmap( + fig: Figure, + ax: plt.Axes, + attributions: torch.Tensor, + image: torch.Tensor, + cmap: str, + center_at_zero: bool, + title: Optional[str], + overlay: bool, +): + vmax = ( + attributions.abs().max().item() + if center_at_zero + else attributions.max().item() + ) + if overlay: + ax.imshow(image, alpha=0.5) + vmin = -vmax if center_at_zero else attributions.min().item() + alpha = 0.5 if overlay else 1.0 + img = ax.imshow(attributions, cmap=cmap, vmin=vmin, vmax=vmax, alpha=alpha) + fig.colorbar(img, ax=ax) + if title is not None: + ax.set_title(title) + + +def visualize_attributions( + attributions: Dict[str, torch.Tensor], + image: torch.Tensor, + cmap: str = "bwr", + center_at_zero: bool = True, + overlay = False, +) -> Figure: + """Visualize attributions. + Attributions can be visualized by overlaying them on the original + image, by plotting them as a heatmap, or by plotting the original image + with a transparency mask over it, making pixels with higher attribution + values more visible. + + The shape of images and attributions is assumed to be (N, C, H, W), + where N is the number of samples, C is the number of channels, and + H and W are the height and width of the images. The channel dimension is + eliminated by averaging over it. + + Parameters + ---------- + attributions : Dict[str, torch.Tensor] + Dictionary mapping method names to attributions. The attributions + should have shape (C, H, W). + image : torch.Tensor, optional + Original image. Shape: (C, H, W), by default None. + cmap : str, optional + Colormap to use for plotting the heatmap, by default "bwr". + center_at_zero : bool, optional + Whether to center the colormap at zero, making a zero attribution + value correspond to white in a diverging colormap. By default True. + overlay : bool, optional + Whether to overlay the attributions on the original image, by default + False. + """ + # Checking inputs + num_methods = len(attributions.keys()) + + n_rows = num_methods // 4 + 1 + n_cols = 4 + + fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, n_rows * 4)) + if num_methods != 1: + axs = axs.flatten() + + # Plot original image + axs[0].imshow(image) + axs[0].set_title("Original image") + + # Plot heatmaps + for idx, method_name in enumerate(attributions.keys()): + if num_methods != 1: + ax = axs[idx + 1] + else: + ax = axs + assert isinstance(ax, plt.Axes) + _plot_heatmap( + fig, + ax, + attributions[method_name].mean(dim=0), + image, + cmap=cmap, + center_at_zero=center_at_zero, + title=method_name, + overlay=overlay + ) + + if n_rows * n_cols > num_methods + 1: + for ax in axs[num_methods + 1:]: + ax.remove() + + return fig From b2e09a47bdaaa0decf1b90bf0fd89e0cb6625c16 Mon Sep 17 00:00:00 2001 From: Arne Gevaert Date: Wed, 23 Aug 2023 14:34:34 +0200 Subject: [PATCH 2/2] Bump version 0.1.2 -> 0.1.3 --- attribench/__init__.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/attribench/__init__.py b/attribench/__init__.py index e11a141..e6c579b 100644 --- a/attribench/__init__.py +++ b/attribench/__init__.py @@ -2,4 +2,4 @@ from ._method_factory import MethodFactory from ._model_factory import ModelFactory, BasicModelFactory -__version__ = "0.1.2" +__version__ = "0.1.3" diff --git a/pyproject.toml b/pyproject.toml index aa804c4..aaa78ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" [project] name = "attribench" -version = "0.1.2" +version = "0.1.3" description = "A benchmark for feature attribution techniques" readme = "README.rst" authors = [ @@ -59,7 +59,7 @@ Homepage = "https://github.com/arnegevaert/benchmark" Documentation = "http://attribench.readthedocs.io/" [tool.bumpver] -current_version = "0.1.2" +current_version = "0.1.3" version_pattern = "MAJOR.MINOR.PATCH" commit_message = "Bump version {old_version} -> {new_version}" commit = true