Skip to content

Commit

Permalink
Check for requires_grad in pre_forward
Browse files Browse the repository at this point in the history
Previously, when checking whether the gradient was required to determine
whether to apply hooks, only the existance of a grad_fn of the input was
checked. This was insufficient, since the first layer input may not have
a grad_fn yet, but still require a gradient. Now, the input is checked
for requires_grad instead, since a grad_fn is not needed, because of the
subsequent `Identity.apply`.
It is still sufficient to check for grad_fn for the output, since the
output will always have a grad_fn if a gradient is required.
  • Loading branch information
chr5tphr committed Jul 27, 2021
1 parent 783e169 commit 13a91af
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion zennit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def wrapper(grad_input, grad_output):
if not isinstance(input, tuple):
input = (input,)

if input[0].grad_fn is not None:
if input[0].requires_grad:
# only if gradient required
post_input = Identity.apply(*input)
post_input[0].grad_fn.register_hook(wrapper)
Expand Down

0 comments on commit 13a91af

Please sign in to comment.