Skip to content

Commit

Permalink
Support 4d parallel + flash attention (hpcaitech#5789)
Browse files Browse the repository at this point in the history
* support tp + sp + pp

* remove comments

---------

Co-authored-by: Edenzzzz <[email protected]>
  • Loading branch information
Edenzzzz and Edenzzzz authored Jun 17, 2024
1 parent 2ddf624 commit 8795bb2
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 373 deletions.
29 changes: 18 additions & 11 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,9 @@ def __init__(
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"

if enable_sequence_parallelism:
self.sequence_parallelism_mode = sequence_parallelism_mode if sequence_parallelism_mode is not None else "1"
self.sequence_parallelism_mode = (
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
)
assert (
self.sequence_parallelism_mode in SUPPORT_SP_MODE
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
Expand All @@ -1014,19 +1016,13 @@ def __init__(
self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
elif self.sequence_parallelism_mode in ["all_to_all"]:
assert (
tp_size == 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism"
assert (
pp_size == 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with pipeline parallelism"
self.sp_size = dist.get_world_size() if sp_size is None else sp_size
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size)
self.sp_size = 1 if sp_size is None else sp_size
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
else:
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
assert (
sp_size == 1 or sp_size is None
), f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True"
), f"You should not set sp_size when sequence parallelism is not enabled."
self.sp_size = 1

self.tp_size = tp_size
Expand All @@ -1040,11 +1036,22 @@ def __init__(
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
(
self.dp_axis,
self.pp_axis,
self.tp_axis,
self.sp_axis,
) = (
0,
1,
2,
3,
)
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)

self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
Expand Down
Loading

0 comments on commit 8795bb2

Please sign in to comment.