diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index abe491bfaace..be9d1c58b79e 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -103,21 +103,21 @@ def module_policy(self): target_key=OPTDecoderLayer) # use flash attention - if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_opt_flash_attention_forward(), - }, - policy=policy, - target_key=OPTAttention) + # if self.shard_config.enable_flash_attention: + # self.append_or_create_method_replacement(description={ + # 'forward': get_opt_flash_attention_forward(), + # }, + # policy=policy, + # target_key=OPTAttention) # use jit fused operator - if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_opt_decoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=OPTDecoderLayer) + # if self.shard_config.enable_jit_fused: + # self.append_or_create_method_replacement(description={ + # 'forward': get_jit_fused_opt_decoder_layer_forward(), + # 'dropout_add': get_jit_fused_dropout_add_func(), + # }, + # policy=policy, + # target_key=OPTDecoderLayer) return policy diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index ad344585e8ce..71483b752c34 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -137,7 +137,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'initial_scale': 1 }]) def run_opt_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 011fb8d238cc..6eaed7d37e47 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -89,7 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_optimizer.step() sharded_optimizer.step() if test_config['precision'] == 'fp32': - atol, rtol = 2e-4, 2e-4 + atol, rtol = 5e-4, 5e-4 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage():