From 5ef1b509ddfbd12ebae2a0e4c56cace5b9020d22 Mon Sep 17 00:00:00 2001 From: guangtai Date: Wed, 1 May 2024 16:05:09 -0700 Subject: [PATCH] sync new features to master --- torch_xla/core/xla_model.py | 2 +- .../distributed/zero_redundancy_optimizer.py | 159 +++++++++++++++--- 2 files changed, 139 insertions(+), 22 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 149aa99b67d..4fb7cc1316a 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -967,7 +967,7 @@ def reduce_scatter_bucketized(reduce_type, see reduce_scatter for reduce_type, scale, scatter_dim, shard_count, groups, pin_layout input_list: List of input tensors output: Optional list of output torch.Tensor - bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather. + bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing reduce-scatter. Returns: A list of `torch.Tensors` with all the values reduced across replicas. Each process diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index 9b21fe4ead8..2299714271a 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -1,4 +1,5 @@ import copy +import logging from typing import (Any, Iterator, Optional, Type, Union, List, Dict) import torch @@ -32,17 +33,28 @@ class ZeroRedundancyOptimizer(Optimizer): collective ops (all_gather and reduce_scatter). See `xm.all_reduce` for details on pinning layout. Default: True sharding_groups (list, Optional): - If specified, ZeRO-1 will use this ``sharding_groups`` for all-gather - and reduce-scatter ops in full parameter construction and gradient - sharding. This can be useful for mixing ZeRO-1 with model parallelism - such as Megatron. + If specified, ZeRO-1 will use this ``sharding_groups`` for all-gather + and reduce-scatter ops in full parameter construction and gradient + sharding. This can be useful for mixing ZeRO-1 with model parallelism + such as Megatron. grad_norm_groups (list, Optional): - If specified, ZeRO-1 will use this ``grad_norm_groups`` for the - EXTRA all-reduce op in grad norm calculation. This can be model parallel - groups when mixing ZeRO-1 with model parallelism such as Megatron. - bucket_cap_mb: - If non-zero, specifies the maximum number of megabytes to combine tensors - before doing the all-gather/reduce-scatter operations. + If specified, ZeRO-1 will use this ``grad_norm_groups`` for the + EXTRA all-reduce op in grad norm calculation. This can be model parallel + groups when mixing ZeRO-1 with model parallelism such as Megatron. + lazy_init (bool, Optional): if ``True``, the class will not shard paramaters + during initialization. Users need to call ``optimizer.init_zero()`` by themselves. + Default: False + bucket_cap_mb_all_gather (int, Optional): Number of MegaBytes of the tensor bucket to fill before + doing all-gather. Default: False + bucket_cap_mb_reduce_scatter (int, Optional): Number of MegaBytes of the tensor bucket to fill before + doing reduce-scatter. Default: False + use_grad_acc_hook (bool, Optional): if ``True``, use hooks for gradients accumulation, then + ``dtype`` of grad accumulation will be the same as ``optimizer_dtype``. Users can set this + to True to use higher precision for gradients accumulation. Default: False + save_master_weights (bool, Optional): + if ``True``, also save sharded master weights. Default: False + higher_cc_precision (bool, Optional): if ``True``, use higher precision for collective communication + operators (the same as ``optimizer_dtype``). Default: False **defaults: any trailing arguments, which are forwarded to the local optimizer. @@ -65,8 +77,16 @@ def __init__( lazy_init: bool = False, bucket_cap_mb_all_gather: int = 0, bucket_cap_mb_reduce_scatter: int = 0, + use_grad_acc_hook: bool = False, + save_master_weights: bool = False, + higher_cc_precision: bool = False, **defaults: Any, ): + if not save_master_weights: + logging.warning( + 'Not saving master weights may have accuracy issues when resuming training!' + ) + super().__init__(params, defaults) self.global_world_size = xm.xrt_world_size() @@ -85,6 +105,10 @@ def __init__( self.bucket_cap_mb_reduce_scatter = bucket_cap_mb_reduce_scatter self.coalesce_cc_all_gather = bucket_cap_mb_all_gather > 0 self.coalesce_cc_reduce_scatter = bucket_cap_mb_reduce_scatter > 0 + self.higher_cc_precision = higher_cc_precision + self.use_grad_acc_hook = use_grad_acc_hook + self.grad_accs = [] + self.grad_acc_hooks = [] self._grad_norm = None @@ -93,6 +117,7 @@ def __init__( self.init_zero() def init_zero(self): + self.remove_hooks() self.local_world_size = len(self.sharding_groups[0]) # Infer the local rank from the group self.local_rank = None @@ -169,6 +194,41 @@ def _shard_tensor(self, tensor: torch.Tensor): tensor = tensor.chunk(self.local_world_size)[self.local_rank] return tensor + def _make_param_hook(self, param, shard): + """ + Create the grad accumulation hook for backprop. + """ + + def _param_hook(*unused): + # Accumulates gradients on main gradients + if param.grad is not None: + if not hasattr(shard, 'main_grad'): + # Create main gradients + shard.main_grad = torch.zeros( + param.shape, + dtype=self.optimizer_dtype, + device=self.device, + requires_grad=False) + param.main_grad = shard.main_grad + shard.main_grad.add_(param.grad.data.to(dtype=self.optimizer_dtype)) + # Deallocate grad memory. + param.grad = None + + return _param_hook + + def _register_hook(self, param, shard): + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + hook = grad_acc.register_hook(self._make_param_hook(param, shard)) + self.grad_acc_hooks.append(hook) + self.grad_accs.append(grad_acc) + + def remove_hooks(self): + for hook in self.grad_acc_hooks: + hook.remove() + self.grad_acc_hooks = [] + self.grad_accs = [] + def _shard_parameters(self): """ Shard all parameters. @@ -196,6 +256,8 @@ def _shard_parameters(self): shard_data = shard_data.to(dtype=self.optimizer_dtype) shard_data = shard_data.to(device=self.device) # move to xla device shard = nn.Parameter(shard_data, requires_grad=param.requires_grad) + if self.use_grad_acc_hook: + self._register_hook(param, shard) sharded_params.append(shard) sharded_params_group = copy.copy(param_group) sharded_params_group['params'] = sharded_params @@ -282,9 +344,13 @@ def step(self, closure=None, **kwargs): self.param_groups, self.base_optimizer.param_groups): for param, shard in zip(param_group['params'], sharded_param_group['params']): - if param.grad is not None: - padded_grad = self._pad_to_world_size(param.grad, - self.local_world_size) + if param.grad is not None or (self.use_grad_acc_hook and + hasattr(shard, 'main_grad')): + padded_grad = self._pad_to_world_size( + shard.main_grad if self.use_grad_acc_hook else param.grad, + self.local_world_size) + if self.higher_cc_precision: + padded_grad = padded_grad.to(dtype=self.optimizer_dtype) if self.coalesce_cc_reduce_scatter: padded_grads.append(padded_grad) else: @@ -317,7 +383,8 @@ def step(self, closure=None, **kwargs): self.param_groups, self.base_optimizer.param_groups): for param, shard in zip(param_group['params'], sharded_param_group['params']): - if param.grad is not None: + if param.grad is not None or (self.use_grad_acc_hook and + hasattr(shard, 'main_grad')): grad_shard = grad_shards[index] if grad_shard.dtype != self.optimizer_dtype: grad_shard = grad_shard.to(dtype=self.optimizer_dtype) @@ -334,15 +401,24 @@ def step(self, closure=None, **kwargs): # Remove shards' grads self.base_optimizer.zero_grad(set_to_none=True) + self.allgather_weights_and_update_full_parameter() + + # sync back + self._sync_param_groups(self.base_optimizer.param_groups, self.param_groups) + + return loss + + def allgather_weights_and_update_full_parameter(self): # All gather the new weights across the ranks and assign them to the full parameters sharded_data = [] for param_group, sharded_param_group in zip( self.param_groups, self.base_optimizer.param_groups): for param, shard in zip(param_group['params'], sharded_param_group['params']): - if param.grad is not None: + if param.grad is not None or (self.use_grad_acc_hook and + hasattr(shard, 'main_grad')): shard_data = shard.data - if param.dtype != self.optimizer_dtype: + if not self.higher_cc_precision: shard_data = shard_data.to(dtype=param.dtype) if self.coalesce_cc_all_gather: sharded_data.append(shard_data) @@ -353,6 +429,8 @@ def step(self, closure=None, **kwargs): pin_layout=self.pin_layout, groups=self.sharding_groups, ) + if padded_param.dtype != param.dtype: + padded_param = padded_param.to(dtype=param.dtype) param.data.copy_(padded_param.data[:param.size(0)]) if self.coalesce_cc_all_gather: @@ -368,21 +446,36 @@ def step(self, closure=None, **kwargs): self.param_groups, self.base_optimizer.param_groups): for param, shard in zip(param_group['params'], sharded_param_group['params']): - if param.grad is not None: + if param.grad is not None or (self.use_grad_acc_hook and + hasattr(shard, 'main_grad')): padded_param = padded_params[index] + if padded_param.dtype != param.dtype: + padded_param = padded_params[index].to(dtype=param.dtype) param.data.copy_(padded_param.data[:param.size(0)]) index += 1 - # sync back - self._sync_param_groups(self.base_optimizer.param_groups, self.param_groups) - - return loss + def zero_grad(self, set_to_none: bool = False): + super().zero_grad(set_to_none=set_to_none) + if self.use_grad_acc_hook: + for sharded_param_group in self.base_optimizer.param_groups: + for shard in sharded_param_group['params']: + if hasattr(shard, 'main_grad'): + shard.main_grad.zero_() def state_dict(self): state_dict = super().state_dict() base_state = self.base_optimizer.state_dict()['state'] state_dict['base_state'] = base_state state_dict['shape_info'] = self.get_shape_info() + if self.save_master_weights: + index = 0 + master_weights = {} + for sharded_param_group in self.base_optimizer.param_groups: + for shard in sharded_param_group['params']: + master_weights[index] = shard.data + index += 1 + state_dict['sharded_master_weights'] = master_weights + return state_dict def load_state_dict(self, state_dict): @@ -396,6 +489,30 @@ def load_state_dict(self, state_dict): tmp = self.base_optimizer.state_dict() tmp['state'] = base_state self.base_optimizer.load_state_dict(tmp) + if 'sharded_master_weights' in state_dict: + master_weights = state_dict['sharded_master_weights'] + index = 0 + for param_group, sharded_param_group in zip( + self.param_groups, self.base_optimizer.param_groups): + for param, shard in zip(param_group['params'], + sharded_param_group['params']): + shard.data.copy_(master_weights[index]) + # set dummy gradient for allgather to be triggered. + if self.use_grad_acc_hook: + # Create main gradients + shard.main_grad = torch.zeros( + param.shape, + dtype=self.optimizer_dtype, + device=self.device, + requires_grad=False) + param.main_grad = shard.main_grad + else: + param.grad = torch.zeros_like(param.data) + index += 1 + xm.mark_step() + # add mark_step around allgather to avoid large number of compilation + self.allgather_weights_and_update_full_parameter() + xm.mark_step() def get_shape_info(self): shape_info = {}