Skip to content

Commit

Permalink
Fix FSDP gradient reduction with orig params
Browse files Browse the repository at this point in the history
The `param.grad is not None` check also fixes gradient reduction in the
case of parameters not having acquired gradients (as parameters could
become empty tensors in FSDP).

Thanks to @ofivite for suggesting that `use_orig_params=True` could be
the cause of the issue, which greatly helped with analysis.

Signed-off-by: janEbert <[email protected]>
  • Loading branch information
janEbert committed Jun 19, 2024
1 parent 61a1a44 commit 1195bbc
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1018,8 +1018,8 @@ def allreduce_fsdp_sharding_omitted_gradients(self):
"""All-reduce gradients of FSDP-sharding-omitted parameters in sharding domain (data-parallel domain)."""
assert isinstance(self.model, torch.nn.Module)
grads = []
for param in self.model.parameters():
if not isinstance(param, torch.distributed.fsdp.FlatParameter) and param.requires_grad:
for param in self.model._ignored_params:
if param.requires_grad and param.grad is not None:
grad = param.grad
grads.append(grad.data)
if len(grads) > 0:
Expand Down

0 comments on commit 1195bbc

Please sign in to comment.