From d91a5d9d9ec8c7b0444b533a6b44be4430c7c199 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Fri, 20 Oct 2023 15:59:40 +0800 Subject: [PATCH 1/3] feat(initialize/launch.py): refactor config for fstp --- configs/7B_sft.py | 10 ++--- internlm/initialize/launch.py | 23 ++++++---- internlm/model/modeling_internlm.py | 14 +++--- internlm/model/multi_head_attention.py | 8 ++-- .../solver/optimizer/hybrid_zero_optim.py | 45 ++++++++++++------- internlm/train/training_internlm.py | 3 +- internlm/utils/evaluation.py | 4 +- 7 files changed, 63 insertions(+), 44 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 6ea8b96e..c51c8129 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -152,19 +152,19 @@ 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. - 2. mode: str, the mode should be 'origin_tp' or 'fstp', defaults to 'origin_tp'. If the mode is 'fstp', - the sequence_parallel should be True. + 2. sp: str, the sequence parallel mode, should be in ['none', 'megatron', 'flash-attn', 'intern'], + defaults to 'none', means the sequence parallel will be disabled. + 3. intern_overlap: bool, enable/disable all_gather/reduce_scatter communication overlap when using 'intern' mode sp, + defaults to False. pipeline parallel (dict): 1. size: int, the size of pipeline parallel. 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, defaults to False. -sequence parallel (bool): enable/disable sequence parallel, defaults to False. """ parallel = dict( zero1=dict(size=-1, fsdp=False), - tensor=dict(size=8, mode="fstp", overlap=True), + tensor=dict(size=8, sp="intern", intern_overlap=True), pipeline=dict(size=1, interleaved_overlap=True), - sequence_parallel=True, ) cudnn_deterministic = False diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 80611fee..0e74f76b 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -306,15 +306,20 @@ def args_sanity_check(): ), "sequence parallel does not support use_flash_attn=False" if isinstance(gpc.config.parallel["tensor"], int): - gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode="origin_tp") - - if gpc.config.parallel["tensor"].get("mode", None) is None: - gpc.config.parallel["tensor"]["mode"] = "origin_tp" - - if gpc.config.parallel["tensor"].get("mode", None) == "fstp": - assert ( - gpc.config.parallel.sequence_parallel is True - ), "when the tp_mode is fstp, the sequence_parallel should be True." + gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], sp="none", intern_overlap=False) + if gpc.config.parallel["tensor"].get("sp", None) is None: + gpc.config.parallel["tensor"]["sp"] = "none" + if gpc.config.parallel["tensor"].get("intern_overlap", None) is None: + gpc.config.parallel["tensor"]["intern_overlap"] = False + assert gpc.config.parallel["tensor"].get("sp", None) in [ + "none", + "megatron", + "flash-attn", + "intern", + ], "invalid sp mode, only ['none', 'megatron', 'flash-attn', 'intern'] is supported" + # adapt to old version's sequence parallel config + if gpc.config.parallel["tensor"].get("sp", None) in ["megatron", "flash-attn", "intern"]: + gpc.config.parallel.sequence_parallel = True # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1: diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 0df2b60e..9b6420d4 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -77,7 +77,7 @@ def __init__( use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, - tp_mode: str = "origin_tp", + sp_mode: str = "none", ): super().__init__() self.checkpoint = checkpoint @@ -102,7 +102,7 @@ def __init__( use_flash_attn=use_flash_attn, device=device, dtype=dtype, - tp_mode=tp_mode, + sp_mode=sp_mode, ) self.dropout1 = nn.Dropout(drop_rate) @@ -114,7 +114,7 @@ def __init__( self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) if use_swiglu: - mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward + mlp_cls = FSTPFeedForward if sp_mode == "intern" else FeedForward self.mlp = mlp_cls( hidden_size, int(hidden_size * mlp_ratio), @@ -297,7 +297,7 @@ def __init__( super().__init__() checkpoint_layer_num = int(num_layers * checkpoint) - self.tp_mode = gpc.config.parallel["tensor"]["mode"] + self.sp_mode = gpc.config.parallel["tensor"]["sp"] if is_reward: head_cls = RewardModelLinear @@ -343,7 +343,7 @@ def __init__( use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, - tp_mode=self.tp_mode, + sp_mode=self.sp_mode, ) for lid in range(num_layers) ] @@ -389,8 +389,8 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N assert len(indexes) == 1 # The indexes are used to indicate the actual position IDs of each token in the packed input. indexes = indexes[0] - # if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension. - if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp": + # if the sequence parallel mode is 'intern', the indexes should also be split in sequence dimension. + if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern": indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 8dcd3f96..cb0efb85 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -175,7 +175,7 @@ def __init__( use_flash_attn: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - tp_mode: str = "origin_tp", + sp_mode: str = "none", ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -203,7 +203,7 @@ def __init__( self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) # notice here should change bias=True - Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear + Wqkv_cls = FSTPLinear if sp_mode == "intern" else ColumnParallelLinearTorch self.Wqkv = Wqkv_cls( embed_dim, 3 * embed_dim, @@ -219,12 +219,12 @@ def __init__( self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) - if tp_mode == "fstp": + if sp_mode == "intern": self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group) self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group) # output projection always have the bias (for now) - out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear + out_proj_cls = FSTPLinear if sp_mode == "intern" else RowParallelLinearTorch self.out_proj = out_proj_cls( embed_dim, embed_dim, diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 96a54c01..a4b31737 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -10,7 +10,10 @@ from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool +from internlm.model.utils import ( + release_reduce_scatter_memory_pool, + split_forward_gather_backward, +) from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, @@ -40,8 +43,20 @@ inf = math.inf logger = get_logger(__file__) + def print_memory(msg): - print(msg, " rank = ", gpc.get_global_rank(), " memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, " reverved memory: ", torch.cuda.memory_reserved() / 1024 / 1024 / 1024, " max memory: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True) + print( + msg, + " rank = ", + gpc.get_global_rank(), + " memory allocated: ", + torch.cuda.memory_allocated() / 1024 / 1024 / 1024, + " reverved memory: ", + torch.cuda.memory_reserved() / 1024 / 1024 / 1024, + " max memory: ", + torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, + flush=True, + ) print("===========================================") @@ -69,8 +84,8 @@ def __init__( backoff_factor = grad_scal_cfg.backoff_factor hysteresis = grad_scal_cfg.hysteresis max_scale = grad_scal_cfg.max_scale - - if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: + + if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: self._fstp_handler = gpc.config.fstp_handler # Zero related args @@ -306,7 +321,7 @@ def _define_and_attach(param, reduce_rank=None): param=param, reduce_rank=reduce_rank, ) - + reduce_scatter_checker = partial( self._wait_reduce_scatter_and_accumulate_grad, param=param, @@ -354,7 +369,7 @@ def reset_reduce_bucket(self) -> None: _param.grad.add_(_grad) # self._fstp_handler.reduce_scatter_handlers[key] = None # del _grad - release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index) + release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index) del self._fstp_handler.reduce_scatter_handlers[key] self._fstp_handler.reduce_scatter_handlers[key] = None assert key in self._fstp_handler.reduce_scatter_handlers @@ -374,7 +389,7 @@ def reset_reduce_bucket(self) -> None: # assert key in self._fstp_handler.all_reduce_handlers bucket.reset_by_rank(rank) - + def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None): param_size = param.numel() @@ -397,11 +412,11 @@ def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None): _param.grad.add_(_grad) # self._fstp_handler.reduce_scatter_handlers[key] = None # del _grad - release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index) + release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index) del self._fstp_handler.reduce_scatter_handlers[key] self._fstp_handler.reduce_scatter_handlers[key] = None assert key in self._fstp_handler.reduce_scatter_handlers - + # if not hasattr(_param, "_fstp_all_reduce_str"): # continue @@ -418,7 +433,7 @@ def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None): # assert key in self._fstp_handler.all_reduce_handlers current_bucket.reset_by_rank(reduce_rank) - + current_bucket.add_num_elements_in_bucket(param_size, reduce_rank) current_bucket.add_param(param, reduce_rank) @@ -685,16 +700,16 @@ def step(self, closure=None): timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop() - + print_memory("No 4") - + try: - res = self._step(closure=closure, norms=total_norms) + res = self._step(closure=closure, norms=total_norms) except torch.cuda.OutOfMemoryError as e: print(e, flush=True) print(torch.cuda.memory_summary(), flush=True) torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") - + return res def _step(self, closure=None, norms=None): @@ -822,7 +837,7 @@ def _step(self, closure=None, norms=None): torch.cuda.synchronize() with torch.cuda.stream(self._comm_bcast_stream): self.broadcast_params() - + timer("step").stop() # update gradients may not be needed here, because the sync_params function is used in initialization, diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 5205ba5b..53996b38 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -110,9 +110,8 @@ def initialize_model(): gpc.config.fstp_handler = None - if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: + if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) - # handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) handler._register_sync_parameters_hook() gpc.config.fstp_handler = handler diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 968a1db1..f708fa78 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape def switch_sequence_parallel_mode(): prev_mode = gpc.config.parallel.sequence_parallel try: - if gpc.config.parallel["tensor"]["mode"] == "fstp": + if gpc.config.parallel["tensor"]["sp"] == "intern": gpc.config.parallel.sequence_parallel = True else: gpc.config.parallel.sequence_parallel = False @@ -106,7 +106,7 @@ def evaluate_on_val_dls( total_val_bsz = len(batch[1]) assert total_val_bsz % data_cfg.micro_bsz == 0 num_microbatches = total_val_bsz // data_cfg.micro_bsz - if gpc.config.parallel["tensor"]["mode"] == "fstp": + if gpc.config.parallel["tensor"]["sp"] == "intern": sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) tensor_shape = torch.Size( [ From eac382ad0a0ed6075b31fbdb8a56d42239fa9f4f Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Fri, 20 Oct 2023 16:22:29 +0800 Subject: [PATCH 2/3] feat(optimizer/hybrid_zero_optim.py): fix lint error --- internlm/model/utils.py | 5 ++--- internlm/solver/optimizer/hybrid_zero_optim.py | 5 +---- internlm/solver/optimizer/store.py | 2 +- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index b9c7c03a..19531e4a 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -1,12 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Any, Optional, Union +from typing import Optional import fused_dense_lib as fused_dense_cuda import torch import torch.nn.functional as F -from flash_attn.utils.distributed import all_reduce_raw # , reduce_scatter_raw +from flash_attn.utils.distributed import all_reduce_raw from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup @@ -397,7 +397,6 @@ def backward(ctx, grad_output, *args): grad_input = grad_input.contiguous() process_group = ctx.process_group all_gather_handler = ctx.all_gather_handler - module = ctx.module block_index = ctx.block_index module_name = ctx.module_name diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index d5fec315..cb8aa659 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -11,10 +11,7 @@ from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import ( - release_reduce_scatter_memory_pool, - split_forward_gather_backward, -) +from internlm.model.utils import release_reduce_scatter_memory_pool from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, diff --git a/internlm/solver/optimizer/store.py b/internlm/solver/optimizer/store.py index 228045ed..f486ccec 100644 --- a/internlm/solver/optimizer/store.py +++ b/internlm/solver/optimizer/store.py @@ -45,7 +45,7 @@ def __init__(self, group_id, dp_parallel_mode): def num_elements_in_bucket(self, reduce_rank: int = None): return self._num_elements_in_bucket[reduce_rank] - + def num_params_in_bucket(self, reduce_rank: int = None): return len(self._params[reduce_rank]) From 2acf9b817f6888e73c3606ddc6549f8c95694b27 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Fri, 20 Oct 2023 16:25:08 +0800 Subject: [PATCH 3/3] feat(utils/gputest.py): fix lint error --- internlm/utils/gputest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internlm/utils/gputest.py b/internlm/utils/gputest.py index 52d96385..bf4cf1c9 100644 --- a/internlm/utils/gputest.py +++ b/internlm/utils/gputest.py @@ -45,7 +45,7 @@ def empty_cache_and_diag(batch_count, interval=50): # # import time # # time.sleep(10) # print(e, "rank = ", gpc.get_global_rank(), flush=True) - # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") # do empty_cache after the bench torch.cuda.empty_cache()