Skip to content

Commit

Permalink
[shardformer] fix gathering output when using tensor parallelism (#5431)
Browse files Browse the repository at this point in the history
* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert
  • Loading branch information
flybird11111 authored Mar 18, 2024
1 parent f2e8b9e commit 5e16bf7
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 13 deletions.
10 changes: 9 additions & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,12 @@ def get_param_info(optim: Optimizer):

if optim is None:
return {}
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
param_info = {
"param_groups": [],
"param2id": {},
"id2param": {},
"param2shape": {},
}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != "params"}
Expand Down Expand Up @@ -899,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
Expand Down Expand Up @@ -939,6 +945,7 @@ def __init__(
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
parallel_output: bool = True,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
Expand Down Expand Up @@ -1035,6 +1042,7 @@ def __init__(
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down
16 changes: 12 additions & 4 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from colossalai.shardformer.shard import ShardConfig

from ..layer import cross_entropy_1d
from ..layer._operation import gather_forward_split_backward


class GPT2PipelineForwards:
Expand Down Expand Up @@ -337,6 +338,9 @@ def gpt2_lmhead_model_forward(
else:
loss = loss_fct(shift_logits, shift_labels)

if not shard_config.parallel_output:
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)

if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
Expand Down Expand Up @@ -793,11 +797,12 @@ def forward(
scale = scale * (1 / float(self.layer_idx + 1))

# use coloattention
attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
)
if not hasattr(self, "attention"):
self.attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
)

attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)

attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
Expand Down Expand Up @@ -1083,6 +1088,9 @@ def forward(
else:
loss = loss_fct(shift_logits, shift_labels)

if not shard_config.parallel_output:
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
Expand Down
11 changes: 6 additions & 5 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from colossalai.shardformer.shard import ShardConfig

from ..layer import cross_entropy_1d
from ..layer._operation import _gather
from ..layer._operation import gather_forward_split_backward

try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
Expand Down Expand Up @@ -290,7 +290,7 @@ def llama_for_causal_lm_forward(
loss = loss_fct(shift_logits, shift_labels)

if not shard_config.parallel_output:
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -485,8 +485,9 @@ def forward(
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal

attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
if not hasattr(self, "attention"):
self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = self.attention(
query_states,
key_states,
value_states,
Expand Down Expand Up @@ -593,7 +594,7 @@ def forward(
loss = loss_fct(shift_logits, shift_labels)

if not shard_config.parallel_output:
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,4 +242,4 @@ def get_stage_index(
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indices.append([start_idx, end_idx])

return stage_indices[0] if num_model_chunks == 1 else stage_indices
return stage_indices[0] if num_model_chunks == 1 else stage_indices
4 changes: 3 additions & 1 deletion colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ class ShardConfig:
enable_all_optimization: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
parallel_output = True
parallel_output: bool = True
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# TODO padding vocab
# make_vocab_size_divisible_by: int = 128
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
Expand Down
2 changes: 1 addition & 1 deletion tests/test_booster/test_plugin/test_3d_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def run_grad_acc_test(test_args):
origin_model, origin_optimizer, dataloader=dataloader
)
for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()):
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)


def run_dist(rank, world_size, port, early_stop: bool = True):
Expand Down

0 comments on commit 5e16bf7

Please sign in to comment.