Skip to content

Commit

Permalink
Fix the Norm and bias of AlphaBeta
Browse files Browse the repository at this point in the history
- the normalizing terms of AlphaBeta were implemented wrong, where the
  parts of the positive and negative terms respectively were split,
  which was fixed
- the negative bias contributed twice in AlphaBeta
  • Loading branch information
chr5tphr committed Oct 18, 2021
1 parent 3a12762 commit 4611ddf
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions zennit/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,15 @@ def __init__(self, alpha=2., beta=1.):
lambda param, _: param.clamp(min=0),
lambda param, name: param.clamp(max=0) if name != 'bias' else torch.zeros_like(param),
lambda param, _: param.clamp(max=0),
lambda param, _: param.clamp(min=0),
lambda param, name: param.clamp(min=0) if name != 'bias' else torch.zeros_like(param),
],
output_modifiers=[lambda output: output] * 4,
gradient_mapper=(lambda out_grad, outputs: [out_grad / stabilize(output) for output in outputs]),
gradient_mapper=(
lambda out_grad, outputs: [
out_grad / stabilize(denom)
for output, denom in zip(outputs, [sum(outputs[:2])] * 2 + [sum(outputs[2:])] * 2)
]
),
reducer=(
lambda inputs, gradients: (
alpha * (inputs[0] * gradients[0] + inputs[1] * gradients[1])
Expand Down

0 comments on commit 4611ddf

Please sign in to comment.