From 7596e9ae08e32a386d11e896b08c9e15fd120c0b Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Fri, 11 Aug 2023 10:32:53 +0800 Subject: [PATCH] [pipeline] rewrite bert tests and fix some bugs (#4409) * add pipeline policy and bert forward to be done * add bertmodel pipeline forward and make tests * add Bert_Policy and test for policy * update formatting * update formatting * update the code * fix bugs * fix name confilt * add bloom model and policy ,revise the base class of policy * revise * revision * add bert_for_pretraining * add bert_for_pretraining forward and policy * fix typos * cancel warning * change the imediate output to default dict * change the default output of get_shared_params * rewrite bert test * rewrite bert test * fix some bugs * del pipeline tests * del pipeline tests * del useless print * del useless print * rewrite data repeats --- tests/kit/model_zoo/transformers/bert.py | 3 +- tests/test_shardformer/test_model/_utils.py | 8 +- .../test_model/test_shard_bert.py | 129 +++++++++++------- .../test_model/test_shard_bert_pipeline.py | 107 --------------- 4 files changed, 88 insertions(+), 159 deletions(-) delete mode 100644 tests/test_shardformer/test_model/test_shard_bert_pipeline.py diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index 9834f5425027..52158596bcf8 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -104,7 +104,8 @@ def data_gen_for_qa(): output_transform_fn = lambda x: x # define loss funciton -loss_fn_for_bert_model = lambda x: x.pooler_output.sum() +loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state + )) loss_fn = lambda x: x.loss config = transformers.BertConfig(hidden_size=128, diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index cce21809d829..c9da9d32e554 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -131,6 +131,8 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer, data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable, booster: Booster): + org_model.cuda() + sharded_model.cuda() def _criterion(outputs, inputs): outputs = output_transform_fn(outputs) @@ -141,7 +143,8 @@ def _criterion(outputs, inputs): sharded_model.train() if booster.plugin.stage_manager is not None: data = { - k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v + k: v.to('cuda').repeat(*([4] + [1] * + (v.dim() - 1))) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() } data_iter = iter([data]) @@ -162,6 +165,7 @@ def _criterion(outputs, inputs): org_model.train() data = {k: v.cuda() for k, v in data.items()} org_output = org_model(**data) + org_loss = criterion(org_output) org_loss.backward() @@ -226,7 +230,6 @@ def check_grad(org_model: Module, atol: float = 1e-5, rtol: float = 1e-3, verbose: bool = False): - for suffix in layer_suffix: org_grad = getattr_(org_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad @@ -242,7 +245,6 @@ def check_grad(org_model: Module, # embedding may be resized when using tensor parallel if shard_grad.shape[0] > org_grad.shape[0]: shard_grad = shard_grad[:org_grad.shape[0], :] - if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index afc1507e8b24..fdbcd014e1b8 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -1,65 +1,98 @@ import pytest import torch +from torch import distributed as dist import colossalai -from colossalai.cluster import ProcessGroupMesh from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.auto_policy import get_autopolicy -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # unwarp model +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) + + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if org_model.__class__.__name__ == 'BertModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + + check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3) + # unwrap model if org_model.__class__.__name__ == 'BertModel': bert = org_model - sharded_bert = sharded_model + sharded_bert = sharded_model.unwrap() else: bert = org_model.bert - sharded_bert = sharded_model.bert - - # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output) - - # do backward - org_loss.backward() - shard_loss.backward() - - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" - - # check grad - col_layer_for_check = ['encoder.layer[0].attention.self.query', 'embeddings.word_embeddings'] - row_layer_for_check = ['encoder.layer[0].attention.output.dense'] - check_grad(bert, sharded_bert, col_layer_for_check, atol=1e-7, rtol=1e-3, dim=0, verbose=False) - check_grad(bert, sharded_bert, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False) - - -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -@parameterize('use_lazy_init', [False, True]) -def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused, - use_lazy_init): + sharded_bert = sharded_model.unwrap().bert + + col_layer_for_check = ['encoder.layer[0].output.dense'] + row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense'] + + if stage_manager is None or stage_manager.is_first_stage(): + #check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3) + #check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3) + check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False) + check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False) + + torch.cuda.empty_cache() + + +@parameterize('test_config', [{ + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': True +}, { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': False, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}]) +def run_bert_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') + test_config['precision'] = 'float' + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention, enable_jit_fused, use_lazy_init) - check_state_dict(org_model, sharded_model, name=name) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() @@ -73,7 +106,7 @@ def check_bert(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_bert(): - spawn(check_bert, 2) + spawn(check_bert, 4) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py deleted file mode 100644 index 3170b58a1175..000000000000 --- a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py +++ /dev/null @@ -1,107 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.auto_policy import get_autopolicy -from colossalai.shardformer.shard import ShardConfig -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward - - -def check_bert_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager): - stage_manager = stage_manager - policy = get_autopolicy(model) - policy.set_model(model) - model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) - policy.set_shard_config(model_config) - layers = policy.get_held_layers() - if stage_manager.is_first_stage(): - assert len(layers) == 1 + 1 - else: - if name == "transformers_bert": - assert len(layers) == 1 + 1 - elif name in [ - "transformers_bert_for_sequence_classification", "transformers_bert_for_token_classification", - "transformers_bert_for_mcq" - ]: - assert len(layers) == 1 + 3 - else: - assert len(layers) == 1 + 2 - - -def check_bert_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager): - if name == 'transformers_bert_for_mcq': - x = torch.randint(0, 1000, (2, 3, 3)).cuda() - attention_mask = torch.ones_like(x).cuda() - if stage_manager.stage == 0: - output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - assert output['hidden_states'].shape == (6, 3, 128) - else: - hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda() - output = sharded_model(input_ids=x, - hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) - assert output[0].shape == (2, 3) - else: - x = torch.randint(0, 1000, (2, 3)).cuda() - # one batch, 2 single sentences, each sentence has 3 tokens - hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - assert output['hidden_states'].shape == (2, 3, 128) - else: - attention_mask = torch.ones((2, 3)).cuda() - output = sharded_model(hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) - assert output[0].shape[0] == 2 - - -@parameterize('enable_fused_normalization', [False]) -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_bert -def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - PP_DIM = 0 - PP_SIZE = 2 - pg_mesh = ProcessGroupMesh(PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - - sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - check_bert_model_policy(name, org_model, stage_manager) - check_bert_model_pipeline_forward(name, sharded_model, stage_manager) - - torch.cuda.empty_cache() - - -def check_bert(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_bert_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_bert(): - spawn(check_bert, 2) - - -if __name__ == "__main__": - test_bert()