Skip to content

Commit

Permalink
feat(model/linear.py): support block allgather overlap
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Oct 12, 2023
1 parent 5fd5a8a commit d0b1346
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 48 deletions.
207 changes: 181 additions & 26 deletions internlm/model/linear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from typing import Optional, Union, Any
from typing import Any, Optional, Union

import torch
import torch.nn.functional as F
Expand All @@ -12,7 +12,12 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch, all_gather_raw
from internlm.model.utils import (
Silu,
all_gather_raw,
fstp_fused_dense_func,
fused_dense_func_torch,
)


class ScaleColumnParallelLinear(nn.Linear):
Expand Down Expand Up @@ -212,7 +217,9 @@ def forward(self, x):

class FSTPLinear(ColumnParallelLinear):
def forward(self, x):
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group, module=self, handler=gpc.config.fstp_handler)
return fstp_fused_dense_func(
x, self.weight, self.bias, process_group=self.process_group, module=self, handler=gpc.config.fstp_handler
)


class FSTPFeedForward(nn.Module):
Expand Down Expand Up @@ -280,31 +287,31 @@ def forward(self, x):
out = self.w3(F.silu(w1_o) * w2_o)
return out


class FSTPAllGatherSyncHandler:
"""
All-gather handler for overlapping the all-gather in adjcent FSTP linear.
"""

def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:

# import pdb; pdb.set_trace()
self.process_group = process_group
self.FSTP_modules = []
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
self.module_handler = dict() # key: FSTP module; value: all-gather handler
self.module_block = dict() # key: FSTP module; value: transformer block index
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
self.module_handler = dict() # key: FSTP module; value: all-gather handler
self.module_block = dict() # key: FSTP module; value: transformer block index
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name

# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
model = [model]

for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model

for _, children in _chunk.named_children():
if isinstance(children, nn.ModuleList):
for idx, block in enumerate(children):
Expand All @@ -322,33 +329,36 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non
index = index + 1
else:
continue



def _register_sync_parameters_hook(self) -> None:
"""
register pre_forward_hook and pre_backward_hook for FSTPLinear.
"""

def _pre_forward_hook(module: nn.Module, inputs: Any):
block_index = self.module_block[module]
name_index = self.module_name_index[module]
if name_index == 0:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight

