Skip to content

Commit

Permalink
feat(initialize/launch.py): refactor config for fstp
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Oct 20, 2023
1 parent 815a584 commit d91a5d9
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 44 deletions.
10 changes: 5 additions & 5 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,19 +152,19 @@
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
tensor parallel (dict):
1. size: int, the size of tensor parallel.
2. mode: str, the mode should be 'origin_tp' or 'fstp', defaults to 'origin_tp'. If the mode is 'fstp',
the sequence_parallel should be True.
2. sp: str, the sequence parallel mode, should be in ['none', 'megatron', 'flash-attn', 'intern'],
defaults to 'none', means the sequence parallel will be disabled.
3. intern_overlap: bool, enable/disable all_gather/reduce_scatter communication overlap when using 'intern' mode sp,
defaults to False.
pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, mode="fstp", overlap=True),
tensor=dict(size=8, sp="intern", intern_overlap=True),
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True,
)

cudnn_deterministic = False
Expand Down
23 changes: 14 additions & 9 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,15 +306,20 @@ def args_sanity_check():
), "sequence parallel does not support use_flash_attn=False"

if isinstance(gpc.config.parallel["tensor"], int):
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode="origin_tp")

if gpc.config.parallel["tensor"].get("mode", None) is None:
gpc.config.parallel["tensor"]["mode"] = "origin_tp"

if gpc.config.parallel["tensor"].get("mode", None) == "fstp":
assert (
gpc.config.parallel.sequence_parallel is True
), "when the tp_mode is fstp, the sequence_parallel should be True."
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], sp="none", intern_overlap=False)
if gpc.config.parallel["tensor"].get("sp", None) is None:
gpc.config.parallel["tensor"]["sp"] = "none"
if gpc.config.parallel["tensor"].get("intern_overlap", None) is None:
gpc.config.parallel["tensor"]["intern_overlap"] = False
assert gpc.config.parallel["tensor"].get("sp", None) in [
"none",
"megatron",
"flash-attn",
"intern",
], "invalid sp mode, only ['none', 'megatron', 'flash-attn', 'intern'] is supported"
# adapt to old version's sequence parallel config
if gpc.config.parallel["tensor"].get("sp", None) in ["megatron", "flash-attn", "intern"]:
gpc.config.parallel.sequence_parallel = True

# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
Expand Down
14 changes: 7 additions & 7 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
tp_mode: str = "origin_tp",
sp_mode: str = "none",
):
super().__init__()
self.checkpoint = checkpoint
Expand All @@ -102,7 +102,7 @@ def __init__(
use_flash_attn=use_flash_attn,
device=device,
dtype=dtype,
tp_mode=tp_mode,
sp_mode=sp_mode,
)

self.dropout1 = nn.Dropout(drop_rate)
Expand All @@ -114,7 +114,7 @@ def __init__(
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)

if use_swiglu:
mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward
mlp_cls = FSTPFeedForward if sp_mode == "intern" else FeedForward
self.mlp = mlp_cls(
hidden_size,
int(hidden_size * mlp_ratio),
Expand Down Expand Up @@ -297,7 +297,7 @@ def __init__(
super().__init__()

checkpoint_layer_num = int(num_layers * checkpoint)
self.tp_mode = gpc.config.parallel["tensor"]["mode"]
self.sp_mode = gpc.config.parallel["tensor"]["sp"]

if is_reward:
head_cls = RewardModelLinear
Expand Down Expand Up @@ -343,7 +343,7 @@ def __init__(
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
tp_mode=self.tp_mode,
sp_mode=self.sp_mode,
)
for lid in range(num_layers)
]
Expand Down Expand Up @@ -389,8 +389,8 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N
assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp":
# if the sequence parallel mode is 'intern', the indexes should also be split in sequence dimension.
if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern":
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)

