Skip to content

Commit

Permalink
Add shapes test
Browse files Browse the repository at this point in the history
  • Loading branch information
Linden Li committed Nov 19, 2023
1 parent c4da786 commit e547a28
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 18 deletions.
2 changes: 2 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
'type': 'no_scaling',
'factor': 1.0,
},
'tensor_parallel_qkvo': False,
'tp_world_size': None,
}


Expand Down
15 changes: 0 additions & 15 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,6 @@
'ffn_type': 'mptmlp',
}

attn_config_defaults: Dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
'attn_impl': 'triton',
'qk_ln': False,
'clip_qkv': None,
'tensor_parallel_qkvo': False,
'tp_world_size': None,
'softmax_scale': None,
'prefix_lm': False,
'attn_uses_sequence_id': False,
'alibi': False,
'alibi_bias_max': 8,
}

init_config_defaults: Dict = {
'name': 'kaiming_normal_',
'fan_mode': 'fan_in',
Expand Down
5 changes: 2 additions & 3 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import math
import warnings
from functools import cached_property, partial
from functools import partial
from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple,
Union)

Expand Down Expand Up @@ -38,7 +38,7 @@
from omegaconf import OmegaConf as om
from torch.distributed._tensor import (DeviceMesh, Shard, distribute_module,
distribute_tensor)
from torch.distributed.tensor.parallel import (ColwiseParallel, RowwiseParallel,
from torch.distributed.tensor.parallel import (RowwiseParallel,
make_input_replicate_1d,
make_sharded_output_tensor,
parallelize_module)
Expand Down Expand Up @@ -266,7 +266,6 @@ def __init__(self, config: MPTConfig):
mesh_dim_names=['ep', 'tp'],
)
new_blocks = nn.ModuleList()
torch.set_printoptions(profile='full', sci_mode=False)
for block in self.blocks:
qkv_module = block.get_submodule('attn.Wqkv')
oned_mesh = _create_1d_device_mesh(twod_mesh, tp_mesh_dim=1)
Expand Down
58 changes: 58 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import pytest
import torch
import torch.nn as nn
from torch.distributed._tensor.api import DTensor
from accelerate import init_empty_weights
from composer import Trainer
from composer.core.precision import Precision, get_precision_context
from composer.optim import DecoupledAdamW
from composer.trainer.dist_strategy import prepare_fsdp_module
Expand Down Expand Up @@ -1800,3 +1802,59 @@ def test_head_dim_8_triton_mqa_attn(batch_size: int = 2):
output = model(batch)

assert not torch.isnan(output.logits).any()

@pytest.mark.world_size(2)
@pytest.mark.gpu
def test_tp_qkvo():
local_world_size = dist.get_local_world_size()
model_cfg = {
'name': 'mpt_causal_lm',
'init_device': 'cpu',
'd_model': 128,
'n_heads': 4, # head size 32
'n_layers': 2,
'expansion_ratio': 1,
'max_seq_len': 16,
'vocab_size': 50368,
'attn_config': {
'attn_type': 'multihead_attention',
'alibi': False,
'attn_impl': 'torch',
'tensor_parallel_qkvo': True,
'tp_world_size': local_world_size
}
}

model_cfg = om.create(model_cfg)
fsdp_config = {
'sharding_strategy': 'NO_SHARD',
'mixed_precision': 'DEFAULT'
}

model = COMPOSER_MODEL_REGISTRY[model_cfg.name](model_cfg)

# The trainer is used to wrap the model in FSDP, which can be used
# alongside with TP for 2D parallelism
trainer = Trainer(
model=model,
fsdp_config=fsdp_config,
)

transformer_blocks = model.model.transformer.blocks
for block in transformer_blocks:
attn_module = block._fsdp_wrapped_module.attn

# Check that all attention module weights are DTensors
assert isinstance(attn_module.Wqkv.weight, DTensor)
assert isinstance(attn_module.out_proj.weight, DTensor)

Wqkv_local = attn_module.Wqkv.weight._local_tensor
out_proj_local = attn_module.out_proj.weight._local_tensor

# Wqkv is colwise-sharded, so its output dimension (dim 0 since torch
# stores everything along the transpose) is sharded along the device mesh
assert Wqkv_local.shape[0] * local_world_size == model_cfg.d_model * 3

# The out projection is row-wise sharded, so its input dimension (dim 1)
# is sharded along the device mesh
assert out_proj_local.shape[1] * local_world_size == model_cfg.d_model

0 comments on commit e547a28

Please sign in to comment.