diff --git a/docs/source/tutorial/image-classification-vgg-resnet.ipynb b/docs/source/tutorial/image-classification-vgg-resnet.ipynb new file mode 100644 index 0000000..3eca276 --- /dev/null +++ b/docs/source/tutorial/image-classification-vgg-resnet.ipynb @@ -0,0 +1,1078 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "810f3c46", + "metadata": {}, + "source": [ + "# Image Classification with VGG and ResNet\n", + "\n", + "This tutorial will introduce the attribution of image classifiers using VGG11\n", + "and ResNet18 trained on ImageNet. Feel free to replace VGG11 and ResNet18 with\n", + "any other version of VGG or ResNet respectively.\n", + "\n", + "## Table of Contents\n", + "* [1. Preparation](<#1.-Preparation>)\n", + "* [2. VGG11 without BatchNorm](<#2.-VGG11-without-BatchNorm>)\n", + " * [2.1 Saliency Map](<#2.1-Saliency-Map>)\n", + " * [2.2 SmoothGrad using Attributors](<#2.2-SmoothGrad-using-Attributors>)\n", + " * [2.3 Layer-wise Relevance Propagation (LRP) with EpsilonPlusFlat](<#2.3-Layer-wise-Relevance-Propagation-(LRP)-with-EpsilonPlusFlat>)\n", + " * [2.4 LRP with EpsilonGammaBox](<#2.4-LRP-with-EpsilonGammaBox>)\n", + " * [2.5 LRP with EpsilonGammaBox with modified epsilon and stabilizer](<#2.5-LRP-with-EpsilonGammaBox-with-modified-epsilon-and-stabilizer>)\n", + " * [2.6 More Visualization](<#2.6-More-Visualization>)\n", + "* [3. VGG11 with BatchNorm](<#3.-VGG11-with-BatchNorm>)\n", + " * [3.1 LRP with EpsilonGammaBox](<#3.1-LRP-with-EpsilonGammaBox>)\n", + " * [3.2 LRP with modified EpsilonGammaBox and ignored BatchNorm](<#3.2-LRP-with-modified-EpsilonGammaBox-and-ignored-BatchNorm>)\n", + " * [3.3 LRP with custom NameMapComposite](<#3.3-LRP-with-custom-NameMapComposite>)\n", + "* [4. ResNet18](<#4.-ResNet18>)\n", + " * [4.1 LRP with EpsilonPlusFlat](<#4.1-LRP-with-EpsilonPlusFlat>)\n", + " * [4.2 LRP with EpsilonGammaBox](<#4.2-LRP-with-EpsilonGammaBox>)\n", + " * [4.3 LRP with custom LayerMapComposite](<#4.3-LRP-with-custom-LayerMapComposite>)\n", + "\n", + "## 1. Preparation\n", + "\n", + "First, we install **Zennit**. This includes its dependencies `Pillow`,\n", + "`torch` and `torchvision`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0e2f1fe", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install zennit" + ] + }, + { + "cell_type": "markdown", + "id": "3a2dc4cd", + "metadata": {}, + "source": [ + "Then, we import necessary modules, classes and functions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9a3fa5e", + "metadata": {}, + "outputs": [], + "source": [ + "from itertools import islice\n", + "\n", + "import torch\n", + "from torch.nn import Linear\n", + "from PIL import Image\n", + "from torchvision.transforms import Compose, Resize, CenterCrop\n", + "from torchvision.transforms import ToTensor, Normalize\n", + "from torchvision.models import vgg11_bn, vgg11, resnet18\n", + "\n", + "from zennit.attribution import Gradient, SmoothGrad\n", + "from zennit.core import Stabilizer\n", + "from zennit.composites import EpsilonGammaBox, EpsilonPlusFlat\n", + "from zennit.composites import SpecialFirstLayerMapComposite, NameMapComposite\n", + "from zennit.image import imgify, imsave\n", + "from zennit.rules import Epsilon, ZPlus, ZBox, Norm, Pass, Flat\n", + "from zennit.types import Convolution, Activation, AvgPool, Linear as AnyLinear\n", + "from zennit.types import BatchNorm, MaxPool\n", + "from zennit.torchvision import VGGCanonizer, ResNetCanonizer" + ] + }, + { + "cell_type": "markdown", + "id": "e48e434e", + "metadata": {}, + "source": [ + "We download an image of the [Dornbusch\n", + "Lighthouse](https://en.wikipedia.org/wiki/Dornbusch_Lighthouse) from [Wikimedia\n", + "Commons](https://commons.wikimedia.org/wiki/File:2006_09_06_180_Leuchtturm.jpg):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfaa9f3a", + "metadata": {}, + "outputs": [], + "source": [ + "torch.hub.download_url_to_file(\n", + " 'https://upload.wikimedia.org/wikipedia/commons/thumb/8/8b/2006_09_06_180_Leuchtturm.jpg/640px-2006_09_06_181_Leuchtturm.jpg',\n", + " 'dornbusch-lighthouse.jpg',\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "0798a0c1", + "metadata": {}, + "source": [ + "We load and prepare the data. The image is resized such that the shorter side\n", + "is 256 pixels in size, then center-cropped to `(224, 224)`, converted to a\n", + "`torch.Tensor`, and then normalized according the channel-wise mean and\n", + "standard deviation of the ImageNet dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4989b65c", + "metadata": {}, + "outputs": [], + "source": [ + "# define the base image transform\n", + "transform_img = Compose([\n", + " Resize(256),\n", + " CenterCrop(224),\n", + "])\n", + "# define the normalization transform\n", + "transform_norm = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n", + "# define the full tensor transform\n", + "transform = Compose([\n", + " transform_img,\n", + " ToTensor(),\n", + " transform_norm,\n", + "])\n", + "\n", + "# load the image\n", + "image = Image.open('dornbusch-lighthouse.jpg')\n", + "\n", + "# transform the PIL image and insert a batch-dimension\n", + "data = transform(image)[None]" + ] + }, + { + "cell_type": "markdown", + "id": "882b4dd8", + "metadata": {}, + "source": [ + "We can look at the original image and the cropped image:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "072a3ad0", + "metadata": {}, + "outputs": [], + "source": [ + "# display the original image\n", + "display(image)\n", + "# display the resized and cropped image\n", + "display(transform_img(image))" + ] + }, + { + "cell_type": "markdown", + "id": "3d45bd9b", + "metadata": {}, + "source": [ + "## 2. VGG11 without BatchNorm\n", + "\n", + "We start with VGG11 without BatchNorm.\n", + "First, we initialize the VGG16 model and optionally load the hyperparameters.\n", + "Set `weights='IMAGENET1K_V1'` to use the pre-trained model instead of the random\n", + "one:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "853bcf2d", + "metadata": {}, + "outputs": [], + "source": [ + "# load the model and set it to evaluation mode\n", + "model = vgg11(weights=None).eval()" + ] + }, + { + "cell_type": "markdown", + "id": "04ca56ed", + "metadata": {}, + "source": [ + "### 2.1 Saliency Map\n", + "We first compute the Saliency Map, which is the absolute gradient, by using the\n", + "`Gradient` **Attributor**:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5ec149f", + "metadata": {}, + "outputs": [], + "source": [ + "# choose a target class for the attribution (label 437 is lighthouse)\n", + "target = torch.eye(1000)[[437]]\n", + "\n", + "# create the attributor, specifying model\n", + "with Gradient(model=model) as attributor:\n", + " # compute the model output and attribution\n", + " output, attribution = attributor(data, target)\n", + "\n", + "print(f'Prediction: {output.argmax(1)[0].item()}')" + ] + }, + { + "cell_type": "markdown", + "id": "e49b5056", + "metadata": {}, + "source": [ + "Visualize the attribution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb87c828", + "metadata": {}, + "outputs": [], + "source": [ + "# absolute sum over the channels\n", + "relevance = attribution.abs().sum(1)\n", + "\n", + "# create an image of the visualize attribution the relevance is only\n", + "# positive, so we use symmetric=False and an unsigned color-map\n", + "img = imgify(relevance, symmetric=False, cmap='hot')\n", + "\n", + "# show the image\n", + "display(img)" + ] + }, + { + "cell_type": "markdown", + "id": "e2e53bad", + "metadata": {}, + "source": [ + "### 2.2 SmoothGrad using Attributors\n", + "Model-agnostic attribution methods like *SmoothGrad* are implemented as\n", + "*Attributors*. For these, we simply replace the `Gradient` **Attributor**,\n", + "e.g., with the `SmoothGrad` **Attributor**:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8846f4a6", + "metadata": {}, + "outputs": [], + "source": [ + "# choose a target class for the attribution (label 437 is lighthouse)\n", + "target = torch.eye(1000)[[437]]\n", + "\n", + "# create the attributor, specifying model\n", + "with SmoothGrad(noise_level=0.1, n_iter=20, model=model) as attributor:\n", + " # compute the model output and attribution\n", + " output, attribution = attributor(data, target)\n", + "\n", + "print(f'Prediction: {output.argmax(1)[0].item()}')" + ] + }, + { + "cell_type": "markdown", + "id": "6ad8f723", + "metadata": {}, + "source": [ + "Visualize the attribution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb34a8e2", + "metadata": {}, + "outputs": [], + "source": [ + "# take the absolute and sum over the channels\n", + "relevance = attribution.abs().sum(1)\n", + "\n", + "# create an image of the visualize attribution the relevance is only\n", + "# positive, so we use symmetric=False and an unsigned color-map\n", + "img = imgify(relevance, symmetric=False, cmap='hot')\n", + "\n", + "# show the image\n", + "display(img)" + ] + }, + { + "cell_type": "markdown", + "id": "d0510255", + "metadata": {}, + "source": [ + "### 2.3 Layer-wise Relevance Propagation (LRP) with EpsilonPlusFlat\n", + "We compute the LRP-attribution using the *Gradient* **Attributor**\n", + "together with the `EpsilonPlusFlat` **Composite**:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d1645b4", + "metadata": {}, + "outputs": [], + "source": [ + "# create a composite\n", + "composite = EpsilonPlusFlat()\n", + "\n", + "# choose a target class for the attribution (label 437 is lighthouse)\n", + "target = torch.eye(1000)[[437]]\n", + "\n", + "# create the attributor, specifying model and composite\n", + "with Gradient(model=model, composite=composite) as attributor:\n", + " # compute the model output and attribution\n", + " output, attribution = attributor(data, target)\n", + "\n", + "print(f'Prediction: {output.argmax(1)[0].item()}')" + ] + }, + { + "cell_type": "markdown", + "id": "e3e9ba55", + "metadata": {}, + "source": [ + "Visualize the attribution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7eda5200", + "metadata": {}, + "outputs": [], + "source": [ + "# sum over the channels\n", + "relevance = attribution.sum(1)\n", + "\n", + "# create an image of the visualize attribution\n", + "img = imgify(relevance, symmetric=True, cmap='coldnhot')\n", + "\n", + "# show the image\n", + "display(img)" + ] + }, + { + "cell_type": "markdown", + "id": "c9169bb3", + "metadata": {}, + "source": [ + "### 2.4 LRP with EpsilonGammaBox\n", + "We now compute the LRP-attribution with the `EpsilonGammaBox` **Composite**.\n", + "The `EpsilonGammaBox` **Composite** uses the `ZBox` rule, which needs as\n", + "arguments the lowest and highest possible values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e17da931", + "metadata": {}, + "outputs": [], + "source": [ + "# the EpsilonGammaBox composite needs the lowest and highest values, which are\n", + "# here for ImageNet 0. and 1. with a different normalization for each channel\n", + "low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))\n", + "\n", + "# create a composite, specifying required arguments\n", + "composite = EpsilonGammaBox(low=low, high=high)\n", + "\n", + "# choose a target class for the attribution (label 437 is lighthouse)\n", + "target = torch.eye(1000)[[437]]\n", + "\n", + "# create the attributor, specifying model and composite\n", + "with Gradient(model=model, composite=composite) as attributor:\n", + " # compute the model output and attribution\n", + " output, attribution = attributor(data, target)\n", + "\n", + "print(f'Prediction: {output.argmax(1)[0].item()}')" + ] + }, + { + "cell_type": "markdown", + "id": "363a14c5", + "metadata": {}, + "source": [ + "Visualize the attribution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1032120", + "metadata": {}, + "outputs": [], + "source": [ + "# sum over the channels\n", + "relevance = attribution.sum(1)\n", + "\n", + "# create an image of the visualize attribution\n", + "img = imgify(relevance, symmetric=True, cmap='coldnhot')\n", + "\n", + "# show the image\n", + "display(img)" + ] + }, + { + "cell_type": "markdown", + "id": "03e742f2", + "metadata": {}, + "source": [ + "### 2.5 LRP with EpsilonGammaBox with modified epsilon and stabilizer\n", + "We again compute the LRP-attribution with the `EpsilonGammaBox` **Composite**.\n", + "This time, we change the epsilon and stabilizer values, which both stabilize the\n", + "denominator in their respective rules:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22c50553", + "metadata": {}, + "outputs": [], + "source": [ + "# the EpsilonGammaBox composite needs the lowest and highest values, which are\n", + "# here for ImageNet 0. and 1. with a different normalization for each channel\n", + "low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))\n", + "\n", + "# create a composite, specifying required arguments\n", + "composite = EpsilonGammaBox(\n", + " low=low,\n", + " high=high,\n", + " epsilon=Stabilizer(epsilon=0.1, norm_scale=True),\n", + " stabilizer=lambda x: ((x == 0.) + x.sign()) * x.abs().clip(min=1e-6),\n", + ")\n", + "\n", + "# choose a target class for the attribution (label 437 is lighthouse)\n", + "target = torch.eye(1000)[[437]]\n", + "\n", + "# create the attributor, specifying model and composite\n", + "with Gradient(model=model, composite=composite) as attributor:\n", + " # compute the model output and attribution\n", + " output, attribution = attributor(data, target)\n", + "\n", + "print(f'Prediction: {output.argmax(1)[0].item()}')" + ] + }, + { + "cell_type": "markdown", + "id": "352c0883", + "metadata": {}, + "source": [ + "Visualize the attribution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b8182b3c", + "metadata": {}, + "outputs": [], + "source": [ + "# sum over the channels\n", + "relevance = attribution.sum(1)\n", + "\n", + "# create an image of the visualize attribution\n", + "img = imgify(relevance, symmetric=True, cmap='coldnhot')\n", + "\n", + "# show the image\n", + "display(img)" + ] + }, + { + "cell_type": "markdown", + "id": "157bc57d", + "metadata": {}, + "source": [ + "### 2.6 More Visualization\n", + "We can try out different color-maps by either using another built-in color map, or using the color-map specification language:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d4b0f1b", + "metadata": {}, + "outputs": [], + "source": [ + "print('Built-in color-map bwr')\n", + "display(imgify(relevance, symmetric=True, cmap='bwr'))\n", + "\n", + "print('CMLS code for color-map from cyan to grey to purple')\n", + "display(imgify(relevance, symmetric=True, cmap='0ff,444,f0f'))\n", + "\n", + "print('CMSL code for grey scale for negative and red for positive values')\n", + "display(imgify(relevance, symmetric=True, cmap='fff,000,f00'))" + ] + }, + { + "cell_type": "markdown", + "id": "b6bf7029", + "metadata": {}, + "source": [ + "To directly save the visualized attribution, we can use `imsave` instead:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "838c2b4b", + "metadata": {}, + "outputs": [], + "source": [ + "# directly save the visualized attribution\n", + "imsave('attrib-1.png', relevance, symmetric=True, cmap='bwr')" + ] + }, + { + "cell_type": "markdown", + "id": "27b6dcd1", + "metadata": {}, + "source": [ + "## 3. VGG11 with BatchNorm\n", + "\n", + "VGG11 with BatchNorm requires a canonizer for LRP to function properly. We\n", + "initialize the VGG11-bn model and optionally load the hyperparameters. Set\n", + "`weights='IMAGENET1K_V1'` to use the pre-trained model instead of the random\n", + "one:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16b50d14", + "metadata": {}, + "outputs": [], + "source": [ + "# load the model and set it to evaluation mode\n", + "model = vgg11_bn(weights=None).eval()" + ] + }, + { + "cell_type": "markdown", + "id": "7f677171", + "metadata": {}, + "source": [ + "### 3.1 LRP with EpsilonGammaBox\n", + "We now compute the LRP-attribution with the `EpsilonGammaBox` **Composite**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2154e6c5", + "metadata": {}, + "outputs": [], + "source": [ + "# use the VGG-specific canonizer (alias for SequentialMergeBatchNorm, only\n", + "# needed with batch-norm)\n", + "canonizer = VGGCanonizer()\n", + "\n", + "# the EpsilonGammaBox composite needs the lowest and highest values, which are\n", + "# here for ImageNet 0. and 1. with a different normalization for each channel\n", + "low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))\n", + "\n", + "# create a composite, specifying arguments and the canonizers\n", + "composite = EpsilonGammaBox(low=low, high=high, canonizers=[canonizer])\n", + "\n", + "# choose a target class for the attribution (label 437 is lighthouse)\n", + "target = torch.eye(1000)[[437]]\n", + "\n", + "# create the attributor, specifying model and composite\n", + "with Gradient(model=model, composite=composite) as attributor:\n", + " # compute the model output and attribution\n", + " output, attribution = attributor(data, target)\n", + "\n", + "print(f'Prediction: {output.argmax(1)[0].item()}')" + ] + }, + { + "cell_type": "markdown", + "id": "ae7a7f06", + "metadata": {}, + "source": [ + "Visualize the attribution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e17ed91", + "metadata": {}, + "outputs": [], + "source": [ + "# sum over the channels\n", + "relevance = attribution.sum(1)\n", + "\n", + "# create an image of the visualize attribution\n", + "img = imgify(relevance, symmetric=True, cmap='coldnhot')\n", + "\n", + "# show the image\n", + "display(img)" + ] + }, + { + "cell_type": "markdown", + "id": "bc6b7066", + "metadata": {}, + "source": [ + "### 3.2 LRP with modified EpsilonGammaBox and ignored BatchNorm\n", + "We can modify the EpsilonGammaBox by changing the epsilon/ gamma, and supplying\n", + "different layer mappings. Here, we explicitly ignore the BatchNorm layers,\n", + "instead of merging them in the adjacent linear layer by not supplying the\n", + "`VGGCanonizer`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b04f3b2", + "metadata": {}, + "outputs": [], + "source": [ + "# the EpsilonGammaBox composite needs the lowest and highest values, which are\n", + "# here for ImageNet 0. and 1. with a different normalization for each channel\n", + "low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))\n", + "\n", + "# create a composite, specifying arguments and canonizers\n", + "composite = EpsilonGammaBox(\n", + " low=low,\n", + " high=high,\n", + " canonizers=[],\n", + " gamma=0.5, # change the gammma for all layers\n", + " epsilon=0.1, # change the epsilon for all layers\n", + " layer_map=[\n", + " (BatchNorm, Pass()), # explicitly ignore BatchNorm\n", + " ]\n", + ")\n", + "\n", + "# choose a target class for the attribution (label 437 is lighthouse)\n", + "target = torch.eye(1000)[[437]]\n", + "\n", + "# create the attributor, specifying model and composite\n", + "with Gradient(model=model, composite=composite) as attributor:\n", + " # compute the model output and attribution\n", + " output, attribution = attributor(data, target)\n", + "\n", + "print(f'Prediction: {output.argmax(1)[0].item()}')" + ] + }, + { + "cell_type": "markdown", + "id": "3573e7bf", + "metadata": {}, + "source": [ + "Visualize the attribution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ea71e4d", + "metadata": {}, + "outputs": [], + "source": [ + "# sum over the channels\n", + "relevance = attribution.sum(1)\n", + "\n", + "# create an image of the visualize attribution\n", + "img = imgify(relevance, symmetric=True, cmap='coldnhot')\n", + "\n", + "# show the image\n", + "display(img)" + ] + }, + { + "cell_type": "markdown", + "id": "09320e02", + "metadata": {}, + "source": [ + "### 3.3 LRP with custom NameMapComposite\n", + "Here we demonstrate how rules can be assigned on a name-base using the abstract\n", + "`NameMapComposite`. Note that this particular case will only work with VGG11-bn,\n", + "so we redefine the model (set `weights='IMAGENET1K_V1'` to use the\n", + "pre-trained model):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8049e03d", + "metadata": {}, + "outputs": [], + "source": [ + "# load the model and set it to evaluation mode\n", + "model = vgg11_bn(weights=None).eval()" + ] + }, + { + "cell_type": "markdown", + "id": "0c788424", + "metadata": {}, + "source": [ + "We can take a look at all leaf modules by doing the following:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5f00750", + "metadata": {}, + "outputs": [], + "source": [ + "leaves = [\n", + " (name, module)\n", + " for name, module in model.named_modules()\n", + " if not [*islice(module.children(), 1)]\n", + "]\n", + "maxlen = max(len(name) for name, _ in leaves)\n", + "print('\\n'.join(f'{name:{maxlen}s}: {module}' for name, module in leaves))" + ] + }, + { + "cell_type": "markdown", + "id": "1e5cff96", + "metadata": {}, + "source": [ + "We can look at the architecture and then assign the rules directly to the names:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38edd796", + "metadata": {}, + "outputs": [], + "source": [ + "name_map = [\n", + " (['features.0'], Flat()), # Conv2d\n", + " (['features.1'], Pass()), # BatchNorm2d\n", + " (['features.2'], Pass()), # ReLU\n", + " (['features.3'], Norm()), # MaxPool2d\n", + " (['features.4'], ZPlus()), # Conv2d\n", + " (['features.5'], Pass()), # BatchNorm2d\n", + " (['features.6'], Pass()), # ReLU\n", + " (['features.7'], Norm()), # MaxPool2d\n", + " (['features.8'], ZPlus()), # Conv2d\n", + " (['features.9'], Pass()), # BatchNorm2d\n", + " (['features.10'], Pass()), # ReLU\n", + " (['features.11'], ZPlus()), # Conv2d\n", + " (['features.12'], Pass()), # BatchNorm2d\n", + " (['features.13'], Pass()), # ReLU\n", + " (['features.14'], Norm()), # MaxPool2d\n", + " (['features.15'], ZPlus()), # Conv2d\n", + " (['features.16'], Pass()), # BatchNorm2d\n", + " (['features.17'], Pass()), # ReLU\n", + " (['features.18'], ZPlus()), # Conv2d\n", + " (['features.19'], Pass()), # BatchNorm2d\n", + " (['features.20'], Pass()), # ReLU\n", + " (['features.21'], Norm()), # MaxPool2d\n", + " (['features.22'], ZPlus()), # Conv2d\n", + " (['features.23'], Pass()), # BatchNorm2d\n", + " (['features.24'], Pass()), # ReLU\n", + " (['features.25'], ZPlus()), # Conv2d\n", + " (['features.26'], Pass()), # BatchNorm2d\n", + " (['features.27'], Pass()), # ReLU\n", + " (['features.28'], Norm()), # MaxPool2d\n", + " (['avgpool'], Norm()), # AdaptiveAvgPool2d\n", + " (['classifier.0'], Epsilon()), # Linear\n", + " (['classifier.1'], Pass()), # ReLU\n", + " (['classifier.2'], Pass()), # Dropout\n", + " (['classifier.3'], Epsilon()), # Linear\n", + " (['classifier.4'], Pass()), # ReLU\n", + " (['classifier.5'], Pass()), # Dropout\n", + " (['classifier.6'], Epsilon()), # Linear\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "7580d344", + "metadata": {}, + "source": [ + "Now we can use the name-map to create the composite:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52e325b8", + "metadata": {}, + "outputs": [], + "source": [ + "# use the VGG-specific canonizer (alias for SequentialMergeBatchNorm, only\n", + "# needed with batch-norm)\n", + "canonizer = VGGCanonizer()\n", + "\n", + "# create a composite, specifying arguments and canonizers\n", + "composite = NameMapComposite(\n", + " name_map=name_map,\n", + " canonizers=[canonizer],\n", + ")\n", + "\n", + "# choose a target class for the attribution (label 437 is lighthouse)\n", + "target = torch.eye(1000)[[437]]\n", + "\n", + "# create the attributor, specifying model and composite\n", + "with Gradient(model=model, composite=composite) as attributor:\n", + " # compute the model output and attribution\n", + " output, attribution = attributor(data, target)\n", + "\n", + "print(f'Prediction: {output.argmax(1)[0].item()}')" + ] + }, + { + "cell_type": "markdown", + "id": "c194e075", + "metadata": {}, + "source": [ + "Visualize the attribution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac63a151", + "metadata": {}, + "outputs": [], + "source": [ + "# sum over the channels\n", + "relevance = attribution.sum(1)\n", + "\n", + "# create an image of the visualize attribution\n", + "img = imgify(relevance, symmetric=True, cmap='coldnhot')\n", + "\n", + "# show the image\n", + "display(img)" + ] + }, + { + "cell_type": "markdown", + "id": "ba184117", + "metadata": {}, + "source": [ + "## 4. ResNet18\n", + "We initialize the ResNet18 model and optionally load the hyperparameters. Set\n", + "`weights='IMAGENET1K_V1'` to use the pre-trained model instead of the random one:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53bd97d9", + "metadata": {}, + "outputs": [], + "source": [ + "# load the model and set it to evaluation mode\n", + "model = resnet18(weights=None).eval()" + ] + }, + { + "cell_type": "markdown", + "id": "e6dee524", + "metadata": {}, + "source": [ + "### 4.1 LRP with EpsilonPlusFlat\n", + "Compute the LRP-attribution using the ``EpsilonPlusFlat`` **Composite**:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a2cd10b", + "metadata": {}, + "outputs": [], + "source": [ + "# use the ResNet-specific canonizer\n", + "canonizer = ResNetCanonizer()\n", + "\n", + "# create a composite, specifying the canonizers\n", + "composite = EpsilonPlusFlat(canonizers=[canonizer])\n", + "\n", + "# choose a target class for the attribution (label 437 is lighthouse)\n", + "target = torch.eye(1000)[[437]]\n", + "\n", + "# create the attributor, specifying model and composite\n", + "with Gradient(model=model, composite=composite) as attributor:\n", + " # compute the model output and attribution\n", + " output, attribution = attributor(data, target)\n", + "\n", + "print(f'Prediction: {output.argmax(1)[0].item()}')" + ] + }, + { + "cell_type": "markdown", + "id": "58799748", + "metadata": {}, + "source": [ + "Visualize the attribution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e55ced93", + "metadata": {}, + "outputs": [], + "source": [ + "# sum over the channels\n", + "relevance = attribution.sum(1)\n", + "\n", + "# create an image of the visualize attribution\n", + "img = imgify(relevance, symmetric=True, cmap='coldnhot')\n", + "\n", + "# show the image\n", + "display(img)" + ] + }, + { + "cell_type": "markdown", + "id": "d08826d3", + "metadata": {}, + "source": [ + "### 4.2 LRP with EpsilonGammaBox\n", + "Compute the LRP-attribution using the ``EpsilonGammaBox`` **Composite**:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0921d8d", + "metadata": {}, + "outputs": [], + "source": [ + "# use the ResNet-specific canonizer\n", + "canonizer = ResNetCanonizer()\n", + "\n", + "# the ZBox rule needs the lowest and highest values, which are here for\n", + "# ImageNet 0. and 1. with a different normalization for each channel\n", + "low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))\n", + "\n", + "# create a composite, specifying the canonizers, if any\n", + "composite = EpsilonGammaBox(low=low, high=high, canonizers=[canonizer])\n", + "\n", + "# choose a target class for the attribution (label 437 is lighthouse)\n", + "target = torch.eye(1000)[[437]]\n", + "\n", + "# create the attributor, specifying model and composite\n", + "with Gradient(model=model, composite=composite) as attributor:\n", + " # compute the model output and attribution\n", + " output, attribution = attributor(data, target)\n", + "\n", + "print(f'Prediction: {output.argmax(1)[0].item()}')" + ] + }, + { + "cell_type": "markdown", + "id": "5db70111", + "metadata": {}, + "source": [ + "Visualize the attribution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec6021bd", + "metadata": {}, + "outputs": [], + "source": [ + "# sum over the channels\n", + "relevance = attribution.sum(1)\n", + "\n", + "# create an image of the visualize attribution\n", + "img = imgify(relevance, symmetric=True, cmap='coldnhot')\n", + "\n", + "# show the image\n", + "display(img)" + ] + }, + { + "cell_type": "markdown", + "id": "9230131a", + "metadata": {}, + "source": [ + "### 4.3 LRP with custom LayerMapComposite\n", + "In this demonstration, we want to ignore the BatchNorm layers and attribute the\n", + "residual connections not by contribution but equally. This setup requires no **Canonizer**, but a special Composite which ignores the BatchNorm modules for LRP by assigning them the `Pass` rule." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b249f6b1", + "metadata": {}, + "outputs": [], + "source": [ + "# the ZBox rule needs the lowest and highest values, which are here for\n", + "# ImageNet 0. and 1. with a different normalization for each channel\n", + "low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))\n", + "\n", + "# create a composite, specifying the canonizers, if any\n", + "composite = SpecialFirstLayerMapComposite(\n", + " # the layer map is a list of tuples, where the first element is the target\n", + " # layer type, and the second is the rule template\n", + " layer_map=[\n", + " (Activation, Pass()), # ignore activations\n", + " (AvgPool, Norm()), # normalize relevance for any AvgPool\n", + " (Convolution, ZPlus()), # any convolutional layer\n", + " (Linear, Epsilon(epsilon=1e-6)), # this is the dense Linear, not any\n", + " (BatchNorm, Pass()), # ignore BatchNorm\n", + " ],\n", + " # the first map is only used once, to the first module which applies to the\n", + " # map, i.e. here the first layer of type AnyLinear\n", + " first_map = [\n", + " (AnyLinear, ZBox(low, high))\n", + " ]\n", + ")\n", + "\n", + "# choose a target class for the attribution (label 437 is lighthouse)\n", + "target = torch.eye(1000)[[437]]\n", + "\n", + "# create the attributor, specifying model and composite\n", + "with Gradient(model=model, composite=composite) as attributor:\n", + " # compute the model output and attribution\n", + " output, attribution = attributor(data, target)\n", + "\n", + "print(f'Prediction: {output.argmax(1)[0].item()}')" + ] + }, + { + "cell_type": "markdown", + "id": "c33934f9", + "metadata": {}, + "source": [ + "Visualize the attribution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04c40cc9", + "metadata": {}, + "outputs": [], + "source": [ + "# sum over the channels\n", + "relevance = attribution.sum(1)\n", + "\n", + "# create an image of the visualize attribution\n", + "img = imgify(relevance, symmetric=True, cmap='coldnhot')\n", + "\n", + "# show the image\n", + "display(img)" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorial/image-classification-vgg.ipynb b/docs/source/tutorial/image-classification-vgg.ipynb deleted file mode 100644 index 74f52eb..0000000 --- a/docs/source/tutorial/image-classification-vgg.ipynb +++ /dev/null @@ -1,248 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "810f3c46", - "metadata": {}, - "source": [ - "# Image Classification with VGG\n", - "\n", - "This tutorial will introduce the attribution of image classifiers using VGG11\n", - "on ImageNet. Feel free to replace VGG11 with any other version of VGG.\n", - "\n", - "First, we install **Zennit**. This includes its dependencies `Pillow`,\n", - "`torch` and `torchvision`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c0e2f1fe", - "metadata": {}, - "outputs": [], - "source": [ - "%pip install zennit" - ] - }, - { - "cell_type": "markdown", - "id": "3a2dc4cd", - "metadata": {}, - "source": [ - "Then, we import necessary modules, classes and functions:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a9a3fa5e", - "metadata": {}, - "outputs": [], - "source": [ - "import logging\n", - "\n", - "import torch\n", - "from PIL import Image\n", - "from torchvision.transforms import Compose, Resize, CenterCrop\n", - "from torchvision.transforms import ToTensor, Normalize\n", - "from torchvision.models import vgg11_bn\n", - "\n", - "from zennit.attribution import Gradient, SmoothGrad\n", - "from zennit.composites import EpsilonPlusFlat, EpsilonGammaBox\n", - "from zennit.image import imgify, imsave\n", - "from zennit.torchvision import VGGCanonizer" - ] - }, - { - "cell_type": "markdown", - "id": "e48e434e", - "metadata": {}, - "source": [ - "We download an image of the [Dornbusch\n", - "Lighthouse](https://en.wikipedia.org/wiki/Dornbusch_Lighthouse) from [Wikimedia\n", - "Commons](https://commons.wikimedia.org/wiki/File:2006_09_06_180_Leuchtturm.jpg):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bfaa9f3a", - "metadata": {}, - "outputs": [], - "source": [ - "torch.hub.download_url_to_file(\n", - " 'https://upload.wikimedia.org/wikipedia/commons/thumb/8/8b/2006_09_06_180_Leuchtturm.jpg/640px-2006_09_06_181_Leuchtturm.jpg',\n", - " 'dornbusch-lighthouse.jpg',\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "0798a0c1", - "metadata": {}, - "source": [ - "We load and prepare the data. The image is resized such that the shorter side\n", - "is 256 pixels in size, then center-cropped to `(224, 224)`, converted to a\n", - "`torch.Tensor`, and then normalized according the channel-wise mean and\n", - "standard deviation of the ImageNet dataset:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4989b65c", - "metadata": {}, - "outputs": [], - "source": [ - "# define the base image transform\n", - "transform_img = Compose([\n", - " Resize(256),\n", - " CenterCrop(224),\n", - "])\n", - "# define the full tensor transform\n", - "transform = Compose([\n", - " transform_img,\n", - " ToTensor(),\n", - " Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n", - "])\n", - "\n", - "# load the image\n", - "image = Image.open('dornbusch-lighthouse.jpg')\n", - "\n", - "# transform the PIL image and insert a batch-dimension\n", - "data = transform(image)[None]" - ] - }, - { - "cell_type": "markdown", - "id": "882b4dd8", - "metadata": {}, - "source": [ - "We can look at the original image and the cropped image:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "072a3ad0", - "metadata": {}, - "outputs": [], - "source": [ - "# display the original image\n", - "display(image)\n", - "# display the resized and cropped image\n", - "display(transform_img(image))" - ] - }, - { - "cell_type": "markdown", - "id": "3d45bd9b", - "metadata": {}, - "source": [ - "Then, we initialize the model and load the hyperparameters. Set\n", - "`pretrained=True` to use the pre-trained model instead of the random one:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16b50d14", - "metadata": {}, - "outputs": [], - "source": [ - "# load the model and set it to evaluation mode\n", - "model = vgg11_bn(pretrained=False).eval()" - ] - }, - { - "cell_type": "markdown", - "id": "7f677171", - "metadata": {}, - "source": [ - "Compute the attribution using the ``EpsilonPlusFlat`` **Composite**:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e17da931", - "metadata": {}, - "outputs": [], - "source": [ - "# use the VGG-specific canonizer (alias for SequentialMergeBatchNorm, only\n", - "# needed with batch-norm)\n", - "canonizer = VGGCanonizer()\n", - "\n", - "# create a composite, specifying the canonizers, if any\n", - "composite = EpsilonPlusFlat(canonizers=[canonizer])\n", - "\n", - "# choose a target class for the attribution (label 437 is lighthouse)\n", - "target = torch.eye(1000)[[437]]\n", - "\n", - "# create the attributor, specifying model and composite\n", - "with Gradient(model=model, composite=composite) as attributor:\n", - " # compute the model output and attribution\n", - " output, attribution = attributor(data, target)\n", - "\n", - "print(f'Prediction: {output.argmax(1)[0].item()}')" - ] - }, - { - "cell_type": "markdown", - "id": "e49b5056", - "metadata": {}, - "source": [ - "Visualize the attribution:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7eda5200", - "metadata": {}, - "outputs": [], - "source": [ - "# sum over the channels\n", - "relevance = attribution.sum(1)\n", - "\n", - "# create an image of the visualize attribution\n", - "img = imgify(relevance, symmetric=True, cmap='coldnhot')\n", - "\n", - "# show the image\n", - "display(img)" - ] - }, - { - "cell_type": "markdown", - "id": "b6bf7029", - "metadata": {}, - "source": [ - "Here, `imgify` produces a PIL-image, which can be saved with `.save()`.\n", - "To directly save the visualized attribution, we can use `imsave` instead:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "838c2b4b", - "metadata": {}, - "outputs": [], - "source": [ - "# directly save the visualized attribution\n", - "imsave('attrib-1.png', relevance, symmetric=True, cmap='bwr')" - ] - } - ], - "metadata": { - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/source/tutorial/index.rst b/docs/source/tutorial/index.rst index c9d5165..d89d94b 100644 --- a/docs/source/tutorial/index.rst +++ b/docs/source/tutorial/index.rst @@ -5,9 +5,8 @@ .. toctree:: :maxdepth: 1 - image-classification-vgg + image-classification-vgg-resnet .. - image-classification-with-vgg-and-resnet image-segmentation-with-unet text-classification-with-tbd audio-classification-with-tbd