diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 42bd9f03..3e37863d 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Optional, Union, Any +from typing import Any, Optional, Union import torch import torch.nn.functional as F @@ -12,7 +12,12 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel -from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch, all_gather_raw +from internlm.model.utils import ( + Silu, + all_gather_raw, + fstp_fused_dense_func, + fused_dense_func_torch, +) class ScaleColumnParallelLinear(nn.Linear): @@ -212,7 +217,9 @@ def forward(self, x): class FSTPLinear(ColumnParallelLinear): def forward(self, x): - return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group, module=self, handler=gpc.config.fstp_handler) + return fstp_fused_dense_func( + x, self.weight, self.bias, process_group=self.process_group, module=self, handler=gpc.config.fstp_handler + ) class FSTPFeedForward(nn.Module): @@ -280,31 +287,31 @@ def forward(self, x): out = self.w3(F.silu(w1_o) * w2_o) return out + class FSTPAllGatherSyncHandler: """ All-gather handler for overlapping the all-gather in adjcent FSTP linear. """ - + def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None: - # import pdb; pdb.set_trace() self.process_group = process_group self.FSTP_modules = [] self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] - self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward - self.module_handler = dict() # key: FSTP module; value: all-gather handler - self.module_block = dict() # key: FSTP module; value: transformer block index - self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module} - self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name - + self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward + self.module_handler = dict() # key: FSTP module; value: all-gather handler + self.module_block = dict() # key: FSTP module; value: transformer block index + self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module} + self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name + # just want to share same for loop for ModuleList and Module if not isinstance(model, nn.ModuleList): model = [model] - + for _chunk in model: if isinstance(_chunk, NaiveAMPModel): _chunk = _chunk.model - + for _, children in _chunk.named_children(): if isinstance(children, nn.ModuleList): for idx, block in enumerate(children): @@ -322,13 +329,12 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non index = index + 1 else: continue - - + def _register_sync_parameters_hook(self) -> None: """ register pre_forward_hook and pre_backward_hook for FSTPLinear. """ - + def _pre_forward_hook(module: nn.Module, inputs: Any): block_index = self.module_block[module] name_index = self.module_name_index[module] @@ -336,19 +342,23 @@ def _pre_forward_hook(module: nn.Module, inputs: Any): total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) weight_handler.wait() self.FSTP_global_weights[module] = total_weight - + # start the all-gather for next module next_module = self.block_module[block_index][name_index + 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True) + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( + next_module.weight, self.process_group, async_op=True + ) self.module_handler[next_module] = weights_handler else: handler = self.module_handler[module] handler.wait() if name_index != 4: next_module = self.block_module[block_index][name_index + 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True) + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( + next_module.weight, self.process_group, async_op=True + ) self.module_handler[next_module] = weights_handler - + def _post_forward_hook(module: nn.Module, input, output): del self.FSTP_global_weights[module] del self.module_handler[module] @@ -360,22 +370,26 @@ def _pre_backward_hook(module: nn.Module, grad_input, grad_output): total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) weight_handler.wait() self.FSTP_global_weights[module] = total_weight - + # start the all-gather for next module next_module = self.block_module[block_index][name_index - 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True) + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( + next_module.weight, self.process_group, async_op=True + ) self.module_handler[next_module] = weights_handler else: handler = self.module_handler[module] handler.wait() if name_index != 0: next_module = self.block_module[block_index][name_index - 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True) + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( + next_module.weight, self.process_group, async_op=True + ) self.module_handler[next_module] = weights_handler - + def _post_backward_hook(module, grad_input, grad_output): del self.FSTP_global_weights[module] - + for module in self.FSTP_modules: # import pdb; pdb.set_trace() module.register_forward_pre_hook(_pre_forward_hook) @@ -383,4 +397,145 @@ def _post_backward_hook(module, grad_input, grad_output): # module.register_backward_pre_hook(_pre_backward_hook) # module.register_backward_hook(_post_backward_hook) module.register_module_full_backward_pre_hook(_pre_backward_hook) - \ No newline at end of file + + +class CoarseGrainedFSTPAllGatherSyncHandler: + """ + All-gather handler for overlapping the all-gather in adjcent FSTP block. + """ + + def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None: + # import pdb; pdb.set_trace() + self.process_group = process_group + self.FSTP_blocks = [] + self.FSTP_outs = [] + self.FSTP_wqkvs = [] + self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] + self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle + self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward + self.block_handles = dict() # key: transformer block; value: all-gather handles + self.module_to_index = dict() # key: FSTP module; value: transformer block index + self.block_to_index = dict() # key: transformer block; value: transformer block index + self.index_to_block = dict() # key: transformer block index; value: transformer block + self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules + + # just want to share same for loop for ModuleList and Module + if not isinstance(model, nn.ModuleList): + model = [model] + + for _chunk in model: + if isinstance(_chunk, NaiveAMPModel): + _chunk = _chunk.model + + for _, children in _chunk.named_children(): + if isinstance(children, nn.ModuleList): + for idx, block in enumerate(children): + self.FSTP_blocks.append(block) + self.block_to_index[block] = idx + self.index_to_block[idx] = block + self.index_to_fsdp_modules[idx] = [] + for _, sub in block.named_children(): + sub_modules = list(sub.children()) + if len(sub_modules) > 0: + for name, child in sub.named_children(): + # print(f"name: {name}", flush=True) + if name == "out_proj": + self.FSTP_outs.append(child) + self.module_to_index[child] = idx + if name == "Wqkv": + self.FSTP_wqkvs.append(child) + self.module_to_index[child] = idx + if isinstance(child, FSTPLinear): + self.index_to_fsdp_modules[idx].append(child) + else: + continue + + def _all_gather_block_weight(self, block_index: int): + block = self.index_to_block[block_index] + fsdp_modules = self.index_to_fsdp_modules[block_index] + self.block_handles[block] = [] + for module in fsdp_modules: + total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True) + self.FSTP_global_weights[module] = total_weight + self.block_handles[block].append(weight_handle) + + def _register_sync_parameters_hook(self) -> None: + """ + register pre_forward_hook and pre_backward_hook for FSTP block. + + Notice that next block's all_gather op should be after current block's all_to_all op, so we + 1. register pre_forward_hook @out_proj module to prefetch for next block + 2. register pre_forward_hook @block module to wait handles for next block + 3. register pre_backward_hook @wqkv module to prefetch for next block + 4. register pre_backward_hook @block module to wait handles for next block + """ + + def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): + block_index = self.module_to_index[module] + # start the all-gather for next block + if block_index + 1 < gpc.config.NUM_LAYER: + self._all_gather_block_weight(block_index + 1) + + def _pre_forward_hook_for_block(block: nn.Module, inputs: Any): + block_index = self.block_to_index[block] + if block_index == 0: + # all gather weight for block 0 + fsdp_modules = self.index_to_fsdp_modules[block_index] + for module in fsdp_modules: + total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True) + weight_handle.wait() + self.FSTP_global_weights[module] = total_weight + else: + # wait handle for current block + handles = self.block_handles[block] + for handle in handles: + handle.wait() + + def _post_forward_hook_for_block(block: nn.Module, input, output): + block_index = self.block_to_index[block] + fsdp_modules = self.index_to_fsdp_modules[block_index] + if block in self.block_handles: + del self.block_handles[block] + for module in fsdp_modules: + del self.FSTP_global_weights[module] + + def _pre_backward_hook_for_wqkv(module: nn.Module, grad_output): + block_index = self.module_to_index[module] + # start the all-gather for next block + if block_index - 1 >= 0: + self._all_gather_block_weight(block_index - 1) + + def _pre_backward_hook_for_block(block: nn.Module, grad_output): + block_index = self.block_to_index[block] + if block_index == gpc.config.NUM_LAYER - 1: + # all gather weight for the last block + fsdp_modules = self.index_to_fsdp_modules[block_index] + for module in fsdp_modules: + total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True) + weight_handle.wait() + self.FSTP_global_weights[module] = total_weight + else: + # wait handle for current block + handles = self.block_handles[block] + for handle in handles: + handle.wait() + + def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output): + block_index = self.block_to_index[block] + fsdp_modules = self.index_to_fsdp_modules[block_index] + if block in self.block_handles: + del self.block_handles[block] + for module in fsdp_modules: + del self.FSTP_global_weights[module] + + for block in self.FSTP_blocks: + block.register_forward_pre_hook(_pre_forward_hook_for_block) + block.register_forward_hook(_post_forward_hook_for_block) + block.register_full_backward_pre_hook(_pre_backward_hook_for_block) + block.register_full_backward_hook(_post_backward_hook_for_block) + + for out_proj in self.FSTP_outs: + out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) + + for wqkv in self.FSTP_wqkvs: + wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 50b9bbd7..97319d98 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -284,7 +284,6 @@ class FSTPFusedDenseFunc(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None): - ctx.compute_weight_gradient = weight.requires_grad ctx.return_residual = return_residual ctx.process_group = process_group @@ -297,16 +296,18 @@ def forward(ctx, x, weight, bias, return_residual=False, process_group=None, mod world_size = gpc.get_world_size(ParallelMode.TENSOR) if world_size > 1: - total_weight = all_gather_handler.FSTP_global_weights[module] - total_bias = bias - # # do all_gather for weight and bias before actual computation - # total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) - # if bias is not None: - # total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) - # handle_bias.wait() - # else: - # total_bias = bias - # handle_weight.wait() + # do all_gather for weight and bias before actual computation + if module in all_gather_handler.FSTP_global_weights: + total_weight = all_gather_handler.FSTP_global_weights[module] + else: + total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) + handle_weight.wait() + + if bias is not None: + total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) + handle_bias.wait() + else: + total_bias = bias else: total_weight = weight total_bias = bias @@ -351,12 +352,14 @@ def backward(ctx, grad_output, *args): world_size = gpc.get_world_size(ParallelMode.TENSOR) if world_size > 1: # do all-gather for weight before backward - # total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) - # handle_weight.wait() - total_weight = all_gather_handler.FSTP_global_weights[module] + if module in all_gather_handler.FSTP_global_weights: + total_weight = all_gather_handler.FSTP_global_weights[module] + else: + total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) + handle_weight.wait() else: total_weight = weight - + # compute weight grad if ctx.needs_input_grad[1]: assert ctx.compute_weight_gradient @@ -380,7 +383,7 @@ def backward(ctx, grad_output, *args): grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) else: grad_input = None - + if ctx.needs_input_grad[1]: if world_size > 1: handle_grad_weight.wait() @@ -408,7 +411,13 @@ def fused_dense_func_torch( def fstp_fused_dense_func( - x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None, module=None, handler=None + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + return_residual: bool = False, + process_group=None, + module=None, + handler=None, ): dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( x.dtype == torch.float32 and torch.is_autocast_enabled() @@ -460,5 +469,3 @@ def Silu(w1_o, w2_o): Silu = torch.jit.script(Silu) - - diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 5deb0233..da59803c 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -36,10 +36,11 @@ from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data from internlm.model.embedding import Embedding1D from internlm.model.linear import ( + CoarseGrainedFSTPAllGatherSyncHandler, FeedForward, + FSTPAllGatherSyncHandler, RewardModelLinear, ScaleColumnParallelLinear, - FSTPAllGatherSyncHandler, ) from internlm.model.multi_head_attention import MHA from internlm.model.utils import try_import_RMSNorm @@ -107,13 +108,14 @@ def initialize_model(): # if fsdp enabled, wrap the model model = wrap_FSDP_model(model) - + if gpc.config.parallel["tensor"]["mode"] == "fstp": - handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) + handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) handler._register_sync_parameters_hook() gpc.config.fstp_handler = handler return model + def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): if gpc.config.parallel.zero1.fsdp: # set wrap_policy for fsdp wrap