Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(initialize/launch.py): refactor config for fstp #4

Merged
merged 4 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,19 +152,19 @@
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
tensor parallel (dict):
1. size: int, the size of tensor parallel.
2. mode: str, the mode should be 'origin_tp' or 'fstp', defaults to 'origin_tp'. If the mode is 'fstp',
the sequence_parallel should be True.
2. sp: str, the sequence parallel mode, should be in ['none', 'megatron', 'flash-attn', 'intern'],
defaults to 'none', means the sequence parallel will be disabled.
3. intern_overlap: bool, enable/disable all_gather/reduce_scatter communication overlap when using 'intern' mode sp,
defaults to False.
pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, mode="fstp", overlap=True),
tensor=dict(size=8, sp="intern", intern_overlap=True),
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True,
)

cudnn_deterministic = False
Expand Down
23 changes: 14 additions & 9 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,15 +306,20 @@ def args_sanity_check():
), "sequence parallel does not support use_flash_attn=False"

if isinstance(gpc.config.parallel["tensor"], int):
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode="origin_tp")

if gpc.config.parallel["tensor"].get("mode", None) is None:
gpc.config.parallel["tensor"]["mode"] = "origin_tp"

if gpc.config.parallel["tensor"].get("mode", None) == "fstp":
assert (
gpc.config.parallel.sequence_parallel is True
), "when the tp_mode is fstp, the sequence_parallel should be True."
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], sp="none", intern_overlap=False)
if gpc.config.parallel["tensor"].get("sp", None) is None:
gpc.config.parallel["tensor"]["sp"] = "none"
if gpc.config.parallel["tensor"].get("intern_overlap", None) is None:
gpc.config.parallel["tensor"]["intern_overlap"] = False
assert gpc.config.parallel["tensor"].get("sp", None) in [
"none",
"megatron",
"flash-attn",
"intern",
], "invalid sp mode, only ['none', 'megatron', 'flash-attn', 'intern'] is supported"
# adapt to old version's sequence parallel config
if gpc.config.parallel["tensor"].get("sp", None) in ["megatron", "flash-attn", "intern"]:
gpc.config.parallel.sequence_parallel = True

# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
Expand Down
14 changes: 7 additions & 7 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
tp_mode: str = "origin_tp",
sp_mode: str = "none",
):
super().__init__()
self.checkpoint = checkpoint
Expand All @@ -102,7 +102,7 @@ def __init__(
use_flash_attn=use_flash_attn,
device=device,
dtype=dtype,
tp_mode=tp_mode,
sp_mode=sp_mode,
)

self.dropout1 = nn.Dropout(drop_rate)
Expand All @@ -114,7 +114,7 @@ def __init__(
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)

if use_swiglu:
mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward
mlp_cls = FSTPFeedForward if sp_mode == "intern" else FeedForward
self.mlp = mlp_cls(
hidden_size,
int(hidden_size * mlp_ratio),
Expand Down Expand Up @@ -297,7 +297,7 @@ def __init__(
super().__init__()

checkpoint_layer_num = int(num_layers * checkpoint)
self.tp_mode = gpc.config.parallel["tensor"]["mode"]
self.sp_mode = gpc.config.parallel["tensor"]["sp"]

if is_reward:
head_cls = RewardModelLinear
Expand Down Expand Up @@ -343,7 +343,7 @@ def __init__(
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
tp_mode=self.tp_mode,
sp_mode=self.sp_mode,
)
for lid in range(num_layers)
]
Expand Down Expand Up @@ -389,8 +389,8 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N
assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp":
# if the sequence parallel mode is 'intern', the indexes should also be split in sequence dimension.
if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern":
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)

max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
Expand Down
8 changes: 4 additions & 4 deletions internlm/model/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(
use_flash_attn: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tp_mode: str = "origin_tp",
sp_mode: str = "none",
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
Expand Down Expand Up @@ -203,7 +203,7 @@ def __init__(
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)

# notice here should change bias=True
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
Wqkv_cls = FSTPLinear if sp_mode == "intern" else ColumnParallelLinearTorch
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
Expand All @@ -219,12 +219,12 @@ def __init__(
self.inner_cross_attn = inner_cross_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
if tp_mode == "fstp":
if sp_mode == "intern":
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)

# output projection always have the bias (for now)
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
out_proj_cls = FSTPLinear if sp_mode == "intern" else RowParallelLinearTorch
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
Expand Down
5 changes: 2 additions & 3 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from typing import Any, Optional, Union
from typing import Optional

import fused_dense_lib as fused_dense_cuda
import torch
import torch.nn.functional as F
from flash_attn.utils.distributed import all_reduce_raw # , reduce_scatter_raw
from flash_attn.utils.distributed import all_reduce_raw
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
Expand Down Expand Up @@ -397,7 +397,6 @@ def backward(ctx, grad_output, *args):
grad_input = grad_input.contiguous()
process_group = ctx.process_group
all_gather_handler = ctx.all_gather_handler
module = ctx.module
block_index = ctx.block_index
module_name = ctx.module_name

Expand Down
8 changes: 4 additions & 4 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
# -*- encoding: utf-8 -*-

import math
from typing import Optional, List
from functools import partial
from typing import List, Optional

import torch
import torch.distributed as dist
from torch.optim import Optimizer

from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool
from internlm.model.utils import release_reduce_scatter_memory_pool
from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import (
BucketStore,
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(
hysteresis = grad_scal_cfg.hysteresis
max_scale = grad_scal_cfg.max_scale

if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
self._fstp_handler = gpc.config.fstp_handler

# Zero related args
Expand Down Expand Up @@ -366,8 +366,8 @@ def _accum_grads_store_in_bucket(self, bucket: BucketStore, reduce_rank: Optiona
_param.grad.add_(_grad)

# release cuda memory.
release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index)
self._fstp_handler.reduce_scatter_handlers[_key] = None
_grad = None

bucket.reset_by_rank(reduce_rank)

Expand Down
2 changes: 1 addition & 1 deletion internlm/solver/optimizer/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, group_id, dp_parallel_mode):

def num_elements_in_bucket(self, reduce_rank: int = None):
return self._num_elements_in_bucket[reduce_rank]

def num_params_in_bucket(self, reduce_rank: int = None):
return len(self._params[reduce_rank])

Expand Down
3 changes: 1 addition & 2 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ def initialize_model():

gpc.config.fstp_handler = None

if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler._register_sync_parameters_hook()
gpc.config.fstp_handler = handler

Expand Down
4 changes: 2 additions & 2 deletions internlm/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
def switch_sequence_parallel_mode():
prev_mode = gpc.config.parallel.sequence_parallel
try:
if gpc.config.parallel["tensor"]["mode"] == "fstp":
if gpc.config.parallel["tensor"]["sp"] == "intern":
gpc.config.parallel.sequence_parallel = True
else:
gpc.config.parallel.sequence_parallel = False
Expand Down Expand Up @@ -106,7 +106,7 @@ def evaluate_on_val_dls(
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
if gpc.config.parallel["tensor"]["mode"] == "fstp":
if gpc.config.parallel["tensor"]["sp"] == "intern":
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = torch.Size(
[
Expand Down
2 changes: 1 addition & 1 deletion internlm/utils/gputest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def empty_cache_and_diag(batch_count, interval=50):
# # import time
# # time.sleep(10)
# print(e, "rank = ", gpc.get_global_rank(), flush=True)
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")

# do empty_cache after the bench
torch.cuda.empty_cache()
Expand Down
Loading