From 229cc5c68c518734edfc01d36e6bd616d32a7224 Mon Sep 17 00:00:00 2001 From: "chenxun.p" Date: Tue, 17 Oct 2023 11:15:54 +0800 Subject: [PATCH 1/3] impl reduce scatter async --- .../core/scheduler/no_pipeline_scheduler.py | 1 + internlm/model/linear.py | 23 +++++++++-- internlm/model/utils.py | 28 ++++++++------ .../solver/optimizer/hybrid_zero_optim.py | 38 +++++++++++++++++-- 4 files changed, 71 insertions(+), 19 deletions(-) diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 56661d8c..f0caf05c 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -194,6 +194,7 @@ def forward_backward_step( _output, _loss, _moe_loss = self._train_one_batch( _data, _label, engine, forward_only, return_loss, self._grad_accum_size ) + engine.optimizer.reset_reduce_bucket() if return_loss: loss += _loss diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 8e19ab69..b141829e 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -329,6 +329,8 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module} self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name + self.reduce_scatter_handlers = {} + # just want to share same for loop for ModuleList and Module if not isinstance(model, nn.ModuleList): model = [model] @@ -337,16 +339,22 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non if isinstance(_chunk, NaiveAMPModel): _chunk = _chunk.model - for _, children in _chunk.named_children(): + for _chunk_name, children in _chunk.named_children(): if isinstance(children, nn.ModuleList): for idx, block in enumerate(children): index = 0 self.block_module[idx] = {} - for _, sub in block.named_children(): + for _sub_name, sub in block.named_children(): sub_modules = list(sub.children()) if len(sub_modules) > 0: for name, child in sub.named_children(): if isinstance(child, FSTPLinear): + + _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" + setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") + if child.bias is not None: + setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") + self.FSTP_modules.append(child) self.module_block[child] = idx self.block_module[idx][index] = child @@ -450,6 +458,8 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module} + self.reduce_scatter_handlers = {} + # just want to share same for loop for ModuleList and Module if not isinstance(model, nn.ModuleList): model = [model] @@ -458,7 +468,7 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non if isinstance(_chunk, NaiveAMPModel): _chunk = _chunk.model - for _, children in _chunk.named_children(): + for _chunk_name, children in _chunk.named_children(): if isinstance(children, nn.ModuleList): for idx, block in enumerate(children): index = 0 @@ -467,7 +477,7 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non self.block_to_index[block] = idx self.index_to_block[idx] = block self.index_to_fsdp_modules[idx] = [] - for _, sub in block.named_children(): + for _sub_name, sub in block.named_children(): sub_modules = list(sub.children()) if len(sub_modules) > 0: for name, child in sub.named_children(): @@ -485,6 +495,11 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non self.index_to_fsdp_modules[idx].append(child) self.module_name_index[child] = index index = index + 1 + + _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" + setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") + if child.bias is not None: + setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") else: continue diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 97319d98..78ad456d 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -324,9 +324,9 @@ def forward(ctx, x, weight, bias, return_residual=False, process_group=None, mod raise RuntimeError("fused_dense only supports matrix dims <= 2M") output = F.linear(total_x, total_weight, total_bias) if ctx.compute_weight_gradient: - ctx.save_for_backward(x, weight) + ctx.save_for_backward(x, weight, bias) else: - ctx.save_for_backward(weight) + ctx.save_for_backward(weight, bias) return output if not return_residual else (output, x) @staticmethod @@ -340,10 +340,10 @@ def backward(ctx, grad_output, *args): all_gather_handler = ctx.all_gather_handler module = ctx.module if ctx.compute_weight_gradient: - x, weight = ctx.saved_tensors + x, weight, bias = ctx.saved_tensors total_x = x else: - (weight,) = ctx.saved_tensors + weight, bias = ctx.saved_tensors total_x = None batch_shape = grad_output.shape[:-1] batch_dim = batch_shape.numel() @@ -368,9 +368,15 @@ def backward(ctx, grad_output, *args): total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] ) if world_size > 1: - grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) + grad_weight_async, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) + assert hasattr(weight, "_fstp_reduce_scatter_str") + all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async) + grad_weight = torch.empty(grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:], dtype=grad_weight.dtype, device=grad_weight.device) if grad_bias is not None: - grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) + grad_bias_async, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) + assert hasattr(bias, "_fstp_reduce_scatter_str") + all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async) + grad_bias = torch.empty(grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:], dtype=grad_bias.dtype, device=grad_bias.device) else: grad_weight = None grad_bias = grad_output if ctx.needs_input_grad[2] else None @@ -384,11 +390,11 @@ def backward(ctx, grad_output, *args): else: grad_input = None - if ctx.needs_input_grad[1]: - if world_size > 1: - handle_grad_weight.wait() - if grad_bias is not None: - handle_grad_bias.wait() + # if ctx.needs_input_grad[1]: + # if world_size > 1: + # handle_grad_weight.wait() + # if grad_bias is not None: + # handle_grad_bias.wait() return grad_input, grad_weight, grad_bias, None, None, None, None diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 97004eb9..c6e9aaba 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -65,6 +65,8 @@ def __init__( hysteresis = grad_scal_cfg.hysteresis max_scale = grad_scal_cfg.max_scale + self._fstp_handler = gpc.config.fstp_handler + # Zero related args reduce_bucket_size = zero_cfg.reduce_bucket_size clip_grad_norm = zero_cfg.clip_grad_norm @@ -301,8 +303,7 @@ def _define_and_attach(param, reduce_rank=None): # NOT IMPORTANT BUT GOOD TO KNOW: # args here is not grad, but allow_unreacable and accumulate_grad def reduce_grad_hook(*args): # pylint: disable=W0613 - if self.skip_grad_reduce is False: - reduction_func() + reduction_func() accum_grad_obj.register_hook(reduce_grad_hook) @@ -322,6 +323,20 @@ def belongs_to_current_rank(self, param) -> bool: group_id = getattr(param, "group_id") return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id]) + def reset_reduce_bucket(self) -> None: + for bucket in self._bucket_store: + for rank, params in bucket._params.items(): + for _param in params: + if not hasattr(_param, "_fstp_reduce_scatter_str"): + continue + + key = getattr(_param, "_fstp_reduce_scatter_str") + comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] + comm_handle.wait() + _param.grad = _grad + + bucket.reset_by_rank(rank) + def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None): param_size = param.numel() @@ -332,11 +347,26 @@ def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None): current_bucket = self._bucket_store[group_id] if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: - self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False) + # wait reduce scatter communication + params = current_bucket.get_param(reduce_rank) + for _param in params: + if not hasattr(_param, "_fstp_reduce_scatter_str"): + continue + + key = getattr(_param, "_fstp_reduce_scatter_str") + comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] + comm_handle.wait() + _param.grad = _grad + + # reduce grad + if self.skip_grad_reduce is False: + self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False) + else: + current_bucket.reset_by_rank(reduce_rank) # the param must not be reduced to ensure correctness is_param_reduced = self._param_store.is_param_reduced(param) - if is_param_reduced: + if is_param_reduced and self.skip_grad_reduce is False: msg = ( f"Parameter of size ({param.size()}) has already been reduced, " + "duplicate reduction will lead to arithmetic incorrectness" From 4e99a7fdbc88e398255d63a9b22854b5ded5deb3 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Tue, 17 Oct 2023 11:30:44 +0800 Subject: [PATCH 2/3] feat(train/training_internlm.py): remove abnormal tgs when calculating avg tgs --- internlm/train/training_internlm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 24040a02..cc310a21 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -576,4 +576,8 @@ def record_current_batch_training_metrics( tgs_list.append(tgs_origin) if batch_count == gpc.config.data.total_steps - 1: print(tgs_list, flush=True) + avg_tgs = sum(tgs_list) / len(tgs_list) + for tgs in tgs_list.copy(): + if abs(tgs - avg_tgs) > 1000: + tgs_list.remove(tgs) print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True) From 6682f5d92a02111777f5c1fbc8c0765c9770ffa2 Mon Sep 17 00:00:00 2001 From: "chenxun.p" Date: Tue, 17 Oct 2023 15:10:07 +0800 Subject: [PATCH 3/3] fix reduce scatter async bug --- internlm/model/utils.py | 4 ++-- internlm/solver/optimizer/hybrid_zero_optim.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 78ad456d..0194e84a 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -371,12 +371,12 @@ def backward(ctx, grad_output, *args): grad_weight_async, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) assert hasattr(weight, "_fstp_reduce_scatter_str") all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async) - grad_weight = torch.empty(grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:], dtype=grad_weight.dtype, device=grad_weight.device) + grad_weight = torch.zeros(grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:], dtype=grad_weight.dtype, device=grad_weight.device) if grad_bias is not None: grad_bias_async, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) assert hasattr(bias, "_fstp_reduce_scatter_str") all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async) - grad_bias = torch.empty(grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:], dtype=grad_bias.dtype, device=grad_bias.device) + grad_bias = torch.zeros(grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:], dtype=grad_bias.dtype, device=grad_bias.device) else: grad_weight = None grad_bias = grad_output if ctx.needs_input_grad[2] else None diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index c6e9aaba..950d35e8 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -333,7 +333,7 @@ def reset_reduce_bucket(self) -> None: key = getattr(_param, "_fstp_reduce_scatter_str") comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] comm_handle.wait() - _param.grad = _grad + _param.grad += _grad bucket.reset_by_rank(rank) @@ -356,7 +356,7 @@ def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None): key = getattr(_param, "_fstp_reduce_scatter_str") comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] comm_handle.wait() - _param.grad = _grad + _param.grad += _grad # reduce grad if self.skip_grad_reduce is False: