Skip to content

Commit

Permalink
Core/Rules: Better/Exposed Stabilizer
Browse files Browse the repository at this point in the history
- exposes a `stabilizer` parameter for all built-in rules other than
  epsilon
- implements Stabilizer as a class
- create a stabilizer function from the stabilizer/epsilon value, for
  which Stabilizer.ensure is implemented, for both the `Epsilon` rule
  and all other rules that apply
- stabilizer/epsilon can be a float/int, in which case a Stabilizer
  object will be instantiated with the appropriate epsilon value
- stabilizer/epsilon can be callable, in which case the callable itself
  will be used as a stabilizer
- add batchnorm to the `layer_map_base`, and make `layer_map_base` a
  function such that `stabilizer` for Norm() may be passed
- added tests for `Stabilizer`
- added explanations for the use of the new `stabilizer` and `epsilon`
  arguments to `how-to/use-rules-composites-and-canonizers.rst`

- BatchNorm is now ignored by default instead of using its gradient (!!)
  the reason for this choice is that it is a better choice than using
  the gradient
  • Loading branch information
chr5tphr committed Jul 20, 2022
1 parent 45763e1 commit 60537d8
Show file tree
Hide file tree
Showing 5 changed files with 407 additions and 97 deletions.
90 changes: 76 additions & 14 deletions docs/source/how-to/use-rules-composites-and-canonizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,69 @@ layers:
# remove the hooks
handles.remove()
Furthermore, a ``stabilizer`` argument may be specified for all **Rules** based
on :py:class:`~zennit.core.BasicHook`, which is used to stabilize the
denominator in the respective **Rule**. It is the same as the ``epsilon``
argument for :py:class:`~zennit.rules.Epsilon`, which uses a different
argument name due to the different intentions behind ``stabilizer`` and
``epsilon``. A ``float`` can be supplied to use the default stabilizer, which
adds the value of ``epsilon`` while conserving the sign of the input. For more
control over the stabilization, a ``callable`` with signature ``(input:
torch.Tensor) -> torch.Tensor`` may be supplied. For this, Zennit provides the
class :py:class:`zennit.core.Stabilizer` with a few options (follow the link to
the API reference for an overview). However, if
:py:class:`~zennit.core.Stabilizer` does not provide the desired stabilization,
a custom function can be supplied instead.

.. code-block:: python
import torch
from torch.nn import Linear
from zennit.rules import ZPlus, Epsilon
from zennit.core import Stabilizer
# instantiate a few rules
rules = [
# specifying a float results in the default stabilizer, which adds an
# epsilon value that conserves the sign of the input
ZPlus(stabilizer=1e-3),
Epsilon(epsilon=1e-3),
# a callable can be supplied for a custom stabilizer; Zennit provides
# zennit.core.Stabilizer with a few choices for the stabilization
ZPlus(stabilizer=Stabilizer(epsilon=1e-4, clip=True)),
Epsilon(epsilon=Stabilizer(epsilon=1e-4, norm_scale=True)),
# if Stabilizer does not provide the desired stabilization, a simple
# function (input: torch.Tensor) -> torch.Tensor may be supplied
ZPlus(stabilizer=lambda x: x + 1e-4),
Epsilon(epsilon=lambda x: ((x == 0.) + x.sign()) * x.abs().clip(min=1e-3)),
]
dense_layer = Linear(3 * 32 * 32, 32 * 32)
input = torch.randn(1, 3 * 32 * 32, requires_grad=True)
# generate an attribution for each rule
attributions = []
for rule in rules:
handles = rule.register(dense_layer)
output = dense_layer(input)
attribution, = torch.autograd.grad(
output, input, grad_outputs=torch.ones_like(output)
)
attributions.append(attribution)
# be sure to remove the hook before registering a new one
handles.remove()
See :doc:`/how-to/write-custom-rules` for further technical detail on how to
write custom rules.

Note that some rules, in particular the ones that modify parameters (e.g.
:py:class:`~zennit.rules.ZPlus`, :py:class:`~zennit.rules.AlphaBeta`, ...)
are not thread-safe in the backward-phase, because they modify the model
parameters for a brief moment. For most users, this is unlikely to cause any
problems, and may be avoided by using locks in appropriate locations.
.. note::

Some rules, in particular the ones that modify parameters (e.g.
:py:class:`~zennit.rules.ZPlus`, :py:class:`~zennit.rules.AlphaBeta`, ...)
are not thread-safe in the backward-phase, because they modify the model
parameters for a brief moment. For most users, this is unlikely to cause
any problems, and may be avoided by using locks in appropriate locations.


.. _use-composites:
Expand Down Expand Up @@ -168,9 +223,13 @@ be used (which is done for all activations in all LRP **Composites**, but not
in :py:class:`~zennit.composites.GuidedBackprop` or
:py:class:`~zennit.composites.ExcitationBackprop`).

Note on **MaxPool**: For LRP, the gradient of MaxPool assigns values only to the
*largest* inputs (winner-takes-all), which is already the expected behaviour for
LRP rules.
.. note::

For LRP, the gradient of **MaxPool** assigns values only to the *largest*
inputs (winner-takes-all), which is already the intended behaviour for LRP
rules. Other operations for which the gradient is already the intended
behaviour for LRP are, for example, *constant padding*, *concatenation*,
*cropping*, *indexing* and *slicing*.

Composites may require arguments, e.g.
:py:class:`~zennit.composites.EpsilonGammaBox` requires keyword arguments
Expand Down Expand Up @@ -208,13 +267,16 @@ Composites may require arguments, e.g.
Some built-in rules also expose some of the parameters of their respective
**Rules**, like the ``epsilon`` for :py:class:`~zennit.rules.Epsilon`, the
``gamma`` for :py:class:`~zennit.rules.Gamma`, and ``zero_params`` for all
**Rules** based on :py:class:`~zennit.core.BasicHook`, which can for example be
used to set the bias to zero during the layer-wise relevance computation:
``gamma`` for :py:class:`~zennit.rules.Gamma`, ``stabilizer`` for the
denominator stabilization of all rules different from
:py:class:`~zennit.rules.Epsilon`, and ``zero_params`` for all **Rules** based
on :py:class:`~zennit.core.BasicHook`, which can for example be used to set the
bias to zero during the layer-wise relevance computation:

.. code-block:: python
from zennit.composites import EpsilonGammaBox
from zennit.core import Stabilizer
# built-in Composites pass some parameters to the respective rules, which
# can be used for some simple modifications; zero_params is applied to all
Expand All @@ -225,6 +287,7 @@ used to set the bias to zero during the layer-wise relevance computation:
high=3.,
epsilon=1e-4,
gamma=2.,
stabilizer=Stabilizer(epsilon=1e-5, clip=True),
zero_params='bias',
)
Expand All @@ -248,10 +311,9 @@ and using :py:func:`~zennit.core.Composite.context`:

.. code-block:: python
# register hooks for rules to all modules that apply within the context
# register hooks for rules to all modules that apply within the context;
# note that model and modified_model are the same model, the context
# variable is purely visual
# hooks are removed when the context is exited
# variable is purely visual; hooks are removed when the context is exited
with composite.context(model) as modified_model:
# execute the hooked/modified model
output = modified_model(input)
Expand Down
Loading

0 comments on commit 60537d8

Please sign in to comment.