From 1195bbcd9346fd1afa527a2610ee9c9676d9c3ee Mon Sep 17 00:00:00 2001 From: janEbert Date: Wed, 29 May 2024 10:41:18 +0200 Subject: [PATCH] Fix FSDP gradient reduction with orig params The `param.grad is not None` check also fixes gradient reduction in the case of parameters not having acquired gradients (as parameters could become empty tensors in FSDP). Thanks to @ofivite for suggesting that `use_orig_params=True` could be the cause of the issue, which greatly helped with analysis. Signed-off-by: janEbert --- .../nlp/models/language_modeling/megatron_gpt_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 111aae6d37e2..7091f6c03f14 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1018,8 +1018,8 @@ def allreduce_fsdp_sharding_omitted_gradients(self): """All-reduce gradients of FSDP-sharding-omitted parameters in sharding domain (data-parallel domain).""" assert isinstance(self.model, torch.nn.Module) grads = [] - for param in self.model.parameters(): - if not isinstance(param, torch.distributed.fsdp.FlatParameter) and param.requires_grad: + for param in self.model._ignored_params: + if param.requires_grad and param.grad is not None: grad = param.grad grads.append(grad.data) if len(grads) > 0: