diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index 12327f8fdd1..9139405e252 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -1,10 +1,6 @@ from copy import deepcopy -from typing import ( - Any, - Iterator, - Optional, - Type, -) +from math import inf +from typing import (Any, Iterator, Optional, Type, Union) import torch import torch.nn as nn @@ -14,6 +10,8 @@ import torch_xla import torch_xla.core.xla_model as xm +from .fsdp.xla_fully_sharded_data_parallel import _calc_grad_norm + class ZeroRedundancyOptimizer(Optimizer): r""" @@ -56,6 +54,8 @@ def __init__( ): self.params = list(params) super().__init__(self.params, defaults) + if isinstance(self.params[0], dict): + self.params = [p for pg in self.params for p in pg['params']] self.device = self.params[0].device @@ -96,6 +96,41 @@ def _shard_parameters(self): shard = nn.Parameter(shard_data, requires_grad=param.requires_grad) self.sharded_params.append(shard) + @torch.no_grad() + def _clip_grad_norm( + self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + ) -> torch.Tensor: + """ + Clip all gradients at this point in time. The norm is computed over all + gradients together, as if they were concatenated into a single vector. + Gradients are modified in-place. + """ + max_norm = float(max_norm) + norm_type = float(norm_type) + params_with_grad = [p for p in self.sharded_params if p.grad is not None] + # Computes the max norm for this shard's gradients and sync's across workers + local_norm = _calc_grad_norm(params_with_grad, norm_type) + if norm_type == inf: + total_norm = xm.all_reduce( + xm.REDUCE_MAX, + local_norm, + groups=self.cc_op_groups, + pin_layout=self.pin_layout) + else: + total_norm = xm.all_reduce( + xm.REDUCE_SUM, + local_norm**norm_type, + groups=self.cc_op_groups, + pin_layout=self.pin_layout) + total_norm = total_norm**(1.0 / norm_type) + + # Now multiply each grad by (max_norm/total_norm), same as torch 1.7 https://tinyurl.com/3wtxhhqq) + clip_coef = torch.clip(max_norm / (total_norm + 1e-6), 0.0, 1.0) + for p in params_with_grad: + p.grad.detach().mul_(clip_coef) + @torch.no_grad() def step(self, closure=None, **kwargs): """ @@ -126,8 +161,7 @@ def step(self, closure=None, **kwargs): if self.grad_clipping: # Update unscale/clip with sub partitions - torch.nn.utils.clip_grad_norm_( - self.sharded_params, max_norm=self.max_norm) + self._clip_grad_norm(max_norm=self.max_norm) # Step the wrapped optimizer loss = self.base_optimizer.step(closure=closure, **kwargs)