Skip to content

Commit

Permalink
Core/Rules: ParamMod and Generalized Gamma Rule
Browse files Browse the repository at this point in the history
- change the gamma rule to its generalized version
- since the generalized gamma version needs the unmodified output of the
  function (and therefore ignore zero_params), we need a way to specify
  the zero_params on a per-modifier basis
- this was solved by implementing the previous mod_params function as a
  class ParamMod, which takes the same arguments, except the module, in
  its __init__
- param_modifiers can now be specified as instances of ParamMod, where
  `zero_params` etc. may be supplied
- this also makes passing mod_params-keyword-arguments to BasicHook
  obsolete, as these are now stored in the param_modifiers themselves
- passing simple functions (or callables) in param_modifiers is still
  okay, the function will then be used to instantiate a ParamMod
  object with the default parameters
- all param_modifiers in zennit.rules are now replaced with ParamMod
  instances, and ParamMod subclasses are implemented for the common
  ClampMod, GammaMod and NoMod
- tests now do not need to check for the param_keys in BasicHook
  anymore, as this is done by comparing the ParamMod instances
- updated and extended the gamma rule docstring
- added zero_params to docstrings in composites and rules
- updated docstrings for BasicHook
- added and modified tests for ParamMod
- added paragraph and code on ParamMod in how-to/write-custom-rules
- update ResNet50 heatmaps in README.md
  • Loading branch information
