From 13a91afb2203e6eadeec8fdcf26ac8fe7f52a6ea Mon Sep 17 00:00:00 2001 From: chrstphr Date: Tue, 27 Jul 2021 10:59:55 +0200 Subject: [PATCH] Check for requires_grad in pre_forward Previously, when checking whether the gradient was required to determine whether to apply hooks, only the existance of a grad_fn of the input was checked. This was insufficient, since the first layer input may not have a grad_fn yet, but still require a gradient. Now, the input is checked for requires_grad instead, since a grad_fn is not needed, because of the subsequent `Identity.apply`. It is still sufficient to check for grad_fn for the output, since the output will always have a grad_fn if a gradient is required. --- zennit/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zennit/core.py b/zennit/core.py index e1eeb00..37552ae 100644 --- a/zennit/core.py +++ b/zennit/core.py @@ -143,7 +143,7 @@ def wrapper(grad_input, grad_output): if not isinstance(input, tuple): input = (input,) - if input[0].grad_fn is not None: + if input[0].requires_grad: # only if gradient required post_input = Identity.apply(*input) post_input[0].grad_fn.register_hook(wrapper)