Skip to content

Commit

Permalink
README: Update, better usage, relative links
Browse files Browse the repository at this point in the history
In README.md:
- bring the Usage part up-to-date, with some better formatting
- add a small demonstration in Usage
- move Documentation to its own Section, so it can be found more easily
- remove statement about "interfaces can break anytime", the interface
  should be mostly stable now
- add another link to the documentation after Usage
- give more details in the introduction to the Example
- add a section Example Heatmaps, which shows example heatmaps for
  various supported attribution methods for VGG16 and ResNet50 produced
  by feed_forward.py
- add a sub-section about the documentation in Contributing with notes
  about building
- make absolute github links relative, such that they can point to their
  respective revision. They were previously absolute links for display
  on PyPI, which cannot correctly disply relative files.
- add badge for tests
- fix inconsistent capitalizations
- fix Pytest link pointing to Pylint

New images:
- add two images to show heatmaps for VGG16 and ResNet50

In setup.py
- define a function to fetch the long description for PyPI, which
  replaces relative links with absolute ones on github. This fixes the
  previous problem with relative links on PyPI, and will also display
  and reference files for the correct release, instead of always for
  master.
- due to `python -m build` by default copying all files for the sdist,
  and then building the wheel in the same folder in an isolated
  environment, the revison would not be correct within the wheel. The
  sdist however creates a PKG-INFO file, which will contain the
  long_description with the absolute links, which will be used if the
  git version cannot be used. The fallback is to use master, as done
  previously.
  • Loading branch information
