Skip to content

Commit

Permalink
Merge pull request #5 from yingtongxiong/fstp/refactor-hook-handle
Browse files Browse the repository at this point in the history
feat(model/overlap_handler.py): refactor overlap hook handle
  • Loading branch information
huangting4201 authored Oct 23, 2023
2 parents 1804d01 + b2c1a70 commit b48687a
Show file tree
Hide file tree
Showing 10 changed files with 450 additions and 379 deletions.
2 changes: 1 addition & 1 deletion configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, sp="megatron", intern_overlap=True),
tensor=dict(size=8, sp="intern", intern_overlap=True),
pipeline=dict(size=1, interleaved_overlap=True),
)

Expand Down
297 changes: 73 additions & 224 deletions internlm/model/linear.py

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.model.embedding import Embedding1D
from internlm.model.linear import (
FeedForward,
MegatronFeedForward,
FSTPFeedForward,
MegatronScaleColumnParallelLinear,
RewardModelLinear,
ScaleColumnParallelLinear,
MegatronScaleColumnParallelLinear,
get_mlp_cls,
)
from internlm.model.multi_head_attention import MHA
Expand Down Expand Up @@ -309,7 +306,11 @@ def __init__(
if is_reward:
head_cls = RewardModelLinear
else:
head_cls = ScaleColumnParallelLinear if self.sp_mode in ["flash-attn", "none", "intern"] else MegatronScaleColumnParallelLinear
head_cls = (
ScaleColumnParallelLinear
if self.sp_mode in ["flash-attn", "none", "intern"]
else MegatronScaleColumnParallelLinear
)
if first:
if embed_split_hidden:
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
Expand Down
1 change: 0 additions & 1 deletion internlm/model/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(
device=None,
dtype=None,
):

super().__init__()

assert (
Expand Down
11 changes: 2 additions & 9 deletions internlm/model/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,7 @@
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding
from internlm.model.linear import (
ColumnParallelLinearTorch,
FSTPLinear,
RowParallelLinearTorch,
MegatronColumnParallelLinearTorch,
MegatronRowParallelLinearTorch,
get_linear_cls,
)
from internlm.model.linear import get_linear_cls


# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
Expand Down Expand Up @@ -227,7 +220,7 @@ def __init__(
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)

# output projection always have the bias (for now)
out_proj_cls = get_linear_cls(sp_mode, 'row')
out_proj_cls = get_linear_cls(sp_mode, "row")
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
Expand Down
283 changes: 283 additions & 0 deletions internlm/model/overlap_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from typing import Any, Union

import torch
from torch import nn

from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.embedding import Embedding1D
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
from internlm.model.utils import (
all_gather_raw_bias_memory_pool,
all_gather_raw_memory_pool,
)
from internlm.utils.common import get_current_device


class FSTPOverlapHandler:
"""
FSTP overlap handler for managing the all-gather and reduce_scatter overlapping.
"""

def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
self.process_group = process_group
self.fstp_outs = []
self.fstp_modules = []
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle
self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle
self.module_to_index = dict() # key: fstp module; value: transformer block index
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
self.head = []
self.embedding = []

self.reduce_scatter_handlers = {}
self.zero_const_pool = {}

# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
model = [model]

for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model

for _chunk_name, children in _chunk.named_children():
if isinstance(children, ScaleColumnParallelLinear):
self.head.append(children)
elif isinstance(children, Embedding1D):
self.embedding.append(children)
elif isinstance(children, nn.ModuleList):
for idx, block in enumerate(children):
self.index_to_fstp_modules[idx] = []
for _sub_name, sub in block.named_children():
sub_modules = list(sub.children())
if len(sub_modules) > 0:
for name, child in sub.named_children():
if name == "out_proj":
self.fstp_outs.append(child)
self.module_to_index[child] = idx
if isinstance(child, FSTPLinear):
self.module_to_index[child] = idx
self.fstp_modules.append(child)
self.index_to_fstp_modules[idx].append(child)

setattr(child, "_fstp_name", name)

_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
if child.bias is not None:
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")

self._initialize_memory_pool()
self._register_sync_parameters_hook()

def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor:
if size not in self.zero_const_pool:
self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous()

return self.zero_const_pool[size]

def _initialize_module_shape(self):
hidden_size = gpc.config.HIDDEN_SIZE
mlp_ratio = gpc.config.MLP_RATIO
mlp_hidden_size = int(hidden_size * mlp_ratio)
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)

self.module_shape["Wqkv"] = (3 * hidden_size, hidden_size)
self.module_shape["out_proj"] = (hidden_size, hidden_size)
self.module_shape["w1"] = (mlp_hidden_size, hidden_size)
self.module_shape["w2"] = (mlp_hidden_size, hidden_size)
self.module_shape["w3"] = (hidden_size, mlp_hidden_size)

def _initialize_memory_pool(self) -> None:
# allocate memory pool
self.all_gather_memory_pool = []
self.all_gather_bias_memory_pool = []
self.reduce_scatter_memory_pool = {}
self.module_shape = {}

self._initialize_module_shape()
dtype = gpc.config.model.get("dtype", torch.half)
device = get_current_device()

for _ in range(2):
weight = {}
for name in self.module_name:
weight[name] = torch.zeros(self.module_shape[name], dtype=dtype, device=device).contiguous()
self.all_gather_memory_pool.append(weight) # containing two groups of block weight

def clear_memory_pool(self) -> None:
self.zero_const_pool = {}
self.reduce_scatter_memory_pool = {}

def get_all_gather_memory(self, module):
block_index = self.module_to_index[module]
return self.all_gather_memory_pool[block_index % 2][module._fstp_name]

def get_bias_memory(self, module: nn.Module):
block_index = self.module_to_index[module]
# if the bias memory pool is empty or module has been not allocated memory
# import pdb; pdb.set_trace()
if len(self.all_gather_bias_memory_pool) == 0:
for _ in range(2):
weight = {}
weight[module._fstp_name] = torch.zeros(
self.module_shape[module._fstp_name][0],
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device(),
).contiguous()
self.all_gather_bias_memory_pool.append(weight)
elif module._fstp_name not in self.all_gather_bias_memory_pool[0]:
for i in range(2):
self.all_gather_bias_memory_pool[i][module._fstp_name] = torch.zeros(
self.module_shape[module._fstp_name][0],
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device(),
).contiguous()

return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name]

