Skip to content

Commit

Permalink
[shardformer] vit/llama/t5 ignore the sequence parallelism flag and s…
Browse files Browse the repository at this point in the history
…ome fix. (#4498)

* [shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

* fix

fix

fix

fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* activate checks
  • Loading branch information
flybird11111 authored Aug 24, 2023
1 parent e04436a commit 3353e55
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 21 deletions.
6 changes: 6 additions & 0 deletions colossalai/shardformer/modeling/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ def bert_model_forward(
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)

for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
if stage_manager.is_first_stage() and idx == 0:
Expand Down Expand Up @@ -1241,6 +1244,9 @@ def forward(
embedding_output = split_forward_gather_backward(embedding_output,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)

encoder_outputs = self.encoder(
embedding_output,
Expand Down
5 changes: 5 additions & 0 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Union

Expand Down Expand Up @@ -35,6 +36,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:

policy = {}

if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")

if self.shard_config.enable_tensor_parallelism:
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement={
Expand Down
12 changes: 8 additions & 4 deletions colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,20 @@ def module_policy(self):

# use flash attention
if self.shard_config.enable_flash_attention:
policy[OPTAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_opt_flash_attention_forward(),
})
},
policy=policy,
target_key=OPTAttention)

# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_opt_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=OPTDecoderLayer)

return policy

Expand Down
5 changes: 5 additions & 0 deletions colossalai/shardformer/policies/t5.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -59,6 +60,10 @@ def module_policy(self):

policy = {}

if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")

if self.shard_config.enable_tensor_parallelism:
policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
Expand Down
5 changes: 5 additions & 0 deletions colossalai/shardformer/policies/vit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Callable, Dict, List, Union

import torch.nn as nn
Expand Down Expand Up @@ -32,6 +33,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:

policy = {}

if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")

if self.shard_config.enable_tensor_parallelism:
policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
Expand Down
27 changes: 13 additions & 14 deletions colossalai/shardformer/policies/whisper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Tuple

Expand Down Expand Up @@ -33,7 +34,6 @@ def preprocess(self):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
Expand All @@ -52,6 +52,14 @@ def module_policy(self):

policy = {}

if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn(
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_jit_fused:
self.shard_config.enable_jit_fused = False
warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused flag.")

if self.shard_config.enable_tensor_parallelism:
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.embed_dim":
Expand Down Expand Up @@ -198,20 +206,11 @@ def module_policy(self):

# enable flash attention
if self.shard_config.enable_flash_attention:
policy[WhisperAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_whisper_flash_attention_forward(),
})

# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_whisper_encoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_whisper_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=WhisperAttention)

return policy

Expand Down
7 changes: 4 additions & 3 deletions tests/test_shardformer/test_model/test_shard_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
atol, rtol = 2e-4, 2e-4
else:
atol, rtol = 5e-3, 5e-3

Expand Down Expand Up @@ -77,7 +77,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,

# check weights and gradients
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
atol, rtol = 2e-4, 2e-4
else:
atol, rtol = 5e-3, 5e-3

Expand All @@ -89,7 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
org_optimizer.step()
sharded_optimizer.step()
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
atol, rtol = 2e-4, 2e-4
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
Expand All @@ -114,6 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,


# TODO(jianghai) fix fp16
#TODO fix WhisperForConditionalGeneration enable jit fused operator
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
Expand Down

0 comments on commit 3353e55

Please sign in to comment.