Skip to content

Commit

Permalink
feat(core/context): support pp for initializing isp/msp/fsp process g…
Browse files Browse the repository at this point in the history
…roup
  • Loading branch information
huangting4201 committed Dec 19, 2023
1 parent 76be8c2 commit d30aecd
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 247 deletions.
20 changes: 12 additions & 8 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,21 +152,25 @@
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
tensor parallel (dict):
1. size: int, the size of tensor parallel.
2. sp: str, the sequence parallel mode, should be in ['none', 'megatron', 'flash-attn', 'intern'],
defaults to 'none', means the sequence parallel will be disabled.
3. intern_overlap: bool, enable/disable all_gather/reduce_scatter communication overlap when using 'intern' mode sp,
defaults to False.
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel.
msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size.
fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size.
isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel.
pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=2, fsdp=False),
tensor=dict(size=1, sp="intern", intern_overlap=False, memory_pool=False),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=8, overlap=True, memory_pool=True),
sequence=4,
tensor=dict(size=4, mode="mtp"),
pipeline=dict(size=2, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
)

cudnn_deterministic = False
Expand Down
2 changes: 0 additions & 2 deletions internlm/core/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from .process_group_initializer import (
Initializer_Data,
Initializer_Model,
Initializer_Nettest,
Initializer_Pipeline,
Initializer_Tensor,
Expand Down Expand Up @@ -44,7 +43,6 @@
"Initializer_Nettest",
"Initializer_Zero3_dp",
"ProcessGroupInitializer",
"Initializer_Model",
"seed",
"set_mode",
"add_seed",
Expand Down
31 changes: 21 additions & 10 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,18 +476,24 @@ def init_parallel_groups(self):
parallel_config = self.config.get("parallel", None)
if parallel_config is not None:
self._set_parallel_size_from_config(parallel_config, "weight", "weight_parallel_size")
self._set_parallel_size_from_config(parallel_config, "sequence", "sequence_parallel_size")
self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size")
self._set_parallel_size_from_config(parallel_config, "tensor", "tensor_parallel_size")
self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size")
self._set_parallel_size_from_config(parallel_config, "zero1", "zero1_parallel_size")

# the user should not set the data parallel size manually
# instead, it should be calculated based on other parallel config
assert self.tensor_parallel_size == 1
assert self.pipeline_parallel_size == 1
assert self.zero1_parallel_size >= 1
self.data_parallel_size = self.world_size // self.sequence_parallel_size
self.weight_data_parallel_size = self.world_size // self.weight_parallel_size
self.sequence_parallel_size = self.tensor_parallel_size
self.data_parallel_size = self.world_size // self.pipeline_parallel_size // self.sequence_parallel_size
self.weight_data_parallel_size = self.world_size // self.pipeline_parallel_size // self.weight_parallel_size
if parallel_config["tensor"]["mode"] != "isp":
assert (
self.zero1_parallel_size <= self.data_parallel_size
), f"zero1_size:{self.zero1_parallel_size} should be less than dp_size:{self.data_parallel_size}"
else:
assert (
self.zero1_parallel_size <= self.weight_data_parallel_size
), f"zero1_size:{self.zero1_parallel_size} should be less than wdp_size:{self.weight_data_parallel_size}"

# the recommended nettest_parallel_size is 32 GPUs
self.nettest_parallel_size = 32
Expand All @@ -508,6 +514,7 @@ def init_parallel_groups(self):
rank,
world_size,
self.weight_parallel_size,
self.weight_data_parallel_size,
self.sequence_parallel_size,
self.data_parallel_size,
self.pipeline_parallel_size,
Expand All @@ -520,12 +527,16 @@ def init_parallel_groups(self):
# run initialization of different process groups
initializers = []
initializers.append(pgroup_initializer.Initializer_Weight(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Sequence(*initializer_args))
if parallel_config["tensor"]["mode"] == "isp":
initializers.append(pgroup_initializer.Initializer_Weight_Data(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Data(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Weight_Data(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
# if self.weight_parallel_size <= 1:
# initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
if parallel_config["tensor"]["mode"] != "isp":
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
else:
initializers.append(pgroup_initializer.Initializer_Zero1_ISP(*initializer_args))
if isinstance(self.config.parallel.zero1, dict) and self.config.parallel.zero1.get("fsdp", False):
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
Expand Down
Loading

0 comments on commit d30aecd

Please sign in to comment.