def get_reduce_scatter_memory(self, key):
return_idx = 0

# if key not in dict
if key not in self.reduce_scatter_memory_pool:
self.reduce_scatter_memory_pool[key] = []

# if the data is empty
if len(self.reduce_scatter_memory_pool[key]) == 0:
self.reduce_scatter_memory_pool[key].append(
torch.zeros(
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
).contiguous()
)
setattr(self.reduce_scatter_memory_pool[key][return_idx], "idle", False)
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
return self.reduce_scatter_memory_pool[key][return_idx]
else: # if not empty
for index, mem_item in enumerate(self.reduce_scatter_memory_pool[key]):
if mem_item.idle is True:
self.reduce_scatter_memory_pool[key][index].idle = False
return_idx = index
return self.reduce_scatter_memory_pool[key][return_idx]
# if the memory pool is all used
cur_len = len(self.reduce_scatter_memory_pool[key])
self.reduce_scatter_memory_pool[key].append(
torch.zeros(
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
).contiguous()
)
setattr(self.reduce_scatter_memory_pool[key][cur_len], "idle", False)
return_idx = cur_len
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
return self.reduce_scatter_memory_pool[key][return_idx]

def release_reduce_scatter_memory(self, key, index):
self.reduce_scatter_memory_pool[key][index].idle = True

def _all_gather_block_weight_memory_pool(self, block_index: int):
fstp_modules = self.index_to_fstp_modules[block_index]
for module in fstp_modules:
if module.bias is not None:
bias_handle = all_gather_raw_bias_memory_pool(
module.bias,
self.process_group,
async_op=True,
module=module,
)
self.bias_global_handle[module] = bias_handle

weight_handle = all_gather_raw_memory_pool(
module.weight,
self.process_group,
async_op=True,
module=module,
)
self.fstp_global_handle[module] = weight_handle

def _register_sync_parameters_hook(self) -> None:
"""
register forward hooks and backward hooks for fstp modules.
"""

def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
self._all_gather_block_weight_memory_pool(0)

def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613
block_index = self.module_to_index[module]
# start the all-gather for next block
if block_index + 1 < gpc.config.NUM_LAYER:
self._all_gather_block_weight_memory_pool(block_index + 1)

def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613
handle = self.fstp_global_handle[module]
handle.wait()
if module.bias is not None:
bias_handle = self.bias_global_handle[module]
bias_handle.wait()

def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
if module in self.fstp_global_handle:
del self.fstp_global_handle[module]

def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613
first_backward_module = self.fstp_modules[-1]
weight_handle = all_gather_raw_memory_pool(
first_backward_module.weight,
self.process_group,
async_op=True,
module=first_backward_module,
)
self.fstp_global_handle[first_backward_module] = weight_handle

def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
# wait handle for current module
weight_handle = self.fstp_global_handle[module]
weight_handle.wait()

# start the all-gather for next module
module_index = self.fstp_modules.index(module)
if module_index - 1 >= 0:
next_module = self.fstp_modules[module_index - 1]
weight_handle = all_gather_raw_memory_pool(
next_module.weight,
self.process_group,
async_op=True,
module=next_module,
)
self.fstp_global_handle[next_module] = weight_handle

def _post_backward_hook_for_module(module, grad_input, grad_output): # pylint: disable=W0613
if module in self.fstp_global_handle:
del self.fstp_global_handle[module]

# register forward hooks
# 1. register post_forward_hook @embedding module to prefetch for block 0
# 2. register pre_forward_hook @out_proj module to prefetch for next block,
# notice that next block's all_gather op should be after current block's all_to_all op
# 3. register pre_forward_hook @fstp_module to wait handle for current module
# 4. register post_forward_hook @fstp_module to release resource
for embedding in self.embedding:
embedding.register_forward_hook(_post_forward_hook_for_embedding)

for out_proj in self.fstp_outs:
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)

for module in self.fstp_modules:
module.register_forward_pre_hook(_pre_forward_hook_for_module)
module.register_forward_hook(_post_forward_hook_for_module)

# register backward hooks
# 1. register post_backward_hook @head module to prefetch for the last block's last module
# 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module
# 3. register post_backward_hook @fstp_module to release resource
for head in self.head:
head.register_full_backward_hook(_post_backward_hook_for_head)

for module in self.fstp_modules:
module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
module.register_full_backward_hook(_post_backward_hook_for_module)
Loading

0 comments on commit b48687a

Please sign in to comment.