Skip to content

Commit

Permalink
Fix bug of clip_grad_norm_ for xla fsdp (#2941)
Browse files Browse the repository at this point in the history
* fix bug of clip_grad_norm_ for xla

* modify
  • Loading branch information
hanwen-sun authored Aug 1, 2024
1 parent 83b0610 commit 288accc
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2300,6 +2300,12 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
xm.all_reduce("sum", gradients, scale=1.0 / self.num_processes)
# Set is_xla_gradients_synced to True to avoid all-reduce twice in the AcceleratedOptimizer step.
acc_opt.gradient_state.is_xla_gradients_synced = True
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true":
self.unscale_gradients()
parameters = [p for p in parameters]
for model in self._models:
if parameters == [p for p in model.parameters()]:
return model.clip_grad_norm_(max_norm, norm_type)
self.unscale_gradients()
return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type)

Expand Down

0 comments on commit 288accc

Please sign in to comment.