From b20f47a1fe5fb446f2d9df5a83b31cb6033579f0 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 23 Oct 2023 12:02:32 +0800 Subject: [PATCH] feat(model/overlap_handler.py): move handler to gpc --- internlm/model/linear.py | 5 +--- internlm/model/overlap_handler.py | 16 ++++------ internlm/model/utils.py | 29 ++++++------------- .../solver/optimizer/hybrid_zero_optim.py | 4 +-- internlm/train/training_internlm.py | 4 +-- train.py | 6 ++-- 6 files changed, 23 insertions(+), 41 deletions(-) diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 6cd3b9c8..b92b2ee5 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -352,16 +352,13 @@ def __init__( class FSTPLinear(ColumnParallelLinear): def forward(self, x): - block_index = gpc.config.fstp_handler.module_to_index[self] return fstp_fused_dense_func( x, self.weight, self.bias, process_group=self.process_group, module=self, - handler=gpc.config.fstp_handler, - block_index=block_index, - module_name=self._fstp_name, + handler=gpc.fstp_handler, ) diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index cafb8183..b6877234 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -116,8 +116,9 @@ def _initialize_memory_pool(self) -> None: self.all_gather_memory_pool.append(weight) # containing two groups of block weight - def get_all_gather_memory(self, index, module_name): - return self.all_gather_memory_pool[index % 2][module_name] + def get_all_gather_memory(self, module): + block_index = self.module_to_index[module] + return self.all_gather_memory_pool[block_index % 2][module._fstp_name] def get_reduce_scatter_memory(self, key): return_idx = 0 @@ -163,8 +164,7 @@ def _all_gather_block_weight_memory_pool(self, block_index: int): module.weight, self.process_group, async_op=True, - block_index=block_index, - module_name=getattr(module, "_fstp_name"), + module=module, ) self.fstp_global_handle[module] = weight_handle @@ -192,13 +192,11 @@ def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): first_backward_module = self.fstp_modules[-1] - block_index = self.module_to_index[first_backward_module] weight_handle = all_gather_raw_memory_pool( first_backward_module.weight, self.process_group, async_op=True, - block_index=block_index, - module_name=getattr(first_backward_module, "_fstp_name"), + module=first_backward_module, ) self.fstp_global_handle[first_backward_module] = weight_handle @@ -211,13 +209,11 @@ def _pre_backward_hook_for_module(module: nn.Module, grad_output): module_index = self.fstp_modules.index(module) if module_index - 1 >= 0: next_module = self.fstp_modules[module_index - 1] - block_index = self.module_to_index[next_module] weight_handle = all_gather_raw_memory_pool( next_module.weight, self.process_group, async_op=True, - block_index=block_index, - module_name=getattr(next_module, "_fstp_name"), + module=next_module, ) self.fstp_global_handle[next_module] = weight_handle diff --git a/internlm/model/utils.py b/internlm/model/utils.py index ccdca481..cdbed954 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -7,13 +7,12 @@ import torch import torch.nn.functional as F from flash_attn.utils.distributed import all_reduce_raw -from torch import Tensor +from torch import Tensor, nn from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger logger = get_logger(__file__) @@ -131,11 +130,10 @@ def all_gather_raw_memory_pool( process_group: ProcessGroup, async_op: bool = False, gather_dim: int = 0, - block_index: int = None, - module_name: str = None, + module: nn.Module = None, ): handle = torch.distributed.all_gather_into_tensor( - gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name), + gpc.fstp_handler.get_all_gather_memory(module=module), input_.contiguous(), group=process_group, async_op=async_op, @@ -166,8 +164,8 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, world_size = torch.distributed.get_world_size(process_group) assert input_.shape[0] % world_size == 0 size = (input_.shape[0] // world_size, *input_.shape[1:]) - index = gpc.config.fstp_handler.get_reduce_scatter_memory(size) - output = gpc.config.fstp_handler.reduce_scatter_memory_pool[size]["data"][index] + index = gpc.fstp_handler.get_reduce_scatter_memory(size) + output = gpc.fstp_handler.reduce_scatter_memory_pool[size]["data"][index] setattr(output, "index", index) handle = torch.distributed.reduce_scatter_tensor( output, input_.contiguous(), group=process_group, async_op=async_op @@ -469,16 +467,12 @@ def forward( process_group=None, module=None, overlap_handler=None, - block_index=None, - module_name=None, ): ctx.compute_weight_gradient = weight.requires_grad ctx.return_residual = return_residual ctx.process_group = process_group ctx.overlap_handler = overlap_handler ctx.module = module - ctx.block_index = block_index - ctx.module_name = module_name if torch.is_autocast_enabled(): x = x.to(dtype=torch.get_autocast_gpu_dtype()) @@ -488,7 +482,7 @@ def forward( if world_size > 1: # do all_gather for weight and bias before actual computation if overlap_handler is not None: - total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name) + total_weight = gpc.fstp_handler.get_all_gather_memory(module=module) else: total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) handle_weight.wait() @@ -531,8 +525,7 @@ def backward(ctx, grad_output, *args): grad_input = grad_input.contiguous() process_group = ctx.process_group overlap_handler = ctx.overlap_handler - block_index = ctx.block_index - module_name = ctx.module_name + module = ctx.module if ctx.compute_weight_gradient: x, weight, bias = ctx.saved_tensors @@ -547,7 +540,7 @@ def backward(ctx, grad_output, *args): world_size = gpc.get_world_size(ParallelMode.TENSOR) if world_size > 1: if overlap_handler is not None: - total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name) + total_weight = gpc.fstp_handler.get_all_gather_memory(module=module) else: total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) handle_weight.wait() @@ -669,16 +662,12 @@ def fstp_fused_dense_func( process_group=None, module=None, handler=None, - block_index=None, - module_name=None, ): dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( x.dtype == torch.float32 and torch.is_autocast_enabled() ) if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FSTPFusedDenseFunc.apply( - x, weight, bias, return_residual, process_group, module, handler, block_index, module_name - ) + return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler) else: assert process_group is None out = F.linear(x, weight, bias) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index e2ec7efd..08d97229 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -68,7 +68,7 @@ def __init__( self._fstp_handler = None if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: - self._fstp_handler = gpc.config.fstp_handler + self._fstp_handler = gpc.fstp_handler # Zero related args reduce_bucket_size = zero_cfg.reduce_bucket_size @@ -350,7 +350,7 @@ def _accum_grads_store_in_bucket(self, bucket: BucketStore, reduce_rank: Optiona _param.grad.add_(_grad) # release cuda memory. - gpc.config.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index) + gpc.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index) self._fstp_handler.reduce_scatter_handlers[_key] = None bucket.reset_by_rank(reduce_rank) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index cabb7ebd..b05611bc 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -108,9 +108,9 @@ def initialize_model(): # if fsdp enabled, wrap the model model = wrap_FSDP_model(model) - gpc.config.fstp_handler = None + gpc.fstp_handler = None if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: - gpc.config.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR)) + gpc.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR)) return model diff --git a/train.py b/train.py index 5066960e..96dc24d1 100644 --- a/train.py +++ b/train.py @@ -297,9 +297,9 @@ def main(args): prof.step() - if gpc.config.fstp_handler is not None: - gpc.config.fstp_handler.zero_const_pool = {} - gpc.config.fstp_handler.reduce_scatter_memory_pool = {} + if gpc.fstp_handler is not None: + gpc.fstp_handler.zero_const_pool = {} + gpc.fstp_handler.reduce_scatter_memory_pool = {} # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") torch.cuda.reset_peak_memory_stats()