# start the all-gather for next module
next_module = self.block_module[block_index][name_index + 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler
else:
handler = self.module_handler[module]
handler.wait()
if name_index != 4:
next_module = self.block_module[block_index][name_index + 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler

def _post_forward_hook(module: nn.Module, input, output):
del self.FSTP_global_weights[module]
del self.module_handler[module]
Expand All @@ -360,27 +370,172 @@ def _pre_backward_hook(module: nn.Module, grad_input, grad_output):
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight

# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler
else:
handler = self.module_handler[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler

def _post_backward_hook(module, grad_input, grad_output):
del self.FSTP_global_weights[module]

for module in self.FSTP_modules:
# import pdb; pdb.set_trace()
module.register_forward_pre_hook(_pre_forward_hook)
module.register_forward_hook(_post_forward_hook)
# module.register_backward_pre_hook(_pre_backward_hook)
# module.register_backward_hook(_post_backward_hook)
module.register_module_full_backward_pre_hook(_pre_backward_hook)



class CoarseGrainedFSTPAllGatherSyncHandler:
"""
All-gather handler for overlapping the all-gather in adjcent FSTP block.
"""

def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
# import pdb; pdb.set_trace()
self.process_group = process_group
self.FSTP_blocks = []
self.FSTP_outs = []
self.FSTP_wqkvs = []
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
self.block_handles = dict() # key: transformer block; value: all-gather handles
self.module_to_index = dict() # key: FSTP module; value: transformer block index
self.block_to_index = dict() # key: transformer block; value: transformer block index
self.index_to_block = dict() # key: transformer block index; value: transformer block
self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules

# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
model = [model]

for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model

for _, children in _chunk.named_children():
if isinstance(children, nn.ModuleList):
for idx, block in enumerate(children):
self.FSTP_blocks.append(block)
self.block_to_index[block] = idx
self.index_to_block[idx] = block
self.index_to_fsdp_modules[idx] = []
for _, sub in block.named_children():
sub_modules = list(sub.children())
if len(sub_modules) > 0:
for name, child in sub.named_children():
# print(f"name: {name}", flush=True)
if name == "out_proj":
self.FSTP_outs.append(child)
self.module_to_index[child] = idx
if name == "Wqkv":
self.FSTP_wqkvs.append(child)
self.module_to_index[child] = idx
if isinstance(child, FSTPLinear):
self.index_to_fsdp_modules[idx].append(child)
else:
continue

def _all_gather_block_weight(self, block_index: int):
block = self.index_to_block[block_index]
fsdp_modules = self.index_to_fsdp_modules[block_index]
self.block_handles[block] = []
for module in fsdp_modules:
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
self.FSTP_global_weights[module] = total_weight
self.block_handles[block].append(weight_handle)

def _register_sync_parameters_hook(self) -> None:
"""
register pre_forward_hook and pre_backward_hook for FSTP block.
Notice that next block's all_gather op should be after current block's all_to_all op, so we
1. register pre_forward_hook @out_proj module to prefetch for next block
2. register pre_forward_hook @block module to wait handles for next block
3. register pre_backward_hook @wqkv module to prefetch for next block
4. register pre_backward_hook @block module to wait handles for next block
"""

def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any):
block_index = self.module_to_index[module]
# start the all-gather for next block
if block_index + 1 < gpc.config.NUM_LAYER:
self._all_gather_block_weight(block_index + 1)

def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
block_index = self.block_to_index[block]
if block_index == 0:
# all gather weight for block 0
fsdp_modules = self.index_to_fsdp_modules[block_index]
for module in fsdp_modules:
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handle.wait()
self.FSTP_global_weights[module] = total_weight
else:
# wait handle for current block
handles = self.block_handles[block]
for handle in handles:
handle.wait()

def _post_forward_hook_for_block(block: nn.Module, input, output):
block_index = self.block_to_index[block]
fsdp_modules = self.index_to_fsdp_modules[block_index]
if block in self.block_handles:
del self.block_handles[block]
for module in fsdp_modules:
del self.FSTP_global_weights[module]

def _pre_backward_hook_for_wqkv(module: nn.Module, grad_output):
block_index = self.module_to_index[module]
# start the all-gather for next block
if block_index - 1 >= 0:
self._all_gather_block_weight(block_index - 1)

def _pre_backward_hook_for_block(block: nn.Module, grad_output):
block_index = self.block_to_index[block]
if block_index == gpc.config.NUM_LAYER - 1:
# all gather weight for the last block
fsdp_modules = self.index_to_fsdp_modules[block_index]
for module in fsdp_modules:
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handle.wait()
self.FSTP_global_weights[module] = total_weight
else:
# wait handle for current block
handles = self.block_handles[block]
for handle in handles:
handle.wait()

def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
block_index = self.block_to_index[block]
fsdp_modules = self.index_to_fsdp_modules[block_index]
if block in self.block_handles:
del self.block_handles[block]
for module in fsdp_modules:
del self.FSTP_global_weights[module]

for block in self.FSTP_blocks:
block.register_forward_pre_hook(_pre_forward_hook_for_block)
block.register_forward_hook(_post_forward_hook_for_block)
block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
block.register_full_backward_hook(_post_backward_hook_for_block)

for out_proj in self.FSTP_outs:
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)

for wqkv in self.FSTP_wqkvs:
wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
45 changes: 26 additions & 19 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None):

ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual
ctx.process_group = process_group
Expand All @@ -297,16 +296,18 @@ def forward(ctx, x, weight, bias, return_residual=False, process_group=None, mod

world_size = gpc.get_world_size(ParallelMode.TENSOR)
if world_size > 1:
total_weight = all_gather_handler.FSTP_global_weights[module]
total_bias = bias
# # do all_gather for weight and bias before actual computation
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
# if bias is not None:
# total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
# handle_bias.wait()
# else:
# total_bias = bias
# handle_weight.wait()
# do all_gather for weight and bias before actual computation
if module in all_gather_handler.FSTP_global_weights:
total_weight = all_gather_handler.FSTP_global_weights[module]
else:
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait()

if bias is not None:
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
handle_bias.wait()
else:
total_bias = bias
else:
total_weight = weight
total_bias = bias
Expand Down Expand Up @@ -351,12 +352,14 @@ def backward(ctx, grad_output, *args):
world_size = gpc.get_world_size(ParallelMode.TENSOR)
if world_size > 1:
# do all-gather for weight before backward
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
# handle_weight.wait()
total_weight = all_gather_handler.FSTP_global_weights[module]
if module in all_gather_handler.FSTP_global_weights:
total_weight = all_gather_handler.FSTP_global_weights[module]
else:
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait()
else:
total_weight = weight

# compute weight grad
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
Expand All @@ -380,7 +383,7 @@ def backward(ctx, grad_output, *args):
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
else:
grad_input = None

if ctx.needs_input_grad[1]:
if world_size > 1:
handle_grad_weight.wait()
Expand Down Expand Up @@ -408,7 +411,13 @@ def fused_dense_func_torch(


def fstp_fused_dense_func(
x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None, module=None, handler=None
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
return_residual: bool = False,
process_group=None,
module=None,
handler=None,
):
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
x.dtype == torch.float32 and torch.is_autocast_enabled()
Expand Down Expand Up @@ -460,5 +469,3 @@ def Silu(w1_o, w2_o):


Silu = torch.jit.script(Silu)


Loading

0 comments on commit d0b1346

Please sign in to comment.