Skip to content

Commit

Permalink
Minor fix: ZeRO-1 grad clipping (#4796)
Browse files Browse the repository at this point in the history
implementation and test in previous PR #4648

reduce local norm across shards
  • Loading branch information
hgt312 authored May 2, 2023
1 parent 37bfcd8 commit 71f9a35
Showing 1 changed file with 42 additions and 8 deletions.
50 changes: 42 additions & 8 deletions torch_xla/distributed/zero_redundancy_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from copy import deepcopy
from typing import (
Any,
Iterator,
Optional,
Type,
)
from math import inf
from typing import (Any, Iterator, Optional, Type, Union)

import torch
import torch.nn as nn
Expand All @@ -14,6 +10,8 @@
import torch_xla
import torch_xla.core.xla_model as xm

from .fsdp.xla_fully_sharded_data_parallel import _calc_grad_norm


class ZeroRedundancyOptimizer(Optimizer):
r"""
Expand Down Expand Up @@ -56,6 +54,8 @@ def __init__(
):
self.params = list(params)
super().__init__(self.params, defaults)
if isinstance(self.params[0], dict):
self.params = [p for pg in self.params for p in pg['params']]

self.device = self.params[0].device

Expand Down Expand Up @@ -96,6 +96,41 @@ def _shard_parameters(self):
shard = nn.Parameter(shard_data, requires_grad=param.requires_grad)
self.sharded_params.append(shard)

@torch.no_grad()
def _clip_grad_norm(
self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
) -> torch.Tensor:
"""
Clip all gradients at this point in time. The norm is computed over all
gradients together, as if they were concatenated into a single vector.
Gradients are modified in-place.
"""
max_norm = float(max_norm)
norm_type = float(norm_type)
params_with_grad = [p for p in self.sharded_params if p.grad is not None]
# Computes the max norm for this shard's gradients and sync's across workers
local_norm = _calc_grad_norm(params_with_grad, norm_type)
if norm_type == inf:
total_norm = xm.all_reduce(
xm.REDUCE_MAX,
local_norm,
groups=self.cc_op_groups,
pin_layout=self.pin_layout)
else:
total_norm = xm.all_reduce(
xm.REDUCE_SUM,
local_norm**norm_type,
groups=self.cc_op_groups,
pin_layout=self.pin_layout)
total_norm = total_norm**(1.0 / norm_type)

# Now multiply each grad by (max_norm/total_norm), same as torch 1.7 https://tinyurl.com/3wtxhhqq)
clip_coef = torch.clip(max_norm / (total_norm + 1e-6), 0.0, 1.0)
for p in params_with_grad:
p.grad.detach().mul_(clip_coef)

@torch.no_grad()
def step(self, closure=None, **kwargs):
"""
Expand Down Expand Up @@ -126,8 +161,7 @@ def step(self, closure=None, **kwargs):

if self.grad_clipping:
# Update unscale/clip with sub partitions
torch.nn.utils.clip_grad_norm_(
self.sharded_params, max_norm=self.max_norm)
self._clip_grad_norm(max_norm=self.max_norm)

# Step the wrapped optimizer
loss = self.base_optimizer.step(closure=closure, **kwargs)
Expand Down

0 comments on commit 71f9a35

Please sign in to comment.