diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 6ea8b96e..c51c8129 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -152,19 +152,19 @@ 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. - 2. mode: str, the mode should be 'origin_tp' or 'fstp', defaults to 'origin_tp'. If the mode is 'fstp', - the sequence_parallel should be True. + 2. sp: str, the sequence parallel mode, should be in ['none', 'megatron', 'flash-attn', 'intern'], + defaults to 'none', means the sequence parallel will be disabled. + 3. intern_overlap: bool, enable/disable all_gather/reduce_scatter communication overlap when using 'intern' mode sp, + defaults to False. pipeline parallel (dict): 1. size: int, the size of pipeline parallel. 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, defaults to False. -sequence parallel (bool): enable/disable sequence parallel, defaults to False. """ parallel = dict( zero1=dict(size=-1, fsdp=False), - tensor=dict(size=8, mode="fstp", overlap=True), + tensor=dict(size=8, sp="intern", intern_overlap=True), pipeline=dict(size=1, interleaved_overlap=True), - sequence_parallel=True, ) cudnn_deterministic = False diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 80611fee..0e74f76b 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -306,15 +306,20 @@ def args_sanity_check(): ), "sequence parallel does not support use_flash_attn=False" if isinstance(gpc.config.parallel["tensor"], int): - gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode="origin_tp") - - if gpc.config.parallel["tensor"].get("mode", None) is None: - gpc.config.parallel["tensor"]["mode"] = "origin_tp" - - if gpc.config.parallel["tensor"].get("mode", None) == "fstp": - assert ( - gpc.config.parallel.sequence_parallel is True - ), "when the tp_mode is fstp, the sequence_parallel should be True." + gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], sp="none", intern_overlap=False) + if gpc.config.parallel["tensor"].get("sp", None) is None: + gpc.config.parallel["tensor"]["sp"] = "none" + if gpc.config.parallel["tensor"].get("intern_overlap", None) is None: + gpc.config.parallel["tensor"]["intern_overlap"] = False + assert gpc.config.parallel["tensor"].get("sp", None) in [ + "none", + "megatron", + "flash-attn", + "intern", + ], "invalid sp mode, only ['none', 'megatron', 'flash-attn', 'intern'] is supported" + # adapt to old version's sequence parallel config + if gpc.config.parallel["tensor"].get("sp", None) in ["megatron", "flash-attn", "intern"]: + gpc.config.parallel.sequence_parallel = True # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1: diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 0df2b60e..9b6420d4 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -77,7 +77,7 @@ def __init__( use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, - tp_mode: str = "origin_tp", + sp_mode: str = "none", ): super().__init__() self.checkpoint = checkpoint @@ -102,7 +102,7 @@ def __init__( use_flash_attn=use_flash_attn, device=device, dtype=dtype, - tp_mode=tp_mode, + sp_mode=sp_mode, ) self.dropout1 = nn.Dropout(drop_rate) @@ -114,7 +114,7 @@ def __init__( self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) if use_swiglu: - mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward + mlp_cls = FSTPFeedForward if sp_mode == "intern" else FeedForward self.mlp = mlp_cls( hidden_size, int(hidden_size * mlp_ratio), @@ -297,7 +297,7 @@ def __init__( super().__init__() checkpoint_layer_num = int(num_layers * checkpoint) - self.tp_mode = gpc.config.parallel["tensor"]["mode"] + self.sp_mode = gpc.config.parallel["tensor"]["sp"] if is_reward: head_cls = RewardModelLinear @@ -343,7 +343,7 @@ def __init__( use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, - tp_mode=self.tp_mode, + sp_mode=self.sp_mode, ) for lid in range(num_layers) ] @@ -389,8 +389,8 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N assert len(indexes) == 1 # The indexes are used to indicate the actual position IDs of each token in the packed input. indexes = indexes[0] - # if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension. - if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp": + # if the sequence parallel mode is 'intern', the indexes should also be split in sequence dimension. + if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern": indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 8dcd3f96..cb0efb85 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -175,7 +175,7 @@ def __init__( use_flash_attn: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - tp_mode: str = "origin_tp", + sp_mode: str = "none", ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -203,7 +203,7 @@ def __init__( self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) # notice here should change bias=True - Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear + Wqkv_cls = FSTPLinear if sp_mode == "intern" else ColumnParallelLinearTorch self.Wqkv = Wqkv_cls( embed_dim, 3 * embed_dim, @@ -219,12 +219,12 @@ def __init__( self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) - if tp_mode == "fstp": + if sp_mode == "intern": self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group) self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group) # output projection always have the bias (for now) - out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear + out_proj_cls = FSTPLinear if sp_mode == "intern" else RowParallelLinearTorch self.out_proj = out_proj_cls( embed_dim, embed_dim, diff --git a/internlm/model/utils.py b/internlm/model/utils.py index b9c7c03a..19531e4a 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -1,12 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Any, Optional, Union +from typing import Optional import fused_dense_lib as fused_dense_cuda import torch import torch.nn.functional as F -from flash_attn.utils.distributed import all_reduce_raw # , reduce_scatter_raw +from flash_attn.utils.distributed import all_reduce_raw from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup @@ -397,7 +397,6 @@ def backward(ctx, grad_output, *args): grad_input = grad_input.contiguous() process_group = ctx.process_group all_gather_handler = ctx.all_gather_handler - module = ctx.module block_index = ctx.block_index module_name = ctx.module_name diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 2c14c65d..cb8aa659 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -2,8 +2,8 @@ # -*- encoding: utf-8 -*- import math -from typing import Optional, List from functools import partial +from typing import List, Optional import torch import torch.distributed as dist @@ -11,7 +11,7 @@ from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool +from internlm.model.utils import release_reduce_scatter_memory_pool from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, @@ -83,7 +83,7 @@ def __init__( hysteresis = grad_scal_cfg.hysteresis max_scale = grad_scal_cfg.max_scale - if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: + if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: self._fstp_handler = gpc.config.fstp_handler # Zero related args @@ -366,8 +366,8 @@ def _accum_grads_store_in_bucket(self, bucket: BucketStore, reduce_rank: Optiona _param.grad.add_(_grad) # release cuda memory. + release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index) self._fstp_handler.reduce_scatter_handlers[_key] = None - _grad = None bucket.reset_by_rank(reduce_rank) diff --git a/internlm/solver/optimizer/store.py b/internlm/solver/optimizer/store.py index 228045ed..f486ccec 100644 --- a/internlm/solver/optimizer/store.py +++ b/internlm/solver/optimizer/store.py @@ -45,7 +45,7 @@ def __init__(self, group_id, dp_parallel_mode): def num_elements_in_bucket(self, reduce_rank: int = None): return self._num_elements_in_bucket[reduce_rank] - + def num_params_in_bucket(self, reduce_rank: int = None): return len(self._params[reduce_rank]) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 5205ba5b..53996b38 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -110,9 +110,8 @@ def initialize_model(): gpc.config.fstp_handler = None - if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: + if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) - # handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) handler._register_sync_parameters_hook() gpc.config.fstp_handler = handler diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 968a1db1..f708fa78 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape def switch_sequence_parallel_mode(): prev_mode = gpc.config.parallel.sequence_parallel try: - if gpc.config.parallel["tensor"]["mode"] == "fstp": + if gpc.config.parallel["tensor"]["sp"] == "intern": gpc.config.parallel.sequence_parallel = True else: gpc.config.parallel.sequence_parallel = False @@ -106,7 +106,7 @@ def evaluate_on_val_dls( total_val_bsz = len(batch[1]) assert total_val_bsz % data_cfg.micro_bsz == 0 num_microbatches = total_val_bsz // data_cfg.micro_bsz - if gpc.config.parallel["tensor"]["mode"] == "fstp": + if gpc.config.parallel["tensor"]["sp"] == "intern": sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) tensor_shape = torch.Size( [ diff --git a/internlm/utils/gputest.py b/internlm/utils/gputest.py index 52d96385..bf4cf1c9 100644 --- a/internlm/utils/gputest.py +++ b/internlm/utils/gputest.py @@ -45,7 +45,7 @@ def empty_cache_and_diag(batch_count, interval=50): # # import time # # time.sleep(10) # print(e, "rank = ", gpc.get_global_rank(), flush=True) - # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") # do empty_cache after the bench torch.cuda.empty_cache()