Skip to content

Commit

Permalink
Add weight test
Browse files Browse the repository at this point in the history
  • Loading branch information
Linden Li committed Nov 19, 2023
1 parent e547a28 commit daf5a4a
Showing 1 changed file with 44 additions and 15 deletions.
59 changes: 44 additions & 15 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,8 +1806,14 @@ def test_head_dim_8_triton_mqa_attn(batch_size: int = 2):
@pytest.mark.world_size(2)
@pytest.mark.gpu
def test_tp_qkvo():
# Note: we need the RNG state in this test to ensure that weights
# are initialized with the same values in both models. Without it,
# even with a random seed, the weights will be different since the
# RNG state changes with each init.
rng_state = reproducibility.get_rng_state()

local_world_size = dist.get_local_world_size()
model_cfg = {
sharded_model_cfg = {
'name': 'mpt_causal_lm',
'init_device': 'cpu',
'd_model': 128,
Expand All @@ -1825,36 +1831,59 @@ def test_tp_qkvo():
}
}

model_cfg = om.create(model_cfg)
# Create the same model config, but with TP turned off
full_model_cfg = copy.deepcopy(sharded_model_cfg)
full_model_cfg['attn_config']['tensor_parallel_qkvo'] = False
del full_model_cfg['attn_config']['tp_world_size']

sharded_model_cfg = om.create(sharded_model_cfg)
full_model_cfg = om.create(full_model_cfg)

sharded_model = COMPOSER_MODEL_REGISTRY[sharded_model_cfg.name](sharded_model_cfg)
reproducibility.load_rng_state(rng_state)

full_model = COMPOSER_MODEL_REGISTRY[full_model_cfg.name](full_model_cfg)
reproducibility.load_rng_state(rng_state)

fsdp_config = {
'sharding_strategy': 'NO_SHARD',
'mixed_precision': 'DEFAULT'
}

model = COMPOSER_MODEL_REGISTRY[model_cfg.name](model_cfg)

# The trainer is used to wrap the model in FSDP, which can be used
# alongside with TP for 2D parallelism
trainer = Trainer(
model=model,
model=sharded_model,
fsdp_config=fsdp_config,
seed=0
)

trainer = Trainer(
model=full_model,
fsdp_config=fsdp_config,
seed=0
)

transformer_blocks = model.model.transformer.blocks
for block in transformer_blocks:
attn_module = block._fsdp_wrapped_module.attn
sharded_transformer_blocks = sharded_model.model.transformer.blocks
full_transformer_blocks = full_model.model.transformer.blocks
for sharded_block, full_block in zip(sharded_transformer_blocks, full_transformer_blocks):
sharded_attn_module = sharded_block._fsdp_wrapped_module.attn
full_attn_module = full_block._fsdp_wrapped_module.attn

# Check that all attention module weights are DTensors
assert isinstance(attn_module.Wqkv.weight, DTensor)
assert isinstance(attn_module.out_proj.weight, DTensor)
assert isinstance(sharded_attn_module.Wqkv.weight, DTensor)
assert isinstance(sharded_attn_module.out_proj.weight, DTensor)

Wqkv_local = attn_module.Wqkv.weight._local_tensor
out_proj_local = attn_module.out_proj.weight._local_tensor
Wqkv_local = sharded_attn_module.Wqkv.weight._local_tensor
out_proj_local = sharded_attn_module.out_proj.weight._local_tensor

# Wqkv is colwise-sharded, so its output dimension (dim 0 since torch
# stores everything along the transpose) is sharded along the device mesh
assert Wqkv_local.shape[0] * local_world_size == model_cfg.d_model * 3
assert Wqkv_local.shape[0] * local_world_size == sharded_model_cfg.d_model * 3

# The out projection is row-wise sharded, so its input dimension (dim 1)
# is sharded along the device mesh
assert out_proj_local.shape[1] * local_world_size == model_cfg.d_model
assert out_proj_local.shape[1] * local_world_size == sharded_model_cfg.d_model

# Check that the sharded output weights are the same as the full model
# weights
assert torch.equal(out_proj_local, full_attn_module.out_proj.weight[:, :out_proj_local.shape[1]])

0 comments on commit daf5a4a

Please sign in to comment.