chr5tphr committed Jul 15, 2022
1 parent 7380555 commit 45763e1
Show file tree
Hide file tree
Showing 9 changed files with 390 additions and 138 deletions.
15 changes: 15 additions & 0 deletions docs/source/bibliography.bib
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,18 @@ @article{anders2021software
year = {2021},
url = {https://arxiv.org/abs/2106.13200},
}

@article{andeol2021learning,
author = {L{\'{e}}o And{\'{e}}ol and
Yusei Kawakami and
Yuichiro Wada and
Takafumi Kanamori and
Klaus{-}Robert M{\"{u}}ller and
Gr{\'{e}}goire Montavon},
title = {Learning Domain Invariant Representations by Joint Wasserstein Distance
Minimization},
journal = {CoRR},
volume = {abs/2106.04923},
year = {2021},
url = {https://arxiv.org/abs/2106.04923},
}
45 changes: 41 additions & 4 deletions docs/source/how-to/write-custom-rules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,10 @@ functions:

* ``input_modifiers``, which is a tuple of :math:`K` functions, each with a
single argument to modify the input tensor,
* ``param_modifiers``, which is a tuple of :math:`K` functions,each with two
arguments, the parameter tensor ``obj`` and its name ``name`` (e.g. ``weight``
or ``bias``), to modify the parameter,
* ``param_modifiers``, which is a tuple of :math:`K` functions or
:py:class:`~zennit.core.ParamMod` instances, each with two arguments, the
parameter tensor ``obj`` and its name ``name`` (e.g. ``weight`` or ``bias``),
to modify the parameter,
* ``output_modifiers``, which is a tuple of :math:`K` functions, each with a
single argument to modify the output tensor, each produced by applying the
module with a modified input and its respective modified parameters,
Expand Down Expand Up @@ -379,10 +380,46 @@ modifiers in order to take negative input values into account.
We recommend taking a look at the implementation of each rule in
:py:mod:`zennit.rules` for more examples.

There are two more arguments to :py:class:`~zennit.core.BasicHook`:
For more control over the parameter modification,
:py:class:`~zennit.core.ParamMod` instances may be used in ``param_modifiers``.
A common use-case for this is to specify a number of parameter names which
should be set to zero instead of applying the modification:

.. code-block:: python
import torch
from zennit.core import BasicHook, ParamMod
lrp_zplus_hook = BasicHook(
param_modifiers=[ParamMod(lambda x, _: x.clip(min=0.), zero_params='bias')],
)
input = torch.randn(1, 4, requires_grad=True)
module = torch.nn.Linear(4, 4)
handles = lrp_zplus_hook.register(module)
output = module(input)
grad, = torch.autograd.grad(output, input, torch.ones_like(output))
handles.remove()
This is used in all built-in rules based on :py:class:`~zennit.core.BasicHook`,
where the argument ``zero_params`` is passed to all applicable
:py:class:`~zennit.core.ParamMod` arguments.

There are two more arguments to :py:class:`~zennit.core.ParamMod`:

* ``param_keys``, an optional list of parameter names that should be modified,
which when ``None`` (default), will modify all parameters, and
* ``require_params``, an optional flag to indicate whether the specified
``param_keys`` are mandatory (``True``, default). A missing parameter with
``param_keys=True`` will cause a ``RuntimeError`` during the backward pass.

During the backward pass inside :py:class:`~zennit.core.BasicHook`, functions
will be internally converted to :py:class:`~zennit.core.ParamMod` with default
parameters.

The built-in rules furthermore introduce subclasses of
:py:class:`~zennit.core.ParamMod` for the common modifiers
:py:class:`~zennit.rules.ClampMod`, :py:class:`~zennit.rules.GammaMod`, and
:py:class:`~zennit.rules.NoMod`.
Binary file modified share/img/beacon_resnet50_various.webp
Binary file not shown.
12 changes: 12 additions & 0 deletions src/zennit/composites.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ class EpsilonGammaBox(SpecialFirstLayerMapComposite):
first_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]`
Applicable mapping for the first layer, same format as `layer_map`. This will be prepended to the ``first_map``
defined by the composite.
zero_params: list[str], optional
A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero.
canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional
List of canonizer instances to be applied before applying hooks.
'''
Expand Down Expand Up @@ -220,6 +222,8 @@ class EpsilonPlus(LayerMapComposite):
layer_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]`
A mapping as a list of tuples, with a tuple of applicable module types and a Hook. This will be prepended to
the ``layer_map`` defined by the composite.
zero_params: list[str], optional
A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero.
canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional
List of canonizer instances to be applied before applying hooks.
'''
Expand Down Expand Up @@ -247,6 +251,8 @@ class EpsilonAlpha2Beta1(LayerMapComposite):
layer_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]`
A mapping as a list of tuples, with a tuple of applicable module types and a Hook. This will be prepended to
the ``layer_map`` defined by the composite.
zero_params: list[str], optional
A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero.
canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional
List of canonizer instances to be applied before applying hooks.
'''
Expand Down Expand Up @@ -277,6 +283,8 @@ class EpsilonPlusFlat(SpecialFirstLayerMapComposite):
first_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]`
Applicable mapping for the first layer, same format as `layer_map`. This will be prepended to the ``first_map``
defined by the composite.
zero_params: list[str], optional
A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero.
canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional
List of canonizer instances to be applied before applying hooks.
'''
Expand Down Expand Up @@ -312,6 +320,8 @@ class EpsilonAlpha2Beta1Flat(SpecialFirstLayerMapComposite):
first_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]`
Applicable mapping for the first layer, same format as `layer_map`. This will be prepended to the ``first_map``
defined by the composite.
zero_params: list[str], optional
A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero.
canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional
List of canonizer instances to be applied before applying hooks.
'''
Expand Down Expand Up @@ -387,6 +397,8 @@ class ExcitationBackprop(LayerMapComposite):
layer_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]`
A mapping as a list of tuples, with a tuple of applicable module types and a Hook. This will be prepended to
the ``layer_map`` defined by the composite.
zero_params: list[str], optional
A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero.
canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional
List of canonizer instances to be applied before applying hooks.
'''
Expand Down
180 changes: 104 additions & 76 deletions src/zennit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,61 +151,102 @@ def modifier_wrapper(input, name):
return zero_params_wrapper


@contextmanager
def mod_params(module, modifier, param_keys=None, zero_params=None, require_params=True):
'''Context manager to temporarily modify parameter attributes (all by default) of a module.
class ParamMod:
'''Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.
Parameters
----------
module: :py:obj:`torch.nn.Module`
Module of which to modify parameters. If `requires_params` is `True`, it must have all elements given in
`param_keys` as attributes (attributes are allowed to be `None`, in which case they are ignored).
modifier: function
A function used to modify parameter attributes. If `param_keys` is empty, this is not used.
param_keys: list[str], optional
A list of parameter names that shall be modified. If `None` (default), all parameters are modified (which may be
none). If `[]`, no parameters are modified and `modifier` is ignored.
A list of parameter names that shall be modified. If `None` (default), all parameters are modified (which may
be none). If `[]`, no parameters are modified and `modifier` is ignored.
zero_params: list[str], optional
A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero.
require_params: bool, optional
Whether existence of `module`'s params is mandatory (True by default). If the attribute exists but is `None`,
it is not considered missing, and the modifier is not applied.
'''
def __init__(self, modifier, param_keys=None, zero_params=None, require_params=True):
self.modifier = modifier
self.param_keys = param_keys
self.zero_params = zero_params
self.require_params = require_params

Raises
------
RuntimeError
If `require_params` is `True` and `module` is missing an attribute listed in `param_keys`.
@classmethod
def ensure(cls, modifier):
'''If ``modifier`` is an instance of ParamMod, return it as-is, if it is callable, create a new instance with
``modifier`` as the ParamMod's function, otherwise raise a TypeError.
Yields
------
module: :py:obj:`torch.nn.Module`
The `module` with appropriate parameters temporarily modified.
'''
try:
stored_params = {}
if param_keys is None:
param_keys = [name for name, _ in module.named_parameters(recurse=False)]
if zero_params is None:
zero_params = []

missing = [key for key in param_keys if not hasattr(module, key)]
if require_params and missing:
missing_str = '\', \''.join(missing)
raise RuntimeError(f'Module {module} requires missing parameters: \'{missing_str}\'')

modifier = zero_wrap(zero_params)(modifier)

for key in param_keys:
if key not in missing:
param = getattr(module, key)
if param is not None:
stored_params[key] = param
setattr(module, key, torch.nn.Parameter(modifier(param.data, key)))
Parameters
----------
modifier : :py:obj:`ParamMod` or callable
The modifier which, if necessary, will be used to construct a ParamMod.
yield module
finally:
for key, value in stored_params.items():
setattr(module, key, value)
Returns
-------
:py:obj:`ParamMod`
Either ``modifier`` as is, or a :py:obj:`ParamMod` constructed using ``modifier``.
Raises
------
TypeError
If ``modifier`` is neither an instance of :py:obj:`ParamMod`, nor callable.
'''
if isinstance(modifier, cls):
return modifier
if callable(modifier):
return cls(modifier)
raise TypeError(f'{modifier} is neither an instance of {cls}, nor callable!')

@contextmanager
def __call__(self, module):
'''Context manager to temporarily modify parameter attributes (all by default) of a module.
Parameters
----------
module: :py:obj:`torch.nn.Module`
Module of which to modify parameters. If `self.requires_params` is `True`, it must have all elements given
in `self.param_keys` as attributes (attributes are allowed to be `None`, in which case they are ignored).
Raises
------
RuntimeError
If `self.require_params` is `True` and `module` is missing an attribute listed in `self.param_keys`.
Yields
------
module: :py:obj:`torch.nn.Module`
The `module` with appropriate parameters temporarily modified.
'''
try:
stored_params = {}
param_keys = self.param_keys
zero_params = self.zero_params

if param_keys is None:
param_keys = [name for name, _ in module.named_parameters(recurse=False)]
if zero_params is None:
zero_params = []

missing = [key for key in param_keys if not hasattr(module, key)]
if self.require_params and missing:
missing_str = '\', \''.join(missing)
raise RuntimeError(f'Module {module} requires missing parameters: \'{missing_str}\'')

modifier = zero_wrap(zero_params)(self.modifier)

for key in param_keys:
if key not in missing:
param = getattr(module, key)
if param is not None:
stored_params[key] = param
setattr(module, key, torch.nn.Parameter(modifier(param.data, key)))

yield module
finally:
for key, value in stored_params.items():
setattr(module, key, value)


def collect_leaves(module):
Expand Down Expand Up @@ -319,37 +360,34 @@ def register(self, module):


class BasicHook(Hook):
'''A hook to compute the layerwise attribution of the module it is attached to.
A `BasicHook` instance may only be registered with a single module.
'''A hook to compute the layer-wise attribution of the module it is attached to.
A BasicHook instance may only be registered with a single module.
Parameters
----------
input_modifiers: list[callable], optional
A list of functions to produce multiple inputs. Default is a single input which is the identity.
param_modifiers: list[callable], optional
A list of functions to temporarily modify the parameters of the attached module for each input produced
with `input_modifiers`. Default is unmodified parameters for each input.
A list of functions ``(input: torch.Tensor) -> torch.Tensor`` to produce multiple inputs. Default is a single
input which is the identity.
param_modifiers: list[:py:obj:`~zennit.core.ParamMod` or callable], optional
A list of ParamMod instances or functions ``(obj: torch.Tensor, name: str) -> torch.Tensor``, with parameter
tensor ``obj``, registered in the root model as ``name``, to temporarily modify the parameters of the attached
module for each input produced with `input_modifiers`. Default is unmodified parameters for each input. Use a
:py:obj:`~zennit.core.ParamMod` instance to specify which parameters should be modified, whether they are
required, and which should be set to zero.
output_modifiers: list[callable], optional
A list of functions to modify the module's output computed using the modified parameters before gradient
computation for each input produced with `input_modifier`. Default is the identity for each output.
A list of functions ``(input: torch.Tensor) -> torch.Tensor`` to modify the module's output computed using the
modified parameters before gradient computation for each input produced with `input_modifier`. Default is the
identity for each output.
gradient_mapper: callable, optional
Function to modify upper relevance. Call signature is of form `(grad_output, outputs)` and a tuple of
the same size as outputs is expected to be returned. `outputs` has the same size as `input_modifiers` and
`param_modifiers`. Default is a stabilized normalization by each of the outputs, multiplied with the output
gradient.
Function ``(out_grad: torch.Tensor, outputs: list[torch.Tensor]) -> list[torch.Tensor]`` to modify upper
relevance. A list or tuple of the same size as ``outputs`` is expected to be returned. ``outputs`` has the same
size as ``input_modifiers`` and ``param_modifiers``. Default is a stabilized normalization by each of the
outputs, multiplied with the output gradient.
reducer: callable
Function to reduce all the inputs and gradients produced through `input_modifiers` and `param_modifiers`.
Call signature is of form `(inputs, gradients)`, where `inputs` and `gradients` have the same as
`input_modifiers` and `param_modifiers`. Default is the sum of the multiplications of each input and its
corresponding gradient.
param_keys: list[str], optional
A list of parameter names that shall be modified. If `None` (default), all parameters are modified (which may be
none). If `[]`, no parameters are modified and `modifier` is ignored.
zero_params: list[str], optional
A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero.
require_params: bool, optional
Whether existence of `module`'s params is mandatory (True by default). If the attribute exists but is `None`,
it is not considered missing, and the modifier is not applied.
Function ``(inputs: list[torch.Tensor], gradients: list[torch.Tensor]) -> torch.Tensor`` to reduce all the
inputs and gradients produced through ``input_modifiers`` and ``param_modifiers``. ``inputs`` and ``gradients``
have the same as ``input_modifiers`` and ``param_modifiers``. Default is the sum of the multiplications of each
input and its corresponding gradient.
'''
def __init__(
self,
Expand All @@ -358,9 +396,6 @@ def __init__(
output_modifiers=None,
gradient_mapper=None,
reducer=None,
param_keys=None,
zero_params=None,
require_params=True
):
super().__init__()
modifiers = {
Expand All @@ -383,12 +418,6 @@ def __init__(
self.gradient_mapper = gradient_mapper
self.reducer = reducer

self.param_kwargs = {
'param_keys': param_keys,
'zero_params': zero_params,
'require_params': require_params
}

def forward(self, module, input, output):
'''Forward hook to save module in-/outputs.'''
self.stored_tensors['input'] = input
Expand All @@ -400,7 +429,7 @@ def backward(self, module, grad_input, grad_output):
outputs = []
for in_mod, param_mod, out_mod in zip(self.input_modifiers, self.param_modifiers, self.output_modifiers):
input = in_mod(original_input).requires_grad_()
with mod_params(module, param_mod, **self.param_kwargs) as modified, torch.autograd.enable_grad():
with ParamMod.ensure(param_mod)(module) as modified, torch.autograd.enable_grad():
output = modified.forward(input)
output = out_mod(output)
inputs.append(input)
Expand All @@ -422,7 +451,6 @@ def copy(self):
self.output_modifiers,
self.gradient_mapper,
self.reducer,
**self.param_kwargs,
)
return copy

Expand Down
Loading

0 comments on commit 45763e1

Please sign in to comment.