Skip to content

Commit

Permalink
support megatron sp implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
yingtongxiong committed Oct 19, 2023
1 parent 4742271 commit dba909d
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 132 deletions.
3 changes: 2 additions & 1 deletion configs/20B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,10 @@
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, mode="fstp", overlap=True),
tensor=dict(size=8, mode="origin_tp", overlap=False),
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True,
sp_megatron=True,
)

cudnn_deterministic = False
Expand Down
1 change: 1 addition & 0 deletions configs/30B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
tensor=dict(size=8, mode="origin_tp", overlap=False),
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True,
sp_megatron=False,
)

cudnn_deterministic = False
Expand Down
246 changes: 133 additions & 113 deletions internlm/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
all_gather_raw_memory_pool,
fstp_fused_dense_func,
fused_dense_func_torch,
megatron_fused_dense_func,
)


Expand Down Expand Up @@ -217,6 +218,71 @@ def forward(self, x):
out = self.w3(Silu(w1_o, w2_o))
return out

class MegatronFeedForward(nn.Module):
"""
FeedForward.
Args:
in_features (int): size of each input sample
hidden_features (int): size of hidden state of FFN
out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
"""

def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int = None,
process_group: Optional[torch.distributed.ProcessGroup] = None,
bias: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
multiple_of: int = 256,
block_idx: int = 0,
):
super().__init__()

hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)

self.w1 = MegatronColumnParallelLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w2 = MegatronColumnParallelLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w3 = MegatronRowParallelLinear(
hidden_features,
out_features,
process_group,
bias=bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)

def forward(self, x):
w1_o = self.w1(x)
w2_o = self.w2(x)
out = self.w3(Silu(w1_o, w2_o))
return out

class FSTPLinear(ColumnParallelLinear):
def forward(self, x):
Expand Down Expand Up @@ -295,129 +361,83 @@ def forward(self, x):
out = self.w3(F.silu(w1_o) * w2_o)
return out

class MegatronColumnParallelLinear(ColumnParallelLinear):
def forward(self, x, gather_dim=0):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.

class FSTPAllGatherSyncHandler:
"""
All-gather handler for overlapping the all-gather in adjcent FSTP linear.
"""

def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
# import pdb; pdb.set_trace()
self.process_group = process_group
self.FSTP_modules = []
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
self.module_handler = dict() # key: FSTP module; value: all-gather handler
self.module_block = dict() # key: FSTP module; value: transformer block index
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name

self.reduce_scatter_handlers = {}
self.all_reduce_handlers = {}

# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
model = [model]

for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model

for _chunk_name, children in _chunk.named_children():
if isinstance(children, nn.ModuleList):
for idx, block in enumerate(children):
index = 0
self.block_module[idx] = {}
for _sub_name, sub in block.named_children():
sub_modules = list(sub.children())
if len(sub_modules) > 0:
for name, child in sub.named_children():
if isinstance(child, FSTPLinear):

_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
if child.bias is not None:
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
return megatron_fused_dense_func(
x,
self.weight,
self.bias,
process_group=self.process_group,
sequence_parallel=self.sequence_parallel,
gather_dim=gather_dim,
)

self.FSTP_modules.append(child)
self.module_block[child] = idx
self.block_module[idx][index] = child
self.module_name_index[child] = index
index = index + 1
else:
continue

def _register_sync_parameters_hook(self) -> None:
class MegatronRowParallelLinear(RowParallelLinear):
def forward(self, x):
"""
register pre_forward_hook and pre_backward_hook for FSTPLinear.
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
a reduce_scatter of the result.
"""
out = megatron_fused_dense_func(x, self.weight, self.bias)
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group)

def _pre_forward_hook(module: nn.Module, inputs: Any):
block_index = self.module_block[module]
name_index = self.module_name_index[module]
if name_index == 0:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight

# start the all-gather for next module
next_module = self.block_module[block_index][name_index + 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler
else:
handler = self.module_handler[module]
handler.wait()
if name_index != 4:
next_module = self.block_module[block_index][name_index + 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler

def _post_forward_hook(module: nn.Module, input, output):
if module in self.FSTP_global_weights:
del self.FSTP_global_weights[module]
if module in self.module_handler:
del self.module_handler[module]

def _pre_backward_hook(module: nn.Module, grad_output):
block_index = self.module_block[module]
name_index = self.module_name_index[module]
if name_index == 4:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight

# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler
else:
handler = self.module_handler[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler
class MegatronScaleColumnParallelLinear(nn.Linear):
"""
def _post_backward_hook(module, grad_input, grad_output):
del self.FSTP_global_weights[module]
Args:
in_features (int): size of each input sample
out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config.
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul.
If not, then the input is already gathered.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
weight_scale (int): For training stability. 1 by default.
"""

for module in self.FSTP_modules:
# import pdb; pdb.set_trace()
module.register_forward_pre_hook(_pre_forward_hook)
module.register_forward_hook(_post_forward_hook)
# module.register_backward_pre_hook(_pre_backward_hook)
# module.register_backward_hook(_post_backward_hook)
module.register_full_backward_pre_hook(_pre_backward_hook)
module.register_full_backward_hook(_post_backward_hook)
def __init__(
self,
in_features: int,
out_features: int,
process_group: Optional[torch.distributed.ProcessGroup],
bias: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
weight_scale: int = 1,
) -> None:
world_size = torch.distributed.get_world_size(process_group)
if out_features % world_size != 0:
raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
self.process_group = process_group
self.weight_scale = weight_scale

def forward(self, input, gather_dim=0): # pylint: disable=W0622
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
if self.weight_scale != 1:
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
else:
weight = self.weight
return megatron_fused_dense_func(
input,
weight,
self.bias,
process_group=self.process_group,
sequence_parallel=gpc.config.parallel.sequence_parallel,
gather_dim=gather_dim,
)

class CoarseGrainedFSTPAllGatherSyncHandler:
"""
Expand Down
14 changes: 4 additions & 10 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from internlm.model.linear import (
FeedForward,
FSTPFeedForward,
MegatronFeedForward,
RewardModelLinear,
ScaleColumnParallelLinear,
MegatronScaleColumnParallelLinear,
)
from internlm.model.multi_head_attention import MHA
from internlm.model.utils import (
Expand Down Expand Up @@ -77,7 +79,6 @@ def __init__(
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
tp_mode: str = "origin_tp",
block_idx: int = 0,
):
super().__init__()
Expand All @@ -103,8 +104,6 @@ def __init__(
use_flash_attn=use_flash_attn,
device=device,
dtype=dtype,
tp_mode=tp_mode,
block_idx=block_idx,
)

self.dropout1 = nn.Dropout(drop_rate)
Expand All @@ -116,7 +115,7 @@ def __init__(
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)

if use_swiglu:
mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward
mlp_cls = MegatronFeedForward if gpc.config.parallel["sp_megatron"] else FeedForward
self.mlp = mlp_cls(
hidden_size,
int(hidden_size * mlp_ratio),
Expand Down Expand Up @@ -299,12 +298,11 @@ def __init__(
super().__init__()

checkpoint_layer_num = int(num_layers * checkpoint)
self.tp_mode = gpc.config.parallel["tensor"]["mode"]

if is_reward:
head_cls = RewardModelLinear
else:
head_cls = ScaleColumnParallelLinear
head_cls = MegatronScaleColumnParallelLinear if gpc.config.parallel["sp_megatron"] else ScaleColumnParallelLinear
if first:
if embed_split_hidden:
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
Expand Down Expand Up @@ -345,7 +343,6 @@ def __init__(
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
tp_mode=self.tp_mode,
block_idx=lid,
)
for lid in range(num_layers)
Expand Down Expand Up @@ -392,9 +389,6 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N
assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp":
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)

max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None

Expand Down
Loading

0 comments on commit dba909d

Please sign in to comment.