Skip to content

Commit

Permalink
support fine-grained overlap
Browse files Browse the repository at this point in the history
  • Loading branch information
yingtongxiong committed Oct 11, 2023
1 parent 792b066 commit 5fd5a8a
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 40 deletions.
2 changes: 1 addition & 1 deletion configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
"""
parallel = dict(
zero1=dict(size=1, fsdp=False),
tensor=dict(size=2, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True,
)
Expand Down
78 changes: 56 additions & 22 deletions internlm/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch
from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch, all_gather_raw


class ScaleColumnParallelLinear(nn.Linear):
Expand Down Expand Up @@ -211,8 +212,7 @@ def forward(self, x):

class FSTPLinear(ColumnParallelLinear):
def forward(self, x):
import pdb; pdb.set_trace()
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group, module=self, handler=gpc.config.fstp_handler)


class FSTPFeedForward(nn.Module):
Expand Down Expand Up @@ -287,6 +287,7 @@ class FSTPAllGatherSyncHandler:

def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:

# import pdb; pdb.set_trace()
self.process_group = process_group
self.FSTP_modules = []
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
Expand All @@ -306,47 +307,80 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non

for _, children in _chunk.named_children():
if isinstance(children, nn.ModuleList):
for _, block in enumerate(children):
for idx, block in enumerate(children):
index = 0
sub_modules = list(block.children())
if len(sub_modules) > 0:
for name, child in block.named_children():
if isinstance(child, FSTPLinear):
self.FSTP_modules.append(child)
self.module_block[child] = _
self.block_module[_][index] = child
self.module_name_index[child] = index
index = index + 1
else:
continue
self.block_module[idx] = {}
for _, sub in block.named_children():
sub_modules = list(sub.children())
if len(sub_modules) > 0:
for name, child in sub.named_children():
if isinstance(child, FSTPLinear):
self.FSTP_modules.append(child)
self.module_block[child] = idx
self.block_module[idx][index] = child
self.module_name_index[child] = index
index = index + 1
else:
continue


def _register_sync_parameters_hook(self) -> None:
"""
register pre_forward_hook and pre_backward_hook for FSTPLinear.
"""

def _hook(module: nn.Module):
def _pre_forward_hook(module: nn.Module, inputs: Any):
block_index = self.module_block[module]
name_index = self.module_name_index[module]
if name_index == 0:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight

# start the all-gather for next module
next_module = self.block_module[block_index][name_index + 1]
self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.module_handler[next_module] = weights_handler
else:
handler = self.module_handler[module]
handler.wait()
if name_index != 4:
next_module = self.block_module[block_index][name_index + 1]
self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.module_handler[next_module] = weights_handler

def _pre_forward_hook(module: nn.Module, inputs: Any):
_hook(module)
def _post_forward_hook(module: nn.Module, input, output):
del self.FSTP_global_weights[module]
del self.module_handler[module]

def _pre_backward_hook(module: nn.Module, grad_input, grad_output):
_hook(module)
block_index = self.module_block[module]
name_index = self.module_name_index[module]
if name_index == 4:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight

# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.module_handler[next_module] = weights_handler
else:
handler = self.module_handler[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.module_handler[next_module] = weights_handler

def _post_backward_hook(module, grad_input, grad_output):
del self.FSTP_global_weights[module]

for module in self.FSTP_modules:
# import pdb; pdb.set_trace()
module.register_forward_pre_hook(_pre_forward_hook)
module.register_backward_pre_hook(_pre_backward_hook)
module.register_forward_hook(_post_forward_hook)
# module.register_backward_pre_hook(_pre_backward_hook)
# module.register_backward_hook(_post_backward_hook)
module.register_module_full_backward_pre_hook(_pre_backward_hook)

3 changes: 2 additions & 1 deletion internlm/model/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __init__(
embed_dim,
3 * embed_dim,
process_group,
bias=True,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577
Expand All @@ -231,6 +231,7 @@ def __init__(
embed_dim,
embed_dim,
process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
)
Expand Down
35 changes: 21 additions & 14 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,26 +283,30 @@ class FSTPFusedDenseFunc(torch.autograd.Function):

@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias, return_residual=False, process_group=None):
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None):

ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual
ctx.process_group = process_group
ctx.all_gather_handler = all_gather_handler
ctx.module = module

if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
total_x = x.contiguous()

world_size = gpc.get_world_size(ParallelMode.TENSOR)
if world_size > 1:
# do all_gather for weight and bias before actual computation
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
if bias is not None:
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
handle_bias.wait()
else:
total_bias = bias
handle_weight.wait()
total_weight = all_gather_handler.FSTP_global_weights[module]
total_bias = bias
# # do all_gather for weight and bias before actual computation
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
# if bias is not None:
# total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
# handle_bias.wait()
# else:
# total_bias = bias
# handle_weight.wait()
else:
total_weight = weight
total_bias = bias
Expand Down Expand Up @@ -332,6 +336,8 @@ def backward(ctx, grad_output, *args):
(grad_input,) = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
all_gather_handler = ctx.all_gather_handler
module = ctx.module
if ctx.compute_weight_gradient:
x, weight = ctx.saved_tensors
total_x = x
Expand All @@ -345,8 +351,9 @@ def backward(ctx, grad_output, *args):
world_size = gpc.get_world_size(ParallelMode.TENSOR)
if world_size > 1:
# do all-gather for weight before backward
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait()
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
# handle_weight.wait()
total_weight = all_gather_handler.FSTP_global_weights[module]
else:
total_weight = weight

Expand Down Expand Up @@ -379,7 +386,7 @@ def backward(ctx, grad_output, *args):
handle_grad_weight.wait()
if grad_bias is not None:
handle_grad_bias.wait()
return grad_input, grad_weight, grad_bias, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None


def fused_dense_func_torch(
Expand All @@ -401,13 +408,13 @@ def fused_dense_func_torch(


def fstp_fused_dense_func(
x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None
x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None, module=None, handler=None
):
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
x.dtype == torch.float32 and torch.is_autocast_enabled()
)
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler)
else:
assert process_group is None
out = F.linear(x, weight, bias)
Expand Down
8 changes: 6 additions & 2 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
FeedForward,
RewardModelLinear,
ScaleColumnParallelLinear,
FSTPAllGatherSyncHandler,
)
from internlm.model.multi_head_attention import MHA
from internlm.model.utils import try_import_RMSNorm
Expand Down Expand Up @@ -106,10 +107,13 @@ def initialize_model():

# if fsdp enabled, wrap the model
model = wrap_FSDP_model(model)


if gpc.config.parallel["tensor"]["mode"] == "fstp":
handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler._register_sync_parameters_hook()
gpc.config.fstp_handler = handler
return model


def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
if gpc.config.parallel.zero1.fsdp:
# set wrap_policy for fsdp wrap
Expand Down

0 comments on commit 5fd5a8a

Please sign in to comment.