chr5tphr committed Mar 31, 2022
1 parent c76d7ed commit a5a622d
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 90 deletions.
196 changes: 109 additions & 87 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
# Zennit
![Zennit-Logo](https://raw.githubusercontent.com/chr5tphr/zennit/master/share/img/zennit.png)
![Zennit-Logo](share/img/zennit.png)

[![Documentation Status](https://readthedocs.org/projects/zennit/badge/?version=latest)](https://zennit.readthedocs.io/en/latest/?badge=latest)
[![tests](https://github.com/chr5tphr/zennit/actions/workflows/tests.yml/badge.svg)](https://github.com/chr5tphr/zennit/actions/workflows/tests.yml)
[![PyPI Version](https://img.shields.io/pypi/v/zennit)](https://pypi.org/project/zennit/)
[![License](https://img.shields.io/pypi/l/zennit)](https://github.com/chr5tphr/zennit/blob/master/COPYING.LESSER)

Zennit (**Z**ennit **e**xplains **n**eural **n**etworks **i**n **t**orch)
is a high-level framework in Python using PyTorch for explaining/exploring neural networks.
Its design philosophy is intended to provide high customizability and integration as a standardized solution
for applying LRP-based attribution methods in research.
Zennit strictly requires models to use PyTorch's `torch.nn.Module` structure
(including activation functions).
Zennit (**Z**ennit **e**xplains **n**eural **n**etworks **i**n **t**orch) is a
high-level framework in Python using Pytorch for explaining/exploring neural
networks. Its design philosophy is intended to provide high customizability and
integration as a standardized solution for applying rule-based attribution
methods in research, with a strong focus on Layerwise Relevance Propagation
(LRP). Zennit strictly requires models to use Pytorch's `torch.nn.Module`
structure (including activation functions).

Zennit is currently under development and has not yet reached a stable state.
Interfaces may change suddenly and without warning, so please be careful when attempting to use Zennit in its current
state.
Zennit is currently under active development, but should be mostly stable.

The latest documentation is hosted on Read the Docs at [zennit.readthedocs.io](https://zennit.readthedocs.io/en/latest/).

If you find Zennit useful for your research, please consider citing our related [paper](https://arxiv.org/abs/2106.13200):
If you find Zennit useful for your research, please consider citing our related
[paper](https://arxiv.org/abs/2106.13200):
```
@article{anders2021software,
author = {Anders, Christopher J. and
Expand All @@ -33,6 +32,10 @@ If you find Zennit useful for your research, please consider citing our related
}
```

## Documentation
The latest documentation is hosted at
[zennit.readthedocs.io](https://zennit.readthedocs.io/en/latest/).

## Install

To install directly from PyPI using pip, use:
Expand All @@ -47,35 +50,65 @@ $ pip install ./zennit
```

## Usage
An example can be found in `share/example/feed_forward.py`.
Currently, only feed-forward type models are supported.

At its heart, Zennit registers hooks at PyTorch's Module level, to modify the backward pass to produce LRP
attributions (instead of the usual gradient).
All rules are implemented as hooks (`zennit/rules.py`) and most use the LRP-specific `BasicHook` (`zennit/core.py`).
**Composites** are a way of choosing the right hook for the right layer.
In addition to the abstract **NameMapComposite**, which assigns hooks to layers by name, and **LayerMapComposite**,
which assigns hooks to layers based on their Type, there exist explicit Composites, which currently are
* EpsilonGammaBox (ZBox in input, epsilon in dense, Gamma 0.25 in convolutions)
* EpsilonPlus (PresetA in iNNvestigate)
* EpsilonPlusFlat (PresetAFlat in iNNvestigate)
* EpsilonAlpha2Beta1 (PresetB in iNNvestigate)
* EpsilonAlpha2Beta1Flat (PresetBFlat in iNNvestigate).

They may be used by directly importing from `zennit.composites`, or by using
their snake-case name as key for `zennit.composites.COMPOSITES`. Additionally,
there are **Canonizers**, which modify models such that LRP may be applied, if
needed. Currently, there are `MergeBatchNorm`, `AttributeCanonizer` and
`CompositeCanonizer`. There are two versions of the abstract `MergeBatchNorm`,
`SequentialMergeBatchNorm`, which automatically detects BatchNorm layers
followed by linear layers in sequential networks, and `NamedMergeBatchNorm`,
which expects a list of tuples to assign one or more linear layers to one batch
norm layer. `AttributeCanonizer` temporarily overwrites attributes of
applicable modules, e.g. for ResNet50, the forward function (attribute) of the
Bottleneck modules is overwritten to handle the residual connection.
At its heart, Zennit registers hooks at Pytorch's Module level, to modify the
backward pass to produce rule-based attributions like LRP (instead of the usual
gradient). All rules are implemented as hooks
([`zennit/rules.py`](zennit/rules.py)) and most use the LRP basis
`BasicHook` ([`zennit/core.py`](zennit/core.py)).

**Composites** ([`zennit/composites.py`](zennit/composites.py)) are a way of
choosing the right hook for the right layer. In addition to the abstract
**NameMapComposite**, which assigns hooks to layers by name, and
**LayerMapComposite**, which assigns hooks to layers based on their Type, there
exist explicit **Composites**, some of which are `EpsilonGammaBox` (`ZBox` in
input, `Epsilon` in dense, `Gamma` in convolutions) or `EpsilonPlus` (`Epsilon`
in dense, `ZPlus` in convolutions). All composites may be used by directly
importing from `zennit.composites`, or by using their snake-case name as key
for `zennit.composites.COMPOSITES`.

**Canonizers** ([`zennit/canonizers.py`](zennit/canonizers.py)) temporarily
transform models into a canonical form, if required, like
`SequentialMergeBatchNorm`, which automatically detects and merges BatchNorm
layers followed by linear layers in sequential networks, or
`AttributeCanonizer`, which temporarily overwrites attributes of applicable
modules, e.g. to handle the residual connection in ResNet-Bottleneck modules.

**Attributors** ([`zennit/attribution.py`](zennit/attribution.py)) directly
execute the necessary steps to apply certain attribution methods, like the
simple `Gradient`, `SmoothGrad` or `Occlusion`. An optional **Composite** may
be passed, which will be applied during the **Attributor**'s execution to
compute the modified gradient, or hybrid methods.

Using all of these components, an LRP-type attribution for VGG16 with
batch-norm layers with respect to label 0 may be computed using:

```python
import torch
from torchvision.models import vgg16_bn

from zennit.composites import EpsilonGammaBox
from zennit.canonizers import SequentialMergeBatchNorm
from zennit.attribution import Gradient


data = torch.randn(1, 3, 224, 224)
model = vgg16_bn()

canonizers = [SequentialMergeBatchNorm()]
composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers)

with Gradient(model=model, composite=composite) as attributor:
out, relevance = attributor(data, torch.eye(1000)[[0]])
```

For more details and examples, have a look at our
[**documentation**](https://zennit.readthedocs.io/en/latest/).

## Example
This example requires bash, cURL and (magic-)file.
This example demonstrates how the script at
[`share/example/feed_forward.py`](share/example/feed_forward.py) can be used to
generate attribution heatmaps for VGG16.
It requires bash, cURL and (magic-)file.

Create a virtual environment, install Zennit and download the example scripts:
```shell
Expand All @@ -95,7 +128,7 @@ $ mkdir params data results
$ bash download-lighthouses.sh --output data/lighthouses
$ curl -o params/vgg16-397923af.pth 'https://download.pytorch.org/models/vgg16-397923af.pth'
```
This creates the needed directories and downloads the pre-trained vgg16 parameters and 8 images of light houses from wikimedia commons into the required label-directory structure for the imagenet dataset in Pytorch.
This creates the needed directories and downloads the pre-trained VGG16 parameters and 8 images of light houses from Wikimedia Commons into the required label-directory structure for the Imagenet dataset in Pytorch.

The `feed_forward.py` example may then be run using:
```shell
Expand All @@ -109,13 +142,17 @@ $ .venv/bin/python feed_forward.py \
--relevance-norm symmetric \
--cmap coldnhot
```
which computes the lrp heatmaps according to the `epsilon_gamma_box` rule and stores them in `results`, along with the respective input images.
Other possible composites that can be passed to `--composites` are, e.g., `epsilon_plus`, `epsilon_alpha2_beta1_flat`, `guided_backprop`, `excitation_backprop`.
which computes the LRP heatmaps according to the `epsilon_gamma_box` rule and
stores them in `results`, along with the respective input images. Other
possible composites that can be passed to `--composites` are, e.g.,
`epsilon_plus`, `epsilon_alpha2_beta1_flat`, `guided_backprop`,
`excitation_backprop`.

The resulting heatmaps may look like the following:
![beacon heatmaps](https://raw.githubusercontent.com/chr5tphr/zennit/master/share/img/beacon_vgg16_epsilon_gamma_box.png)
![beacon heatmaps](share/img/beacon_vgg16_epsilon_gamma_box.png)

Alternatively, heatmaps for SmoothGrad with absolute relevances may be computed by omitting `--composite` and supplying `--attributor`:
Alternatively, heatmaps for SmoothGrad with absolute relevances may be computed
by omitting `--composite` and supplying `--attributor`:
```shell
$ .venv/bin/python feed_forward.py \
data/lighthouses \
Expand All @@ -129,7 +166,8 @@ $ .venv/bin/python feed_forward.py \
```
For Integrated Gradients, `--attributor integrads` may be provided.

Heatmaps for Occlusion Analysis with unaligned relevances may be computed by executing:
Heatmaps for Occlusion Analysis with unaligned relevances may be computed by
executing:
```shell
$ .venv/bin/python feed_forward.py \
data/lighthouses \
Expand All @@ -142,54 +180,38 @@ $ .venv/bin/python feed_forward.py \
--cmap hot
```

The following is a slightly modified excerpt of `share/example/feed_forward.py`:
```python
...
# the maximal input shape, needed for the ZBox rule
shape = (batch_size, 3, 224, 224)

composite_kwargs = {
'low': norm_fn(torch.zeros(*shape, device=device)), # the lowest and ...
'high': norm_fn(torch.ones(*shape, device=device)), # the highest pixel value for ZBox
'canonizers': [VGG16Canonizer()] # the torchvision specific vgg16 canonizer
}

# create a composite specified by a name; the COMPOSITES dict includes all preset composites
# provided by zennit.
composite = COMPOSITES['epsilon_gamma_box'](**composite_kwargs)

# disable requires_grad for all parameters, we do not need their modified gradients
for param in model.parameters():
param.requires_grad = False

# create the composite context outside the main loop, such that the canonizers and hooks do not
# need to be registered and removed for each step.
with composite.context(model) as modified_model:
for data, target in loader:
# we use data without the normalization applied for visualization, and with the
# normalization applied as the model input
data_norm = norm_fn(data.to(device))
data_norm.requires_grad_()

# one-hot encoding of the target labels of size (len(target), 1000)
output_relevance = torch.eye(n_outputs, device=device)[target]

out = modified_model(data_norm)
# a simple backward pass will accumulate the relevance in data_norm.grad
torch.autograd.backward((out,), (output_relevance,))
...
```
## Example Heatmaps
Heatmaps of various attribution methods for VGG16 and ResNet50, all generated using
[`share/example/feed_forward.py`](share/example/feed_forward.py), can be found below.

<details>
<summary>Heatmaps for VGG16</summary>

![vgg16 heatmaps](share/img/beacon_vgg16_various.webp)
</details>

<details>
<summary>Heatmaps for ResNet50</summary>

![resnet50 heatmaps](share/img/beacon_resnet50_various.webp)
</details>

## Contributing

### Code Style
We use [PEP8](https://www.python.org/dev/peps/pep-0008) with a line-width of 120 characters.
For docstrings we use [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html).
We use [PEP8](https://www.python.org/dev/peps/pep-0008) with a line-width of 120 characters. For
docstrings we use [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html).

We use [`flake8`](https://pypi.org/project/flake8/) for quick style checks and [`pylint`](https://pypi.org/project/pylint/) for thorough style checks.
We use [`flake8`](https://pypi.org/project/flake8/) for quick style checks and
[`pylint`](https://pypi.org/project/pylint/) for thorough style checks.

### Testing
Tests are written using [pytest](https://pypi.org/project/pylint/) and executed in a separate environment using [tox](https://tox.readthedocs.io/en/latest/).
Tests are written using [Pytest](https://docs.pytest.org) and executed
in a separate environment using [Tox](https://tox.readthedocs.io/en/latest/).

A full style check and all tests can be run by simply calling `tox` in the repository root.

### Documentation
The documentation is written using [Sphinx](https://www.sphinx-doc.org). It can be built at
`docs/build` using the respective Tox environment with `tox -e docs`. To rebuild the full
documentation, `tox -e docs -- -aE` can be used.
39 changes: 36 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,42 @@
#!/usr/bin/env python3
import re
from setuptools import setup, find_packages
from subprocess import run, CalledProcessError


with open('README.md', 'r', encoding='utf-8') as fd:
long_description = fd.read()
def get_long_description(project_path):
'''Fetch the README contents and replace relative links with absolute ones
pointing to github for correct behaviour on PyPI.
'''
try:
revision = run(
['git', 'describe', '--tags'],
capture_output=True,
check=True,
text=True
).stdout[:-1]
except CalledProcessError:
try:
with open('PKG-INFO', 'r') as fd:
body = fd.read().partition('\n\n')[2]
if body:
return body
except FileNotFoundError:
revision = 'master'

with open('README.md', 'r', encoding='utf-8') as fd:
long_description = fd.read()

link_root = {
'': f'https://github.com/{project_path}/blob',
'!': f'https://raw.githubusercontent.com/{project_path}',
}

def replace(mobj):
return f'{mobj[1]}[{mobj[2]}]({link_root[mobj[1]]}/{revision}/{mobj[3]})'

link_rexp = re.compile(r'(!?)\[([^\]]*)\]\((?!https?://|/)([^\)]+)\)')
return link_rexp.sub(replace, long_description)


setup(
Expand All @@ -12,7 +45,7 @@
author='chrstphr',
author_email='[email protected]',
description='Attribution of Neural Networks using PyTorch',
long_description=long_description,
long_description=get_long_description('chr5tphr/zennit'),
long_description_content_type='text/markdown',
url='https://github.com/chr5tphr/zennit',
packages=find_packages(include=['zennit*']),
Expand Down
Binary file added share/img/beacon_resnet50_various.webp
Binary file not shown.
Binary file added share/img/beacon_vgg16_various.webp
Binary file not shown.

0 comments on commit a5a622d

Please sign in to comment.