max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
Expand Down
8 changes: 4 additions & 4 deletions internlm/model/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(
use_flash_attn: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tp_mode: str = "origin_tp",
sp_mode: str = "none",
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
Expand Down Expand Up @@ -203,7 +203,7 @@ def __init__(
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)

# notice here should change bias=True
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
Wqkv_cls = FSTPLinear if sp_mode == "intern" else ColumnParallelLinearTorch
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
Expand All @@ -219,12 +219,12 @@ def __init__(
self.inner_cross_attn = inner_cross_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
if tp_mode == "fstp":
if sp_mode == "intern":
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)

# output projection always have the bias (for now)
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
out_proj_cls = FSTPLinear if sp_mode == "intern" else RowParallelLinearTorch
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
Expand Down
45 changes: 30 additions & 15 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@

from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool
from internlm.model.utils import (
release_reduce_scatter_memory_pool,
split_forward_gather_backward,
)
from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import (
BucketStore,
Expand Down Expand Up @@ -40,8 +43,20 @@
inf = math.inf
logger = get_logger(__file__)


def print_memory(msg):
print(msg, " rank = ", gpc.get_global_rank(), " memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, " reverved memory: ", torch.cuda.memory_reserved() / 1024 / 1024 / 1024, " max memory: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True)
print(
msg,
" rank = ",
gpc.get_global_rank(),
" memory allocated: ",
torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
" reverved memory: ",
torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
" max memory: ",
torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
flush=True,
)
print("===========================================")


Expand Down Expand Up @@ -69,8 +84,8 @@ def __init__(
backoff_factor = grad_scal_cfg.backoff_factor
hysteresis = grad_scal_cfg.hysteresis
max_scale = grad_scal_cfg.max_scale
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:

if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
self._fstp_handler = gpc.config.fstp_handler

# Zero related args
Expand Down Expand Up @@ -306,7 +321,7 @@ def _define_and_attach(param, reduce_rank=None):
param=param,
reduce_rank=reduce_rank,
)

reduce_scatter_checker = partial(
self._wait_reduce_scatter_and_accumulate_grad,
param=param,
Expand Down Expand Up @@ -354,7 +369,7 @@ def reset_reduce_bucket(self) -> None:
_param.grad.add_(_grad)
# self._fstp_handler.reduce_scatter_handlers[key] = None
# del _grad
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index)
del self._fstp_handler.reduce_scatter_handlers[key]
self._fstp_handler.reduce_scatter_handlers[key] = None
assert key in self._fstp_handler.reduce_scatter_handlers
Expand All @@ -374,7 +389,7 @@ def reset_reduce_bucket(self) -> None:
# assert key in self._fstp_handler.all_reduce_handlers

bucket.reset_by_rank(rank)

def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None):
param_size = param.numel()

Expand All @@ -397,11 +412,11 @@ def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None):
_param.grad.add_(_grad)
# self._fstp_handler.reduce_scatter_handlers[key] = None
# del _grad
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index)
del self._fstp_handler.reduce_scatter_handlers[key]
self._fstp_handler.reduce_scatter_handlers[key] = None
assert key in self._fstp_handler.reduce_scatter_handlers

# if not hasattr(_param, "_fstp_all_reduce_str"):
# continue

Expand All @@ -418,7 +433,7 @@ def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None):
# assert key in self._fstp_handler.all_reduce_handlers

current_bucket.reset_by_rank(reduce_rank)

current_bucket.add_num_elements_in_bucket(param_size, reduce_rank)
current_bucket.add_param(param, reduce_rank)

Expand Down Expand Up @@ -685,16 +700,16 @@ def step(self, closure=None):
timer("sync_grad").start()
self._sync_grad()
timer("sync_grad").stop()

print_memory("No 4")

try:
res = self._step(closure=closure, norms=total_norms)
res = self._step(closure=closure, norms=total_norms)
except torch.cuda.OutOfMemoryError as e:
print(e, flush=True)
print(torch.cuda.memory_summary(), flush=True)
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")

return res

def _step(self, closure=None, norms=None):
Expand Down Expand Up @@ -822,7 +837,7 @@ def _step(self, closure=None, norms=None):
torch.cuda.synchronize()
with torch.cuda.stream(self._comm_bcast_stream):
self.broadcast_params()

timer("step").stop()

# update gradients may not be needed here, because the sync_params function is used in initialization,
Expand Down
3 changes: 1 addition & 2 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ def initialize_model():

gpc.config.fstp_handler = None

if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler._register_sync_parameters_hook()
gpc.config.fstp_handler = handler

Expand Down
4 changes: 2 additions & 2 deletions internlm/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
def switch_sequence_parallel_mode():
prev_mode = gpc.config.parallel.sequence_parallel
try:
if gpc.config.parallel["tensor"]["mode"] == "fstp":
if gpc.config.parallel["tensor"]["sp"] == "intern":
gpc.config.parallel.sequence_parallel = True
else:
gpc.config.parallel.sequence_parallel = False
Expand Down Expand Up @@ -106,7 +106,7 @@ def evaluate_on_val_dls(
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
if gpc.config.parallel["tensor"]["mode"] == "fstp":
if gpc.config.parallel["tensor"]["sp"] == "intern":
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = torch.Size(
[
Expand Down

0 comments on commit d91a5d9

Please sign in to comment.