Skip to content

Commit

Permalink
sync new features to master
Browse files Browse the repository at this point in the history
  • Loading branch information
hgt312 committed May 28, 2024
1 parent 7770a49 commit 5ef1b50
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 22 deletions.
2 changes: 1 addition & 1 deletion torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ def reduce_scatter_bucketized(reduce_type,
see reduce_scatter for reduce_type, scale, scatter_dim, shard_count, groups, pin_layout
input_list: List of input tensors
output: Optional list of output torch.Tensor
bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather.
bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing reduce-scatter.
Returns:
A list of `torch.Tensors` with all the values reduced across replicas. Each process
Expand Down
159 changes: 138 additions & 21 deletions torch_xla/distributed/zero_redundancy_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import logging
from typing import (Any, Iterator, Optional, Type, Union, List, Dict)

import torch
Expand Down Expand Up @@ -32,17 +33,28 @@ class ZeroRedundancyOptimizer(Optimizer):
collective ops (all_gather and reduce_scatter). See `xm.all_reduce`
for details on pinning layout. Default: True
sharding_groups (list, Optional):
If specified, ZeRO-1 will use this ``sharding_groups`` for all-gather
and reduce-scatter ops in full parameter construction and gradient
sharding. This can be useful for mixing ZeRO-1 with model parallelism
such as Megatron.
If specified, ZeRO-1 will use this ``sharding_groups`` for all-gather
and reduce-scatter ops in full parameter construction and gradient
sharding. This can be useful for mixing ZeRO-1 with model parallelism
such as Megatron.
grad_norm_groups (list, Optional):
If specified, ZeRO-1 will use this ``grad_norm_groups`` for the
EXTRA all-reduce op in grad norm calculation. This can be model parallel
groups when mixing ZeRO-1 with model parallelism such as Megatron.
bucket_cap_mb:
If non-zero, specifies the maximum number of megabytes to combine tensors
before doing the all-gather/reduce-scatter operations.
If specified, ZeRO-1 will use this ``grad_norm_groups`` for the
EXTRA all-reduce op in grad norm calculation. This can be model parallel
groups when mixing ZeRO-1 with model parallelism such as Megatron.
lazy_init (bool, Optional): if ``True``, the class will not shard paramaters
during initialization. Users need to call ``optimizer.init_zero()`` by themselves.
Default: False
bucket_cap_mb_all_gather (int, Optional): Number of MegaBytes of the tensor bucket to fill before
doing all-gather. Default: False
bucket_cap_mb_reduce_scatter (int, Optional): Number of MegaBytes of the tensor bucket to fill before
doing reduce-scatter. Default: False
use_grad_acc_hook (bool, Optional): if ``True``, use hooks for gradients accumulation, then
``dtype`` of grad accumulation will be the same as ``optimizer_dtype``. Users can set this
to True to use higher precision for gradients accumulation. Default: False
save_master_weights (bool, Optional):
if ``True``, also save sharded master weights. Default: False
higher_cc_precision (bool, Optional): if ``True``, use higher precision for collective communication
operators (the same as ``optimizer_dtype``). Default: False
**defaults: any trailing arguments, which are forwarded to the local
optimizer.
Expand All @@ -65,8 +77,16 @@ def __init__(
lazy_init: bool = False,
bucket_cap_mb_all_gather: int = 0,
bucket_cap_mb_reduce_scatter: int = 0,
use_grad_acc_hook: bool = False,
save_master_weights: bool = False,
higher_cc_precision: bool = False,
**defaults: Any,
):
if not save_master_weights:
logging.warning(
'Not saving master weights may have accuracy issues when resuming training!'
)

super().__init__(params, defaults)

self.global_world_size = xm.xrt_world_size()
Expand All @@ -85,6 +105,10 @@ def __init__(
self.bucket_cap_mb_reduce_scatter = bucket_cap_mb_reduce_scatter
self.coalesce_cc_all_gather = bucket_cap_mb_all_gather > 0
self.coalesce_cc_reduce_scatter = bucket_cap_mb_reduce_scatter > 0
self.higher_cc_precision = higher_cc_precision
self.use_grad_acc_hook = use_grad_acc_hook
self.grad_accs = []
self.grad_acc_hooks = []

self._grad_norm = None

Expand All @@ -93,6 +117,7 @@ def __init__(
self.init_zero()

def init_zero(self):
self.remove_hooks()
self.local_world_size = len(self.sharding_groups[0])
# Infer the local rank from the group
self.local_rank = None
Expand Down Expand Up @@ -169,6 +194,41 @@ def _shard_tensor(self, tensor: torch.Tensor):
tensor = tensor.chunk(self.local_world_size)[self.local_rank]
return tensor

def _make_param_hook(self, param, shard):
"""
Create the grad accumulation hook for backprop.
"""

def _param_hook(*unused):
# Accumulates gradients on main gradients
if param.grad is not None:
if not hasattr(shard, 'main_grad'):
# Create main gradients
shard.main_grad = torch.zeros(
param.shape,
dtype=self.optimizer_dtype,
device=self.device,
requires_grad=False)
param.main_grad = shard.main_grad
shard.main_grad.add_(param.grad.data.to(dtype=self.optimizer_dtype))
# Deallocate grad memory.
param.grad = None

return _param_hook

def _register_hook(self, param, shard):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
hook = grad_acc.register_hook(self._make_param_hook(param, shard))
self.grad_acc_hooks.append(hook)
self.grad_accs.append(grad_acc)

def remove_hooks(self):
for hook in self.grad_acc_hooks:
hook.remove()
self.grad_acc_hooks = []
self.grad_accs = []

def _shard_parameters(self):
"""
Shard all parameters.
Expand Down Expand Up @@ -196,6 +256,8 @@ def _shard_parameters(self):
shard_data = shard_data.to(dtype=self.optimizer_dtype)
shard_data = shard_data.to(device=self.device) # move to xla device
shard = nn.Parameter(shard_data, requires_grad=param.requires_grad)
if self.use_grad_acc_hook:
self._register_hook(param, shard)
sharded_params.append(shard)
sharded_params_group = copy.copy(param_group)
sharded_params_group['params'] = sharded_params
Expand Down Expand Up @@ -282,9 +344,13 @@ def step(self, closure=None, **kwargs):
self.param_groups, self.base_optimizer.param_groups):
for param, shard in zip(param_group['params'],
sharded_param_group['params']):
if param.grad is not None:
padded_grad = self._pad_to_world_size(param.grad,
self.local_world_size)
if param.grad is not None or (self.use_grad_acc_hook and
hasattr(shard, 'main_grad')):
padded_grad = self._pad_to_world_size(
shard.main_grad if self.use_grad_acc_hook else param.grad,
self.local_world_size)
if self.higher_cc_precision:
padded_grad = padded_grad.to(dtype=self.optimizer_dtype)
if self.coalesce_cc_reduce_scatter:
padded_grads.append(padded_grad)
else:
Expand Down Expand Up @@ -317,7 +383,8 @@ def step(self, closure=None, **kwargs):
self.param_groups, self.base_optimizer.param_groups):
for param, shard in zip(param_group['params'],
sharded_param_group['params']):
if param.grad is not None:
if param.grad is not None or (self.use_grad_acc_hook and
hasattr(shard, 'main_grad')):
grad_shard = grad_shards[index]
if grad_shard.dtype != self.optimizer_dtype:
grad_shard = grad_shard.to(dtype=self.optimizer_dtype)
Expand All @@ -334,15 +401,24 @@ def step(self, closure=None, **kwargs):
# Remove shards' grads
self.base_optimizer.zero_grad(set_to_none=True)

self.allgather_weights_and_update_full_parameter()

# sync back
self._sync_param_groups(self.base_optimizer.param_groups, self.param_groups)

return loss

def allgather_weights_and_update_full_parameter(self):
# All gather the new weights across the ranks and assign them to the full parameters
sharded_data = []
for param_group, sharded_param_group in zip(
self.param_groups, self.base_optimizer.param_groups):
for param, shard in zip(param_group['params'],
sharded_param_group['params']):
if param.grad is not None:
if param.grad is not None or (self.use_grad_acc_hook and
hasattr(shard, 'main_grad')):
shard_data = shard.data
if param.dtype != self.optimizer_dtype:
if not self.higher_cc_precision:
shard_data = shard_data.to(dtype=param.dtype)
if self.coalesce_cc_all_gather:
sharded_data.append(shard_data)
Expand All @@ -353,6 +429,8 @@ def step(self, closure=None, **kwargs):
pin_layout=self.pin_layout,
groups=self.sharding_groups,
)
if padded_param.dtype != param.dtype:
padded_param = padded_param.to(dtype=param.dtype)
param.data.copy_(padded_param.data[:param.size(0)])

if self.coalesce_cc_all_gather:
Expand All @@ -368,21 +446,36 @@ def step(self, closure=None, **kwargs):
self.param_groups, self.base_optimizer.param_groups):
for param, shard in zip(param_group['params'],
sharded_param_group['params']):
if param.grad is not None:
if param.grad is not None or (self.use_grad_acc_hook and
hasattr(shard, 'main_grad')):
padded_param = padded_params[index]
if padded_param.dtype != param.dtype:
padded_param = padded_params[index].to(dtype=param.dtype)
param.data.copy_(padded_param.data[:param.size(0)])
index += 1

# sync back
self._sync_param_groups(self.base_optimizer.param_groups, self.param_groups)

return loss
def zero_grad(self, set_to_none: bool = False):
super().zero_grad(set_to_none=set_to_none)
if self.use_grad_acc_hook:
for sharded_param_group in self.base_optimizer.param_groups:
for shard in sharded_param_group['params']:
if hasattr(shard, 'main_grad'):
shard.main_grad.zero_()

def state_dict(self):
state_dict = super().state_dict()
base_state = self.base_optimizer.state_dict()['state']
state_dict['base_state'] = base_state
state_dict['shape_info'] = self.get_shape_info()
if self.save_master_weights:
index = 0
master_weights = {}
for sharded_param_group in self.base_optimizer.param_groups:
for shard in sharded_param_group['params']:
master_weights[index] = shard.data
index += 1
state_dict['sharded_master_weights'] = master_weights

return state_dict

def load_state_dict(self, state_dict):
Expand All @@ -396,6 +489,30 @@ def load_state_dict(self, state_dict):
tmp = self.base_optimizer.state_dict()
tmp['state'] = base_state
self.base_optimizer.load_state_dict(tmp)
if 'sharded_master_weights' in state_dict:
master_weights = state_dict['sharded_master_weights']
index = 0
for param_group, sharded_param_group in zip(
self.param_groups, self.base_optimizer.param_groups):
for param, shard in zip(param_group['params'],
sharded_param_group['params']):
shard.data.copy_(master_weights[index])
# set dummy gradient for allgather to be triggered.
if self.use_grad_acc_hook:
# Create main gradients
shard.main_grad = torch.zeros(
param.shape,
dtype=self.optimizer_dtype,
device=self.device,
requires_grad=False)
param.main_grad = shard.main_grad
else:
param.grad = torch.zeros_like(param.data)
index += 1
xm.mark_step()
# add mark_step around allgather to avoid large number of compilation
self.allgather_weights_and_update_full_parameter()
xm.mark_step()

def get_shape_info(self):
shape_info = {}
Expand Down

0 comments on commit 5ef1b50

Please sign in to comment.