diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index d88661953a29..30855a622adb 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -187,6 +187,9 @@ def bert_model_forward( hidden_states = split_forward_gather_backward(hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) + if encoder_hidden_states is not None: + encoder_hidden_states = split_forward_gather_backward( + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: @@ -1241,6 +1244,9 @@ def forward( embedding_output = split_forward_gather_backward(embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group) + if encoder_hidden_states is not None: + encoder_hidden_states = split_forward_gather_backward( + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) encoder_outputs = self.encoder( embedding_output, diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ccf7764079a9..c417e5d017bd 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -35,6 +36,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[LlamaDecoderLayer] = ModulePolicyDescription( attribute_replacement={ diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 58663553b922..abe491bfaace 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -104,16 +104,20 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[OPTAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_opt_flash_attention_forward(), - }) + }, + policy=policy, + target_key=OPTAttention) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_opt_decoder_layer_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=OPTDecoderLayer) return policy diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 651883d35b87..192a1b8472fc 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Optional, Tuple @@ -59,6 +60,10 @@ def module_policy(self): policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 757bab95f273..b4fb8692e684 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -1,3 +1,4 @@ +import warnings from typing import Callable, Dict, List, Union import torch.nn as nn @@ -32,6 +33,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={}, param_replacement=[], diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index a33f929f1e48..bffb624d0d1a 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Tuple @@ -33,7 +34,6 @@ def preprocess(self): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - # TODO: vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size if vocab_size % world_size != 0: @@ -52,6 +52,14 @@ def module_policy(self): policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn( + "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_jit_fused: + self.shard_config.enable_jit_fused = False + warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused flag.") + if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ "self_attn.embed_dim": @@ -198,20 +206,11 @@ def module_policy(self): # enable flash attention if self.shard_config.enable_flash_attention: - policy[WhisperAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_whisper_flash_attention_forward(), - }) - - # use jit fused operator - if self.shard_config.enable_jit_fused: - policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={ - 'forward': get_jit_fused_whisper_encoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={ - 'forward': get_jit_fused_whisper_decoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=WhisperAttention) return policy diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 6445b314dc97..011fb8d238cc 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': - atol, rtol = 1e-3, 1e-3 + atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 @@ -77,7 +77,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check weights and gradients if test_config['precision'] == 'fp32': - atol, rtol = 1e-3, 1e-3 + atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 @@ -89,7 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_optimizer.step() sharded_optimizer.step() if test_config['precision'] == 'fp32': - atol, rtol = 1e-3, 1e-3 + atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): @@ -114,6 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # TODO(jianghai) fix fp16 +#TODO fix WhisperForConditionalGeneration enable jit fused operator @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2,