From d79f3cf9c57f830dc803a415b41415470661b30c Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 11 Aug 2023 16:40:06 +0800 Subject: [PATCH] [shardformer] update tests for all optimization (#4413) [shardformer] update tests for all optimization --- colossalai/shardformer/modeling/bert.py | 5 ++- tests/kit/model_zoo/transformers/bert.py | 29 +++++++++----- .../test_model/test_shard_bert.py | 39 +++++++++++++------ 3 files changed, 50 insertions(+), 23 deletions(-) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index b9d4b5fda7af..eaafd67b8968 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1048,9 +1048,12 @@ def forward( final_attention_mask = final_attention_mask * scale + attention_mask else: final_attention_mask = attention_mask + + if final_attention_mask is not None: batch_size, src_len = query_layer.size()[0], query_layer.size()[2] tgt_len = key_layer.size()[2] - final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, tgt_len) + final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, + tgt_len).contiguous() query_layer = query_layer.permute(0, 2, 1, 3).contiguous() key_layer = key_layer.permute(0, 2, 1, 3).contiguous() diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index 52158596bcf8..e16d3b269ba0 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -69,21 +69,30 @@ def data_gen_for_mcq(): # data['labels'] = torch.tensor([0], dtype=torch.int64) input_ids = torch.tensor([[[ 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591, - 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102 + 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102, 5442, + 1012, 102, 102 ], [ 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096, - 2218, 1999, 1996, 2192, 1012, 102, 0, 0 + 2218, 1999, 1996, 2192, 1012, 102, 0, 0, 1012, 102, 0, 0 ]]]) - token_type_ids = torch.tensor( - [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, - 0]]]) - attention_mask = torch.tensor( - [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, - 0]]]) + token_type_ids = torch.tensor([[[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1 + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0 + ]]]) + attention_mask = torch.tensor([[[ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1 + ], + [ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0 + ]]]) labels = torch.tensor([0], dtype=torch.int64) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index fdbcd014e1b8..0a24e46d28f2 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -36,10 +36,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, tp_group = booster.plugin.tp_group # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == 'BertModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model if org_model.__class__.__name__ == 'BertModel': bert = org_model @@ -51,17 +55,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, col_layer_for_check = ['encoder.layer[0].output.dense'] row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense'] + if test_config['precision'] == 'fp32': + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): #check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3) #check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3) - check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False) - check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False) + check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) # check weights after optimizer.step() org_optimizer.step() sharded_optimizer.step() + if test_config['precision'] == 'fp32': + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False) + check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) torch.cuda.empty_cache() @@ -70,23 +82,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, - 'use_lazy_init': True + 'use_lazy_init': True, + 'precision': 'fp32', }, { 'tp_size': 2, 'pp_size': 2, - 'num_microbatches': 4, - 'enable_fused_normalization': False, - 'use_lazy_init': False + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 4, 'pp_size': 1, - 'enable_fused_normalization': True, - 'use_lazy_init': False + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', }]) def run_bert_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') - test_config['precision'] = 'float' for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)