Skip to content

Commit

Permalink
[pipeline] rewrite bert tests and fix some bugs (hpcaitech#4409)
Browse files Browse the repository at this point in the history
* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt

* add bloom model and policy ,revise the base class of policy

* revise

* revision

* add bert_for_pretraining

* add bert_for_pretraining forward and policy

* fix typos

* cancel warning

* change the imediate output to default dict

* change the default output of get_shared_params

* rewrite bert test

* rewrite bert test

* fix some bugs

* del pipeline tests

* del pipeline tests

* del useless print

* del useless print

* rewrite data repeats
  • Loading branch information
CjhHa1 authored and ver217 committed Aug 15, 2023
1 parent 5269a24 commit 4fa0b1f
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 159 deletions.
3 changes: 2 additions & 1 deletion tests/kit/model_zoo/transformers/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def data_gen_for_qa():
output_transform_fn = lambda x: x

# define loss funciton
loss_fn_for_bert_model = lambda x: x.pooler_output.sum()
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
))
loss_fn = lambda x: x.loss

config = transformers.BertConfig(hidden_size=128,
Expand Down
8 changes: 5 additions & 3 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer,
data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable,
booster: Booster):
org_model.cuda()
sharded_model.cuda()

def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
Expand All @@ -141,7 +143,8 @@ def _criterion(outputs, inputs):
sharded_model.train()
if booster.plugin.stage_manager is not None:
data = {
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
k: v.to('cuda').repeat(*([4] + [1] *
(v.dim() - 1))) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
data_iter = iter([data])
Expand All @@ -162,6 +165,7 @@ def _criterion(outputs, inputs):
org_model.train()
data = {k: v.cuda() for k, v in data.items()}
org_output = org_model(**data)

org_loss = criterion(org_output)
org_loss.backward()

Expand Down Expand Up @@ -226,7 +230,6 @@ def check_grad(org_model: Module,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False):

for suffix in layer_suffix:
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
Expand All @@ -242,7 +245,6 @@ def check_grad(org_model: Module,
# embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[:org_grad.shape[0], :]

if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")

Expand Down
129 changes: 81 additions & 48 deletions tests/test_shardformer/test_model/test_shard_bert.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,98 @@
import pytest
import torch
from torch import distributed as dist

import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.auto_policy import get_autopolicy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)


def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# unwarp model
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):

org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)

org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if org_model.__class__.__name__ == 'BertModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)

check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
# unwrap model
if org_model.__class__.__name__ == 'BertModel':
bert = org_model
sharded_bert = sharded_model
sharded_bert = sharded_model.unwrap()
else:
bert = org_model.bert
sharded_bert = sharded_model.bert

# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output)

# do backward
org_loss.backward()
shard_loss.backward()

assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"

# check grad
col_layer_for_check = ['encoder.layer[0].attention.self.query', 'embeddings.word_embeddings']
row_layer_for_check = ['encoder.layer[0].attention.output.dense']
check_grad(bert, sharded_bert, col_layer_for_check, atol=1e-7, rtol=1e-3, dim=0, verbose=False)
check_grad(bert, sharded_bert, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False)


@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
@parameterize('use_lazy_init', [False, True])
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
use_lazy_init):
sharded_bert = sharded_model.unwrap().bert

col_layer_for_check = ['encoder.layer[0].output.dense']
row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense']

if stage_manager is None or stage_manager.is_first_stage():
#check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3)
#check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3)
check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)

# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False)

torch.cuda.empty_cache()


@parameterize('test_config', [{
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': True
}, {
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': False,
'use_lazy_init': False
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
}])
def run_bert_test(test_config):

sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
test_config['precision'] = 'float'

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused, use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)

clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()


Expand All @@ -73,7 +106,7 @@ def check_bert(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bert():
spawn(check_bert, 2)
spawn(check_bert, 4)


if __name__ == "__main__":
Expand Down
107 changes: 0 additions & 107 deletions tests/test_shardformer/test_model/test_shard_bert_pipeline.py

This file was deleted.

0 comments on commit 4fa0b1f

Please sign in to comment.