diff --git a/docs/source/how-to/use-rules-composites-and-canonizers.rst b/docs/source/how-to/use-rules-composites-and-canonizers.rst index 45cee3f..55a517d 100644 --- a/docs/source/how-to/use-rules-composites-and-canonizers.rst +++ b/docs/source/how-to/use-rules-composites-and-canonizers.rst @@ -107,14 +107,69 @@ layers: # remove the hooks handles.remove() +Furthermore, a ``stabilizer`` argument may be specified for all **Rules** based +on :py:class:`~zennit.core.BasicHook`, which is used to stabilize the +denominator in the respective **Rule**. It is the same as the ``epsilon`` +argument for :py:class:`~zennit.rules.Epsilon`, which uses a different +argument name due to the different intentions behind ``stabilizer`` and +``epsilon``. A ``float`` can be supplied to use the default stabilizer, which +adds the value of ``epsilon`` while conserving the sign of the input. For more +control over the stabilization, a ``callable`` with signature ``(input: +torch.Tensor) -> torch.Tensor`` may be supplied. For this, Zennit provides the +class :py:class:`zennit.core.Stabilizer` with a few options (follow the link to +the API reference for an overview). However, if +:py:class:`~zennit.core.Stabilizer` does not provide the desired stabilization, +a custom function can be supplied instead. + +.. code-block:: python + + import torch + from torch.nn import Linear + from zennit.rules import ZPlus, Epsilon + from zennit.core import Stabilizer + + # instantiate a few rules + rules = [ + # specifying a float results in the default stabilizer, which adds an + # epsilon value that conserves the sign of the input + ZPlus(stabilizer=1e-3), + Epsilon(epsilon=1e-3), + # a callable can be supplied for a custom stabilizer; Zennit provides + # zennit.core.Stabilizer with a few choices for the stabilization + ZPlus(stabilizer=Stabilizer(epsilon=1e-4, clip=True)), + Epsilon(epsilon=Stabilizer(epsilon=1e-4, norm_scale=True)), + # if Stabilizer does not provide the desired stabilization, a simple + # function (input: torch.Tensor) -> torch.Tensor may be supplied + ZPlus(stabilizer=lambda x: x + 1e-4), + Epsilon(epsilon=lambda x: ((x == 0.) + x.sign()) * x.abs().clip(min=1e-3)), + ] + dense_layer = Linear(3 * 32 * 32, 32 * 32) + input = torch.randn(1, 3 * 32 * 32, requires_grad=True) + + # generate an attribution for each rule + attributions = [] + for rule in rules: + handles = rule.register(dense_layer) + + output = dense_layer(input) + attribution, = torch.autograd.grad( + output, input, grad_outputs=torch.ones_like(output) + ) + attributions.append(attribution) + + # be sure to remove the hook before registering a new one + handles.remove() + See :doc:`/how-to/write-custom-rules` for further technical detail on how to write custom rules. -Note that some rules, in particular the ones that modify parameters (e.g. -:py:class:`~zennit.rules.ZPlus`, :py:class:`~zennit.rules.AlphaBeta`, ...) -are not thread-safe in the backward-phase, because they modify the model -parameters for a brief moment. For most users, this is unlikely to cause any -problems, and may be avoided by using locks in appropriate locations. +.. note:: + + Some rules, in particular the ones that modify parameters (e.g. + :py:class:`~zennit.rules.ZPlus`, :py:class:`~zennit.rules.AlphaBeta`, ...) + are not thread-safe in the backward-phase, because they modify the model + parameters for a brief moment. For most users, this is unlikely to cause + any problems, and may be avoided by using locks in appropriate locations. .. _use-composites: @@ -168,9 +223,13 @@ be used (which is done for all activations in all LRP **Composites**, but not in :py:class:`~zennit.composites.GuidedBackprop` or :py:class:`~zennit.composites.ExcitationBackprop`). -Note on **MaxPool**: For LRP, the gradient of MaxPool assigns values only to the -*largest* inputs (winner-takes-all), which is already the expected behaviour for -LRP rules. +.. note:: + + For LRP, the gradient of **MaxPool** assigns values only to the *largest* + inputs (winner-takes-all), which is already the intended behaviour for LRP + rules. Other operations for which the gradient is already the intended + behaviour for LRP are, for example, *constant padding*, *concatenation*, + *cropping*, *indexing* and *slicing*. Composites may require arguments, e.g. :py:class:`~zennit.composites.EpsilonGammaBox` requires keyword arguments @@ -208,13 +267,16 @@ Composites may require arguments, e.g. Some built-in rules also expose some of the parameters of their respective **Rules**, like the ``epsilon`` for :py:class:`~zennit.rules.Epsilon`, the -``gamma`` for :py:class:`~zennit.rules.Gamma`, and ``zero_params`` for all -**Rules** based on :py:class:`~zennit.core.BasicHook`, which can for example be -used to set the bias to zero during the layer-wise relevance computation: +``gamma`` for :py:class:`~zennit.rules.Gamma`, ``stabilizer`` for the +denominator stabilization of all rules different from +:py:class:`~zennit.rules.Epsilon`, and ``zero_params`` for all **Rules** based +on :py:class:`~zennit.core.BasicHook`, which can for example be used to set the +bias to zero during the layer-wise relevance computation: .. code-block:: python from zennit.composites import EpsilonGammaBox + from zennit.core import Stabilizer # built-in Composites pass some parameters to the respective rules, which # can be used for some simple modifications; zero_params is applied to all @@ -225,6 +287,7 @@ used to set the bias to zero during the layer-wise relevance computation: high=3., epsilon=1e-4, gamma=2., + stabilizer=Stabilizer(epsilon=1e-5, clip=True), zero_params='bias', ) @@ -248,10 +311,9 @@ and using :py:func:`~zennit.core.Composite.context`: .. code-block:: python - # register hooks for rules to all modules that apply within the context + # register hooks for rules to all modules that apply within the context; # note that model and modified_model are the same model, the context - # variable is purely visual - # hooks are removed when the context is exited + # variable is purely visual; hooks are removed when the context is exited with composite.context(model) as modified_model: # execute the hooked/modified model output = modified_model(input) diff --git a/src/zennit/composites.py b/src/zennit/composites.py index f9a322e..06acfe2 100644 --- a/src/zennit/composites.py +++ b/src/zennit/composites.py @@ -21,7 +21,7 @@ from .core import Composite from .layer import Sum from .rules import Gamma, Epsilon, ZBox, ZPlus, AlphaBeta, Flat, Pass, Norm, ReLUDeconvNet, ReLUGuidedBackprop -from .types import Convolution, Linear, AvgPool, Activation +from .types import Convolution, Linear, AvgPool, Activation, BatchNorm class LayerMapComposite(Composite): @@ -29,7 +29,7 @@ class LayerMapComposite(Composite): Parameters ---------- - layer_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]` + layer_map: list[tuple[tuple[torch.nn.Module, ...], Hook]] A mapping as a list of tuples, with a tuple of applicable module types and a Hook. canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional List of canonizer instances to be applied before applying hooks. @@ -64,7 +64,7 @@ class SpecialFirstLayerMapComposite(LayerMapComposite): Parameters ---------- - layer_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]` + layer_map: list[tuple[tuple[torch.nn.Module, ...], Hook]] A mapping as a list of tuples, with a tuple of applicable module types and a Hook. first_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]` Applicable mapping for the first layer, same format as `layer_map`. @@ -150,11 +150,27 @@ def wrapped(composite): return wrapped -LAYER_MAP_BASE = [ - (Activation, Pass()), - (Sum, Norm()), - (AvgPool, Norm()) -] +def layer_map_base(stabilizer=1e-6): + '''Return a basic layer map (list of 2-tuples) shared by all built-in LayerMapComposites. + + Parameters + ---------- + stabilizer: callable or float, optional + Stabilization parameter for rules other than ``Epsilon``. If ``stabilizer`` is a float, it will be added to the + denominator with the same sign as each respective entry. If it is callable, a function ``(input: torch.Tensor) + -> torch.Tensor`` is expected, of which the output corresponds to the stabilized denominator. + + Returns + ------- + list[tuple[tuple[torch.nn.Module, ...], Hook]] + Basic ayer map shared by all built-in LayerMapComposites. + ''' + return [ + (Activation, Pass()), + (Sum, Norm(stabilizer=stabilizer)), + (AvgPool, Norm(stabilizer=stabilizer)), + (BatchNorm, Pass()), + ] @register_composite('epsilon_gamma_box') @@ -168,11 +184,18 @@ class EpsilonGammaBox(SpecialFirstLayerMapComposite): A tensor with the same size as the input, describing the lowest possible pixel values. high: obj:`torch.Tensor` A tensor with the same size as the input, describing the highest possible pixel values. - epsilon: float - Epsilon parameter for the epsilon rule. - gamma: float + epsilon: callable or float, optional + Stabilization parameter for the ``Epsilon`` rule. If ``epsilon`` is a float, it will be added to the + denominator with the same sign as each respective entry. If it is callable, a function ``(input: torch.Tensor) + -> torch.Tensor`` is expected, of which the output corresponds to the stabilized denominator. Note that this is + called ``stabilizer`` for all other rules. + gamma: float, optional Gamma parameter for the gamma rule. - layer_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]` + stabilizer: callable or float, optional + Stabilization parameter for rules other than ``Epsilon``. If ``stabilizer`` is a float, it will be added to the + denominator with the same sign as each respective entry. If it is callable, a function ``(input: torch.Tensor) + -> torch.Tensor`` is expected, of which the output corresponds to the stabilized denominator. + layer_map: list[tuple[tuple[torch.nn.Module, ...], Hook]] A mapping as a list of tuples, with a tuple of applicable module types and a Hook. This will be prepended to the ``layer_map`` defined by the composite. first_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]` @@ -189,6 +212,7 @@ def __init__( high, epsilon=1e-6, gamma=0.25, + stabilizer=1e-6, layer_map=None, first_map=None, zero_params=None, @@ -200,12 +224,12 @@ def __init__( first_map = [] rule_kwargs = {'zero_params': zero_params} - layer_map = layer_map + LAYER_MAP_BASE + [ - (Convolution, Gamma(gamma=gamma, **rule_kwargs)), + layer_map = layer_map + layer_map_base(stabilizer) + [ + (Convolution, Gamma(gamma=gamma, stabilizer=stabilizer, **rule_kwargs)), (torch.nn.Linear, Epsilon(epsilon=epsilon, **rule_kwargs)), ] first_map = first_map + [ - (Convolution, ZBox(low=low, high=high, **rule_kwargs)) + (Convolution, ZBox(low=low, high=high, stabilizer=stabilizer, **rule_kwargs)) ] super().__init__(layer_map=layer_map, first_map=first_map, canonizers=canonizers) @@ -217,9 +241,16 @@ class EpsilonPlus(LayerMapComposite): Parameters ---------- - epsilon: float - Epsilon parameter for the epsilon rule. - layer_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]` + epsilon: callable or float, optional + Stabilization parameter for the ``Epsilon`` rule. If ``epsilon`` is a float, it will be added to the + denominator with the same sign as each respective entry. If it is callable, a function ``(input: torch.Tensor) + -> torch.Tensor`` is expected, of which the output corresponds to the stabilized denominator. Note that this is + called ``stabilizer`` for all other rules. + stabilizer: callable or float, optional + Stabilization parameter for rules other than ``Epsilon``. If ``stabilizer`` is a float, it will be added to the + denominator with the same sign as each respective entry. If it is callable, a function ``(input: torch.Tensor) + -> torch.Tensor`` is expected, of which the output corresponds to the stabilized denominator. + layer_map: list[tuple[tuple[torch.nn.Module, ...], Hook]] A mapping as a list of tuples, with a tuple of applicable module types and a Hook. This will be prepended to the ``layer_map`` defined by the composite. zero_params: list[str], optional @@ -227,13 +258,13 @@ class EpsilonPlus(LayerMapComposite): canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional List of canonizer instances to be applied before applying hooks. ''' - def __init__(self, epsilon=1e-6, layer_map=None, zero_params=None, canonizers=None): + def __init__(self, epsilon=1e-6, stabilizer=1e-6, layer_map=None, zero_params=None, canonizers=None): if layer_map is None: layer_map = [] rule_kwargs = {'zero_params': zero_params} - layer_map = layer_map + LAYER_MAP_BASE + [ - (Convolution, ZPlus(**rule_kwargs)), + layer_map = layer_map + layer_map_base(stabilizer) + [ + (Convolution, ZPlus(stabilizer=stabilizer, **rule_kwargs)), (torch.nn.Linear, Epsilon(epsilon=epsilon, **rule_kwargs)), ] super().__init__(layer_map=layer_map, canonizers=canonizers) @@ -246,9 +277,16 @@ class EpsilonAlpha2Beta1(LayerMapComposite): Parameters ---------- - epsilon: float - Epsilon parameter for the epsilon rule. - layer_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]` + epsilon: callable or float, optional + Stabilization parameter for the ``Epsilon`` rule. If ``epsilon`` is a float, it will be added to the + denominator with the same sign as each respective entry. If it is callable, a function ``(input: torch.Tensor) + -> torch.Tensor`` is expected, of which the output corresponds to the stabilized denominator. Note that this is + called ``stabilizer`` for all other rules. + stabilizer: callable or float, optional + Stabilization parameter for rules other than ``Epsilon``. If ``stabilizer`` is a float, it will be added to the + denominator with the same sign as each respective entry. If it is callable, a function ``(input: torch.Tensor) + -> torch.Tensor`` is expected, of which the output corresponds to the stabilized denominator. + layer_map: list[tuple[tuple[torch.nn.Module, ...], Hook]] A mapping as a list of tuples, with a tuple of applicable module types and a Hook. This will be prepended to the ``layer_map`` defined by the composite. zero_params: list[str], optional @@ -256,13 +294,13 @@ class EpsilonAlpha2Beta1(LayerMapComposite): canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional List of canonizer instances to be applied before applying hooks. ''' - def __init__(self, epsilon=1e-6, layer_map=None, zero_params=None, canonizers=None): + def __init__(self, epsilon=1e-6, stabilizer=1e-6, layer_map=None, zero_params=None, canonizers=None): if layer_map is None: layer_map = [] rule_kwargs = {'zero_params': zero_params} - layer_map = layer_map + LAYER_MAP_BASE + [ - (Convolution, AlphaBeta(alpha=2, beta=1, **rule_kwargs)), + layer_map = layer_map + layer_map_base(stabilizer) + [ + (Convolution, AlphaBeta(alpha=2, beta=1, stabilizer=stabilizer, **rule_kwargs)), (torch.nn.Linear, Epsilon(epsilon=epsilon, **rule_kwargs)), ] super().__init__(layer_map=layer_map, canonizers=canonizers) @@ -275,9 +313,16 @@ class EpsilonPlusFlat(SpecialFirstLayerMapComposite): Parameters ---------- - epsilon: float - Epsilon parameter for the epsilon rule. - layer_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]` + epsilon: callable or float, optional + Stabilization parameter for the ``Epsilon`` rule. If ``epsilon`` is a float, it will be added to the + denominator with the same sign as each respective entry. If it is callable, a function ``(input: torch.Tensor) + -> torch.Tensor`` is expected, of which the output corresponds to the stabilized denominator. Note that this is + called ``stabilizer`` for all other rules. + stabilizer: callable or float, optional + Stabilization parameter for rules other than ``Epsilon``. If ``stabilizer`` is a float, it will be added to the + denominator with the same sign as each respective entry. If it is callable, a function ``(input: torch.Tensor) + -> torch.Tensor`` is expected, of which the output corresponds to the stabilized denominator. + layer_map: list[tuple[tuple[torch.nn.Module, ...], Hook]] A mapping as a list of tuples, with a tuple of applicable module types and a Hook. This will be prepended to the ``layer_map`` defined by the composite. first_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]` @@ -288,19 +333,21 @@ class EpsilonPlusFlat(SpecialFirstLayerMapComposite): canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional List of canonizer instances to be applied before applying hooks. ''' - def __init__(self, epsilon=1e-6, layer_map=None, first_map=None, zero_params=None, canonizers=None): + def __init__( + self, epsilon=1e-6, stabilizer=1e-6, layer_map=None, first_map=None, zero_params=None, canonizers=None + ): if layer_map is None: layer_map = [] if first_map is None: first_map = [] rule_kwargs = {'zero_params': zero_params} - layer_map = layer_map + LAYER_MAP_BASE + [ - (Convolution, ZPlus(**rule_kwargs)), + layer_map = layer_map + layer_map_base(stabilizer) + [ + (Convolution, ZPlus(stabilizer=stabilizer, **rule_kwargs)), (torch.nn.Linear, Epsilon(epsilon=epsilon, **rule_kwargs)), ] first_map = first_map + [ - (Linear, Flat(**rule_kwargs)) + (Linear, Flat(stabilizer=stabilizer, **rule_kwargs)) ] super().__init__(layer_map=layer_map, first_map=first_map, canonizers=canonizers) @@ -312,9 +359,16 @@ class EpsilonAlpha2Beta1Flat(SpecialFirstLayerMapComposite): Parameters ---------- - epsilon: float - Epsilon parameter for the epsilon rule. - layer_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]` + epsilon: callable or float, optional + Stabilization parameter for the ``Epsilon`` rule. If ``epsilon`` is a float, it will be added to the + denominator with the same sign as each respective entry. If it is callable, a function ``(input: torch.Tensor) + -> torch.Tensor`` is expected, of which the output corresponds to the stabilized denominator. Note that this is + called ``stabilizer`` for all other rules. + stabilizer: callable or float, optional + Stabilization parameter for rules other than ``Epsilon``. If ``stabilizer`` is a float, it will be added to the + denominator with the same sign as each respective entry. If it is callable, a function ``(input: torch.Tensor) + -> torch.Tensor`` is expected, of which the output corresponds to the stabilized denominator. + layer_map: list[tuple[tuple[torch.nn.Module, ...], Hook]] A mapping as a list of tuples, with a tuple of applicable module types and a Hook. This will be prepended to the ``layer_map`` defined by the composite. first_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]` @@ -325,19 +379,21 @@ class EpsilonAlpha2Beta1Flat(SpecialFirstLayerMapComposite): canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional List of canonizer instances to be applied before applying hooks. ''' - def __init__(self, epsilon=1e-6, layer_map=None, first_map=None, zero_params=None, canonizers=None): + def __init__( + self, epsilon=1e-6, stabilizer=1e-6, layer_map=None, first_map=None, zero_params=None, canonizers=None + ): if layer_map is None: layer_map = [] if first_map is None: first_map = [] rule_kwargs = {'zero_params': zero_params} - layer_map = layer_map + LAYER_MAP_BASE + [ - (Convolution, AlphaBeta(alpha=2, beta=1, **rule_kwargs)), + layer_map = layer_map + layer_map_base(stabilizer) + [ + (Convolution, AlphaBeta(alpha=2, beta=1, stabilizer=stabilizer, **rule_kwargs)), (torch.nn.Linear, Epsilon(epsilon=epsilon, **rule_kwargs)), ] first_map = first_map + [ - (Linear, Flat(**rule_kwargs)) + (Linear, Flat(stabilizer=stabilizer, **rule_kwargs)) ] super().__init__(layer_map=layer_map, first_map=first_map, canonizers=canonizers) @@ -349,7 +405,7 @@ class DeconvNet(LayerMapComposite): Parameters ---------- - layer_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]` + layer_map: list[tuple[tuple[torch.nn.Module, ...], Hook]] A mapping as a list of tuples, with a tuple of applicable module types and a Hook. This will be prepended to the ``layer_map`` defined by the composite. canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional @@ -372,7 +428,7 @@ class GuidedBackprop(LayerMapComposite): Parameters ---------- - layer_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]` + layer_map: list[tuple[tuple[torch.nn.Module, ...], Hook]] A mapping as a list of tuples, with a tuple of applicable module types and a Hook. This will be prepended to the ``layer_map`` defined by the composite. canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional @@ -394,7 +450,7 @@ class ExcitationBackprop(LayerMapComposite): Parameters ---------- - layer_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]` + layer_map: list[tuple[tuple[torch.nn.Module, ...], Hook]] A mapping as a list of tuples, with a tuple of applicable module types and a Hook. This will be prepended to the ``layer_map`` defined by the composite. zero_params: list[str], optional @@ -402,14 +458,13 @@ class ExcitationBackprop(LayerMapComposite): canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional List of canonizer instances to be applied before applying hooks. ''' - def __init__(self, layer_map=None, zero_params=None, canonizers=None): + def __init__(self, stabilizer=1e-6, layer_map=None, zero_params=None, canonizers=None): if layer_map is None: layer_map = [] - rule_kwargs = {'zero_params': zero_params} layer_map = layer_map + [ - (Sum, Norm()), - (AvgPool, Norm()), - (Linear, ZPlus(**rule_kwargs)), + (Sum, Norm(stabilizer=stabilizer)), + (AvgPool, Norm(stabilizer=stabilizer)), + (Linear, ZPlus(stabilizer=stabilizer, zero_params=zero_params)), ] super().__init__(layer_map=layer_map, canonizers=canonizers) diff --git a/src/zennit/core.py b/src/zennit/core.py index df406c4..6e53ad0 100644 --- a/src/zennit/core.py +++ b/src/zennit/core.py @@ -23,23 +23,106 @@ import torch -def stabilize(input, epsilon=1e-6): - '''Stabilize input for safe division. This shifts zero-elements by ``+ epsilon``. For the sake of the - *epsilon rule*, this also shifts positive values by ``+ epsilon`` and negative values by ``- epsilon``. +class Stabilizer: + '''Class to create a stabilizer callable. + + Parameters + ---------- + epsilon: float, optional + Value by which to shift/clip elements of ``input``. + clip: bool, optional + If ``False`` (default), add ``epsilon`` multiplied by each entry's sign (+1 for 0). If ``True``, instead clip + the absolute value of ``input`` and multiply it by each entry's original sign. + norm_scale: bool, optional + If ``False`` (default), ``epsilon`` is added to/used to clip ``input``. If ``True``, scale ``epsilon`` by the + square root of the mean over the squared elements of the specified dimensions ``dim``. + dim: tuple[int], optional + If ``norm_scale`` is ``True``, specifies the dimension over which the scaled norm should be computed (all + except dimension 0 by default). + + ''' + def __init__(self, epsilon=1e-6, clip=False, norm_scale=False, dim=None): + self.epsilon = epsilon + self.clip = clip + self.norm_scale = norm_scale + self.dim = dim + + def __call__(self, input): + '''Stabilize input for safe division. This shifts zero-elements by ``+ epsilon``. For the sake of the + *epsilon rule*, this also shifts positive values by ``+ epsilon`` and negative values by ``- epsilon``. + + Parameters + ---------- + input: :py:obj:`torch.Tensor` + Tensor to stabilize. + + Returns + ------- + :py:obj:`torch.Tensor` + Stabilized ``input``. + ''' + return stabilize(input, self.epsilon, self.clip, self.norm_scale, self.dim) + + @classmethod + def ensure(cls, value): + '''Given a value, return a stabilizer. If ``value`` is a float, a Stabilizer with that epsilon ``value`` is + returned. If ``value`` is callable, it will be used directly as a stabilizer. Otherwise a TypeError will be + raised. + + Parameters + ---------- + value: float, int, or callable + The value used to produce a valid stabilizer function. + + Returns + ------- + callable or Stabilizer + A callable to be used as a stabilizer. + + Raises + ------ + TypeError + If no valid stabilizer could be produced from ``value``. + ''' + if isinstance(value, (float, int)): + return cls(epsilon=float(value)) + if callable(value): + return value + raise TypeError(f'Value {value} is not a valid stabilizer!') + + +def stabilize(input, epsilon=1e-6, clip=False, norm_scale=False, dim=None): + '''Stabilize input for safe division. Parameters ---------- input: :py:obj:`torch.Tensor` Tensor to stabilize. epsilon: float, optional - Value by which to shift elements. + Value by which to shift/clip elements of ``input``. + clip: bool, optional + If ``False`` (default), add ``epsilon`` multiplied by each entry's sign (+1 for 0). If ``True``, instead clip + the absolute value of ``input`` and multiply it by each entry's original sign. + norm_scale: bool, optional + If ``False`` (default), ``epsilon`` is added to/used to clip ``input``. If ``True``, scale ``epsilon`` by the + square root of the mean over the squared elements of the specified dimensions ``dim``. + dim: tuple[int], optional + If ``norm_scale`` is ``True``, specifies the dimension over which the scaled norm should be computed. Defaults + to all except dimension 0. Returns ------- :py:obj:`torch.Tensor` New Tensor copied from `input` with values shifted by epsilon. ''' - return input + ((input == 0.).to(input) + input.sign()) * epsilon + sign = ((input == 0.).to(input) + input.sign()) + if norm_scale: + if dim is None: + dim = tuple(range(1, input.ndim)) + epsilon = epsilon * ((input ** 2).mean(dim=dim, keepdim=True) ** .5) + if clip: + return sign * input.abs().clip(min=epsilon) + return input + sign * epsilon def expand(tensor, shape, cut_batch_dim=False): @@ -396,6 +479,7 @@ def __init__( output_modifiers=None, gradient_mapper=None, reducer=None, + stabilizer=1e-6, ): super().__init__() modifiers = { @@ -434,8 +518,8 @@ def backward(self, module, grad_input, grad_output): output = out_mod(output) inputs.append(input) outputs.append(output) - gradients = torch.autograd.grad(outputs, inputs, grad_outputs=self.gradient_mapper(grad_output[0], outputs)) - # relevance = self.reducer([input.detach() for input in inputs], [gradient.detach() for gradient in gradients]) + grad_outputs = self.gradient_mapper(grad_output[0], outputs) + gradients = torch.autograd.grad(outputs, inputs, grad_outputs=grad_outputs) relevance = self.reducer(inputs, gradients) return tuple(relevance if original.shape == relevance.shape else None for original in grad_input) diff --git a/src/zennit/rules.py b/src/zennit/rules.py index 485723c..f8cd4f1 100644 --- a/src/zennit/rules.py +++ b/src/zennit/rules.py @@ -18,7 +18,7 @@ '''Rules based on Hooks''' import torch -from .core import Hook, BasicHook, stabilize, expand, ParamMod +from .core import Hook, BasicHook, Stabilizer, expand, ParamMod def zero_bias(zero_params=None): @@ -103,17 +103,21 @@ class Epsilon(BasicHook): Parameters ---------- - epsilon: float, optional - Stabilization parameter. + epsilon: callable or float, optional + Stabilization parameter. If ``epsilon`` is a float, it will be added to the denominator with the same sign as + each respective entry. If it is callable, a function ``(input: torch.Tensor) -> torch.Tensor`` is expected, of + which the output corresponds to the stabilized denominator. Note that this is called ``stabilizer`` for all + other rules. zero_params: list[str], optional A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero. ''' def __init__(self, epsilon=1e-6, zero_params=None): + stabilizer_fn = Stabilizer.ensure(epsilon) super().__init__( input_modifiers=[lambda input: input], param_modifiers=[NoMod(zero_params=zero_params)], output_modifiers=[lambda output: output], - gradient_mapper=(lambda out_grad, outputs: out_grad / stabilize(outputs[0], epsilon)), + gradient_mapper=(lambda out_grad, outputs: out_grad / stabilizer_fn(outputs[0])), reducer=(lambda inputs, gradients: inputs[0] * gradients[0]), ) @@ -129,12 +133,17 @@ class Gamma(BasicHook): ---------- gamma: float, optional Multiplier for added positive weights. + stabilizer: callable or float, optional + Stabilization parameter. If ``stabilizer`` is a float, it will be added to the denominator with the same sign + as each respective entry. If it is callable, a function ``(input: torch.Tensor) -> torch.Tensor`` is expected, + of which the output corresponds to the stabilized denominator. zero_params: list[str], optional A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero. ''' - def __init__(self, gamma=0.25, zero_params=None): + def __init__(self, gamma=0.25, stabilizer=1e-6, zero_params=None): mod_kwargs = {'zero_params': zero_params} mod_kwargs_nobias = {'zero_params': zero_bias(zero_params)} + stabilizer_fn = Stabilizer.ensure(stabilizer) super().__init__( input_modifiers=[ lambda input: input.clamp(min=0), @@ -153,7 +162,7 @@ def __init__(self, gamma=0.25, zero_params=None): output_modifiers=[lambda output: output] * 5, gradient_mapper=( lambda out_grad, outputs: [ - output * out_grad / stabilize(denom) + output * out_grad / stabilizer_fn(denom) for output, denom in ( [(outputs[4] > 0., sum(outputs[:2]))] * 2 + [(outputs[4] < 0., sum(outputs[2:4]))] * 2 @@ -172,6 +181,10 @@ class ZPlus(BasicHook): Parameters ---------- + stabilizer: callable or float, optional + Stabilization parameter. If ``stabilizer`` is a float, it will be added to the denominator with the same sign + as each respective entry. If it is callable, a function ``(input: torch.Tensor) -> torch.Tensor`` is expected, + of which the output corresponds to the stabilized denominator. zero_params: list[str], optional A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero. @@ -181,7 +194,8 @@ class ZPlus(BasicHook): :cite:p:`montavon2017explaining` only considers positive inputs, as they are used in ReLU Networks. This implementation is effectively alpha=1, beta=0, where negative inputs are allowed. ''' - def __init__(self, zero_params=None): + def __init__(self, stabilizer=1e-6, zero_params=None): + stabilizer_fn = Stabilizer.ensure(stabilizer) super().__init__( input_modifiers=[ lambda input: input.clamp(min=0), @@ -192,7 +206,7 @@ def __init__(self, zero_params=None): ClampMod(max=0., zero_params=zero_bias(zero_params)), ], output_modifiers=[lambda output: output] * 2, - gradient_mapper=(lambda out_grad, outputs: [out_grad / stabilize(sum(outputs))] * 2), + gradient_mapper=(lambda out_grad, outputs: [out_grad / stabilizer_fn(sum(outputs))] * 2), reducer=(lambda inputs, gradients: inputs[0] * gradients[0] + inputs[1] * gradients[1]), ) @@ -209,17 +223,22 @@ class AlphaBeta(BasicHook): Multiplier for the positive output term. beta: float, optional Multiplier for the negative output term. + stabilizer: callable or float, optional + Stabilization parameter. If ``stabilizer`` is a float, it will be added to the denominator with the same sign + as each respective entry. If it is callable, a function ``(input: torch.Tensor) -> torch.Tensor`` is expected, + of which the output corresponds to the stabilized denominator. zero_params: list[str], optional A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero. ''' - def __init__(self, alpha=2., beta=1., zero_params=None): + def __init__(self, alpha=2., beta=1., stabilizer=1e-6, zero_params=None): if alpha < 0 or beta < 0: raise ValueError("Both alpha and beta parameters must be positive!") if (alpha - beta) != 1.: raise ValueError("The difference of parameters alpha - beta must equal 1!") mod_kwargs = {'zero_params': zero_params} mod_kwargs_nobias = {'zero_params': zero_bias(zero_params)} + stabilizer_fn = Stabilizer.ensure(stabilizer) super().__init__( input_modifiers=[ @@ -237,7 +256,7 @@ def __init__(self, alpha=2., beta=1., zero_params=None): output_modifiers=[lambda output: output] * 4, gradient_mapper=( lambda out_grad, outputs: [ - out_grad / stabilize(denom) + out_grad / stabilizer_fn(denom) for denom in ([sum(outputs[:2])] * 2 + [sum(outputs[2:])] * 2) ] ), @@ -266,15 +285,20 @@ class ZBox(BasicHook): Lowest pixel values of input. Subject to broadcasting. high: :py:class:`torch.Tensor` or float Highest pixel values of input. Subject to broadcasting. + stabilizer: callable or float, optional + Stabilization parameter. If ``stabilizer`` is a float, it will be added to the denominator with the same sign + as each respective entry. If it is callable, a function ``(input: torch.Tensor) -> torch.Tensor`` is expected, + of which the output corresponds to the stabilized denominator. zero_params: list[str], optional A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero. ''' - def __init__(self, low, high, zero_params=None): + def __init__(self, low, high, stabilizer=1e-6, zero_params=None): def sub(positive, *negatives): return positive - sum(negatives) mod_kwargs = {'zero_params': zero_params} + stabilizer_fn = Stabilizer.ensure(stabilizer) super().__init__( input_modifiers=[ @@ -288,7 +312,7 @@ def sub(positive, *negatives): ClampMod(max=0., **mod_kwargs), ], output_modifiers=[lambda output: output] * 3, - gradient_mapper=(lambda out_grad, outputs: (out_grad / stabilize(sub(*outputs)),) * 3), + gradient_mapper=(lambda out_grad, outputs: (out_grad / stabilizer_fn(sub(*outputs)),) * 3), reducer=(lambda inputs, gradients: sub(*(input * gradient for input, gradient in zip(inputs, gradients)))), ) @@ -309,12 +333,13 @@ class Norm(BasicHook): epsilon only used as a stabilizer, and without the need of the attached layer to have parameters ``weight`` and ``bias``. ''' - def __init__(self): + def __init__(self, stabilizer=1e-6): + stabilizer_fn = Stabilizer.ensure(stabilizer) super().__init__( input_modifiers=[lambda input: input], param_modifiers=[NoMod(param_keys=[])], output_modifiers=[lambda output: output], - gradient_mapper=(lambda out_grad, outputs: out_grad / stabilize(outputs[0])), + gradient_mapper=(lambda out_grad, outputs: out_grad / stabilizer_fn(outputs[0])), reducer=(lambda inputs, gradients: inputs[0] * gradients[0]), ) @@ -325,17 +350,22 @@ class WSquare(BasicHook): Parameters ---------- + stabilizer: callable or float, optional + Stabilization parameter. If ``stabilizer`` is a float, it will be added to the denominator with the same sign + as each respective entry. If it is callable, a function ``(input: torch.Tensor) -> torch.Tensor`` is expected, + of which the output corresponds to the stabilized denominator. zero_params: list[str], optional A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero. ''' - def __init__(self, zero_params=None): + def __init__(self, stabilizer=1e-6, zero_params=None): + stabilizer_fn = Stabilizer.ensure(stabilizer) super().__init__( input_modifiers=[torch.ones_like], param_modifiers=[ ParamMod((lambda param, _: param ** 2), zero_params=zero_params), ], output_modifiers=[lambda output: output], - gradient_mapper=(lambda out_grad, outputs: out_grad / stabilize(outputs[0])), + gradient_mapper=(lambda out_grad, outputs: out_grad / stabilizer_fn(outputs[0])), reducer=(lambda inputs, gradients: gradients[0]), ) @@ -346,18 +376,23 @@ class Flat(BasicHook): Parameters ---------- + stabilizer: callable or float, optional + Stabilization parameter. If ``stabilizer`` is a float, it will be added to the denominator with the same sign + as each respective entry. If it is callable, a function ``(input: torch.Tensor) -> torch.Tensor`` is expected, + of which the output corresponds to the stabilized denominator. zero_params: list[str], optional A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero. ''' - def __init__(self, zero_params=None): + def __init__(self, stabilizer=1e-6, zero_params=None): mod_kwargs = {'zero_params': zero_bias(zero_params), 'require_params': False} + stabilizer_fn = Stabilizer.ensure(stabilizer) super().__init__( input_modifiers=[torch.ones_like], param_modifiers=[ ParamMod((lambda param, name: torch.ones_like(param)), **mod_kwargs), ], output_modifiers=[lambda output: output], - gradient_mapper=(lambda out_grad, outputs: out_grad / stabilize(outputs[0])), + gradient_mapper=(lambda out_grad, outputs: out_grad / stabilizer_fn(outputs[0])), reducer=(lambda inputs, gradients: gradients[0]), ) diff --git a/tests/test_core.py b/tests/test_core.py index dda4437..46980e4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -3,26 +3,100 @@ import torch import pytest -from helpers import nograd +from helpers import nograd, prodict -from zennit.core import stabilize, expand, ParamMod, collect_leaves +from zennit.core import stabilize, expand, ParamMod, collect_leaves, Stabilizer from zennit.core import Identity, Hook, BasicHook, RemovableHandle, RemovableHandleList, Composite -@pytest.mark.parametrize('input,epsilon,expected', [ - ([0., -0., 1., -1.], None, [1e-6, 1e-6, 1. + 1e-6, -1. - 1e-6]), - ([0., -0., 1., -1.], 1e-3, [1e-3, 1e-3, 1. + 1e-3, -1. - 1e-3]), - ([0., -0., 1., -1.], 1., [1., 1., 2., -2.]), +@pytest.mark.parametrize('kwargs,input,expected', [ + ( + {}, + [0., -0., 1., -1.], + [1e-6, 1e-6, 1. + 1e-6, -1. - 1e-6] + ), ( + {'epsilon': 1e-3, 'clip': False, 'norm_scale': False}, + [0., -0., 1., -1.], + [1e-3, 1e-3, 1. + 1e-3, -1. - 1e-3] + ), ( + {'epsilon': 1., 'clip': False, 'norm_scale': False}, + [0., -0., 1., -1.], + [1., 1., 2., -2.] + ), ( + {'epsilon': 1e-6, 'clip': True, 'norm_scale': False}, + [0., -0., 1., -1.], + [1e-6, 1e-6, 1., -1.] + ), ( + {'epsilon': 1e-3, 'clip': True, 'norm_scale': False}, + [0., -0., 1., -1.], + [1e-3, 1e-3, 1., -1.] + ), ( + {'epsilon': 1., 'clip': True, 'norm_scale': False}, + [0., -0., 1., -1.], + [1., 1., 1., -1.] + ), ( + {'epsilon': 1e-6, 'clip': False, 'norm_scale': True}, + [0., -0., 2., -2.], + [1.4142e-6, 1.4142e-6, 2. + 1.4142e-6, -2. - 1.4142e-6] + ), ( + {'epsilon': 1e-3, 'clip': False, 'norm_scale': True}, + [0., -0., 2., -2.], + [1.4142e-3, 1.4142e-3, 2. + 1.4142e-3, -2. - 1.4142e-3] + ), ( + {'epsilon': 1., 'clip': False, 'norm_scale': True}, + [0., -0., 2., -2.], + [1.4142, 1.4142, 3.4142, -3.4142] + ), ( + {'epsilon': 1e-6, 'clip': True, 'norm_scale': True}, + [0., -0., 2., -2.], + [1.4142e-6, 1.4142e-6, 2., -2.] + ), ( + {'epsilon': 1e-3, 'clip': True, 'norm_scale': True}, + [0., -0., 2., -2.], + [1.4142e-3, 1.4142e-3, 2., -2.] + ), ( + {'epsilon': 1., 'clip': True, 'norm_scale': True}, + [0., -0., 2., -2.], + [1.4142, 1.4142, 2., -2.] + ), ]) -def test_stabilize(input, epsilon, expected): +def test_stabilize(kwargs, input, expected): '''Test whether stabilize produces the expected outputs given some inputs.''' - kwargs = {} if epsilon is None else {'epsilon': epsilon} input_tensor = torch.tensor(input, dtype=torch.float64) - output = stabilize(input_tensor, **kwargs) + output = stabilize(input_tensor, dim=0, **kwargs) expected_tensor = torch.tensor(expected, dtype=torch.float64) assert torch.allclose(expected_tensor, output) +@pytest.mark.parametrize('kwargs', prodict( + epsilon=[1e-6, 1e-3, 1.], + clip=[True, False], + norm_scale=[True, False], + dim=[None, (0,), (1,), (0, 1)] +)) +def test_stabilizer_match(kwargs, data_simple): + '''Test whether stabilize and Stabilizer produce the same output.''' + stabilizer = Stabilizer(**kwargs) + stabilizer_out = stabilizer(data_simple) + stabilize_out = stabilize(data_simple, **kwargs) + assert torch.allclose(stabilizer_out, stabilize_out) + + +@pytest.mark.parametrize('value', [1., 1, Stabilizer(epsilon=1.), lambda x: x]) +def test_stabilizer_ensure(value): + '''Test whether Stabilizer.ensure produces a stabilizer with the correct epsilon, or returns callables as-is.''' + ensured = Stabilizer.ensure(value) + assert not isinstance(value, float) or isinstance(ensured, Stabilizer) and ensured.epsilon == value + assert not callable(value) or value is ensured + + +@pytest.mark.parametrize('value', [None, 'wow']) +def test_stabilizer_ensure_fail(value): + '''Test whether Stabilizer.ensure fails on unsupported types.''' + with pytest.raises(TypeError): + Stabilizer.ensure(value) + + @pytest.mark.parametrize('input_shape,target_shape,cut_batch_dim', [ ((), (), False), ((), (2,), False),