From 41924322a4d1db975d7a0ed7bb7310b7cf706ad0 Mon Sep 17 00:00:00 2001 From: Linden Li Date: Mon, 23 Oct 2023 20:23:11 -0700 Subject: [PATCH 1/8] add tensor parallelism --- llmfoundry/models/layers/attention.py | 84 ++++++++++++--- llmfoundry/models/mpt/configuration_mpt.py | 4 + llmfoundry/models/mpt/modeling_mpt.py | 120 ++++++++++++++++++++- 3 files changed, 193 insertions(+), 15 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 39fa7162ac..c8d578cb2d 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -79,6 +79,8 @@ def scaled_multihead_dot_product_attention( training: bool = False, needs_weights: bool = False, multiquery: bool = False, + tensor_parallel_qkvo: bool = False, + tp_world_size: Optional[int] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]: @@ -95,9 +97,19 @@ def scaled_multihead_dot_product_attention( )) kv_n_heads = n_heads - q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) - k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads) - v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads) + if tensor_parallel_qkvo: + assert tp_world_size is not None + q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads // tp_world_size) + k = rearrange(key, + 'b s (h d) -> b h d s', + h=kv_n_heads // tp_world_size) + v = rearrange(value, + 'b s (h d) -> b h s d', + h=kv_n_heads // tp_world_size) + else: + q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) + k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads) + v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads) if past_key_value is not None: # attn_impl: flash & triton use kernels which expect input shape [b, s, h, d_head]. @@ -346,6 +358,9 @@ def triton_flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, + tensor_parallel_qkvo: bool = False, + tp_world_size: Optional[int] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]: try: @@ -429,9 +444,21 @@ def triton_flash_attn_fn( ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min) - query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads) - key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads) - value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads) + if tensor_parallel_qkvo: + assert tp_world_size is not None + query = rearrange(query, + 'b s (h d) -> b s h d', + h=n_heads // tp_world_size) + key = rearrange(key, + 'b s (h d) -> b s h d', + h=kv_n_heads // tp_world_size) + value = rearrange(value, + 'b s (h d) -> b s h d', + h=kv_n_heads // tp_world_size) + else: + query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads) + key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads) + value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads) # multi-query case if kv_n_heads == 1: @@ -473,6 +500,8 @@ def __init__( attn_impl: str = 'triton', clip_qkv: Optional[float] = None, qk_ln: bool = False, + tensor_parallel_qkvo: bool = False, + tp_world_size: Optional[int] = None, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -486,6 +515,9 @@ def __init__( self.clip_qkv = clip_qkv self.qk_ln = qk_ln + self.tensor_parallel_qkvo = tensor_parallel_qkvo + self.tp_world_size = tp_world_size + self.d_model = d_model self.n_heads = n_heads self.kv_n_heads = kv_n_heads @@ -564,14 +596,26 @@ def forward( if self.clip_qkv: qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) - query, key, value = qkv.split( - [ - self.d_model, - self.kv_n_heads * self.head_dim, - self.kv_n_heads * self.head_dim, - ], - dim=2, - ) + if self.tensor_parallel_qkvo: + # If tensor parallelism is used, each of the QKV tensors gets a + # 1 / tp_world_size fraction of the original split. + query, key, value = qkv.split( + [ + self.d_model // self.tp_world_size, + self.kv_n_heads * self.head_dim // self.tp_world_size, + self.kv_n_heads * self.head_dim // self.tp_world_size, + ], + dim=2, + ) + else: + query, key, value = qkv.split( + [ + self.d_model, + self.kv_n_heads * self.head_dim, + self.kv_n_heads * self.head_dim, + ], + dim=2, + ) key_padding_mask = attention_mask @@ -595,6 +639,8 @@ def forward( dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, + tensor_parallel_qkvo=self.tensor_parallel_qkvo, + tp_world_size=self.tp_world_size, ) return self.out_proj(context), attn_weights, past_key_value @@ -614,6 +660,8 @@ def __init__( attn_impl: str = 'triton', clip_qkv: Optional[float] = None, qk_ln: bool = False, + tensor_parallel_qkvo: bool = False, + tp_world_size: Optional[int] = None, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -627,7 +675,11 @@ def __init__( kv_n_heads=n_heads, # for MHA, same # heads as kv groups attn_impl=attn_impl, clip_qkv=clip_qkv, + tensor_parallel_qkvo=tensor_parallel_qkvo, + tp_world_size=tp_world_size, qk_ln=qk_ln, + tensor_parallel_qkvo=tensor_parallel_qkvo, + tp_world_size=tp_world_size, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, @@ -651,6 +703,8 @@ def __init__( attn_impl: str = 'triton', clip_qkv: Optional[float] = None, qk_ln: bool = False, + tensor_parallel_qkvo: bool = False, + tp_world_size: Optional[int] = None, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -665,6 +719,8 @@ def __init__( attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, + tensor_parallel_qkvo=tensor_parallel_qkvo, + tp_world_size=tp_world_size, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 251e4f5caf..9d13617ab7 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -14,6 +14,8 @@ 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, + 'tensor_parallel_qkvo': False, + 'tp_world_size': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, @@ -83,6 +85,8 @@ def __init__( qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to this value. + tensor_parallel_qkvo (bool): Whether to implement tensor parallel attention projections + tp_world_size (Optional[Int]): Must be set if tensor_parallel_qkvo is True. The number of GPUs to use for tensor parallelism. softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, use the default scale of ``1/sqrt(d_keys)``. prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 4f4581b177..1b26e2cdbf 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -8,6 +8,7 @@ import math import warnings +from functools import cached_property, partial from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Union) @@ -22,9 +23,17 @@ InContextLearningQAAccuracy) from composer.metrics.nlp import LanguageCrossEntropy, LanguagePerplexity from composer.models import HuggingFaceModel -from composer.utils import dist +from composer.utils import dist, get_device from omegaconf import DictConfig from omegaconf import OmegaConf as om +from torch.distributed._tensor import (DeviceMesh, Shard, distribute_module, + distribute_tensor) +from torch.distributed.tensor.parallel import (ColwiseParallel, RowwiseParallel, + make_input_replicate_1d, + make_sharded_output_tensor, + parallelize_module) +from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh +from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.modeling_outputs import (BaseModelOutputWithPast, CausalLMOutputWithPast) @@ -75,6 +84,47 @@ class MPTPreTrainedModel(PreTrainedModel): base_model_prefix = 'model' _no_split_modules = ['MPTBlock'] +def rearrange_tensor(t: torch.Tensor, n_devices: int, d_model: int, + head_dim: int, kv_n_heads: int): + # Split output dim into three chunks: query proj. weights, key proj. weights, value proj. weights + # The Wqkv projection is a (n_heads * head_dim + 2 * kv_n_heads * head_dim, d_model)-dim tensor. + # The projection Wqkv(x) in the attention module will eventually be split into + # Q, K, V tensors with the split (n_heads * head_dim, kv_n_heads * head_dim, kv_n_heads * head_dim). + # As a result, each device should have a 1/n_devices fraction of the rows of each chunk responsible for a projection + # in order for numerical equivalence. + t_chunks = torch.split( + t, [d_model, kv_n_heads * head_dim, kv_n_heads * head_dim], dim=0) + + # For each chunk, split d_model (dim=0) into n_devices chunks + # this ends up sampling a 1/n_devices fraction of the rows for each chunk + sub_chunks = [ + torch.chunk(chunk, chunks=n_devices, dim=0) for chunk in t_chunks + ] + + # concatenate the q, k, v chunks for each device + new_chunks = [ + torch.cat([sub_chunk[i] + for sub_chunk in sub_chunks], dim=0) + for i in range(n_devices) + ] + + return torch.cat(new_chunks, dim=0) + + +def shard_qkv( + mod_name: str, + mod: nn.Module, + mesh: DeviceMesh, + d_model: int, + head_dim: int, + kv_n_heads: int, +): + placement = [Shard(0)] + rearr_weight = rearrange_tensor(mod.weight, mesh.size(), d_model, head_dim, + kv_n_heads) + rearr_weight_param = torch.nn.Parameter(rearr_weight) + mod.weight = torch.nn.Parameter( + distribute_tensor(rearr_weight_param, mesh, placement)) class MPTModel(MPTPreTrainedModel): @@ -83,6 +133,8 @@ def __init__(self, config: MPTConfig): super().__init__(config) self.attn_impl = config.attn_config['attn_impl'] + self.tensor_parallel_qkvo = config.attn_config['tensor_parallel_qkvo'] + self.tp_world_size = config.attn_config['tp_world_size'] self.prefix_lm = config.attn_config['prefix_lm'] self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id'] self.alibi = config.attn_config['alibi'] @@ -128,6 +180,72 @@ def __init__(self, config: MPTConfig): f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.' ) self.apply(self.param_init_fn) + + if self.tensor_parallel_qkvo: + device_type = 'cuda' if get_device(None).name == 'gpu' else 'cpu' + world_size = dist.get_world_size() + node_count = world_size // dist.get_local_world_size() + # Configures intranode tensor parallelism + twod_mesh = DeviceMesh( + device_type=device_type, + mesh=torch.arange(0, world_size).view(node_count, -1), + mesh_dim_names=['ep', 'tp'], + ) + new_blocks = nn.ModuleList() + torch.set_printoptions(profile='full', sci_mode=False) + for block in self.blocks: + qkv_module = block.get_submodule('attn.Wqkv') + oned_mesh = _create_1d_device_mesh(twod_mesh, tp_mesh_dim=1) + + kv_n_heads = config.n_heads + if config.attn_config['attn_type'] == 'grouped_query_attention': + kv_n_heads = config.attn_config['kv_n_heads'] + elif config.attn_config['attn_type'] == 'multiquery_attention': + raise NotImplementedError( + 'Tensor parallel currently does not work for multiquery attention.' + ) + + # Megatron trick: + # Shard qkv module column wise + # Shard output projection row wise + # Note: since PyTorch does not support interleaved sharding yet, we need to + # manually rearrange the weight tensors since the QKV projection is fused. + distribute_module( + qkv_module, + oned_mesh, + partition_fn=partial(shard_qkv, + d_model=config.d_model, + head_dim=config.d_model // + config.n_heads, + kv_n_heads=kv_n_heads), + input_fn=make_input_replicate_1d, + output_fn=make_sharded_output_tensor, + ) + + block = parallelize_module( + module=block, + device_mesh=twod_mesh, + parallelize_plan={ + 'attn.out_proj': RowwiseParallel(), + }, + tp_mesh_dim=1, + ) + + # Call to parallelize_module moves params to gpu if they are cpu params. + # Move them back to cpu so that FSDP wrapping sees all params on cpu. + # Othewise FSDP wrapping fails as it sees some params on cpu and others on gpu. + assert config.init_device == 'cpu' + if config.init_device == 'cpu': + block = block.to('cpu') + new_blocks.append(block) + self.blocks = new_blocks + print('Tensor parallelism initialized...') + + # This call is needed to register the hooks to be compatible with FSDP + if not enable_2d_with_fsdp(): + raise RuntimeError( + 'Failed to enable 2D parallelism with FSDP. Please check your environment.' + ) self.is_causal = not self.prefix_lm From 0e84450ff26dcdf197782cc3f8a1107956c528eb Mon Sep 17 00:00:00 2001 From: Linden Li Date: Tue, 24 Oct 2023 03:53:13 +0000 Subject: [PATCH 2/8] fix attention --- llmfoundry/models/layers/attention.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index c8d578cb2d..596d6e7004 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -675,8 +675,6 @@ def __init__( kv_n_heads=n_heads, # for MHA, same # heads as kv groups attn_impl=attn_impl, clip_qkv=clip_qkv, - tensor_parallel_qkvo=tensor_parallel_qkvo, - tp_world_size=tp_world_size, qk_ln=qk_ln, tensor_parallel_qkvo=tensor_parallel_qkvo, tp_world_size=tp_world_size, From e547a28d6e128f6777360e0b8662226ea43f4ff7 Mon Sep 17 00:00:00 2001 From: Linden Li Date: Sun, 19 Nov 2023 01:03:42 +0000 Subject: [PATCH 3/8] Add shapes test --- llmfoundry/models/layers/blocks.py | 2 + llmfoundry/models/mpt/configuration_mpt.py | 15 ------ llmfoundry/models/mpt/modeling_mpt.py | 5 +- tests/test_model.py | 58 ++++++++++++++++++++++ 4 files changed, 62 insertions(+), 18 deletions(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 6605807c6b..2bd678ddb1 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -35,6 +35,8 @@ 'type': 'no_scaling', 'factor': 1.0, }, + 'tensor_parallel_qkvo': False, + 'tp_world_size': None, } diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index a30ee655dd..0df6f7c29a 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -22,21 +22,6 @@ 'ffn_type': 'mptmlp', } -attn_config_defaults: Dict = { - 'attn_type': 'multihead_attention', - 'attn_pdrop': 0.0, - 'attn_impl': 'triton', - 'qk_ln': False, - 'clip_qkv': None, - 'tensor_parallel_qkvo': False, - 'tp_world_size': None, - 'softmax_scale': None, - 'prefix_lm': False, - 'attn_uses_sequence_id': False, - 'alibi': False, - 'alibi_bias_max': 8, -} - init_config_defaults: Dict = { 'name': 'kaiming_normal_', 'fan_mode': 'fan_in', diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index dfb26967ae..e1ed15f520 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -8,7 +8,7 @@ import math import warnings -from functools import cached_property, partial +from functools import partial from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Union) @@ -38,7 +38,7 @@ from omegaconf import OmegaConf as om from torch.distributed._tensor import (DeviceMesh, Shard, distribute_module, distribute_tensor) -from torch.distributed.tensor.parallel import (ColwiseParallel, RowwiseParallel, +from torch.distributed.tensor.parallel import (RowwiseParallel, make_input_replicate_1d, make_sharded_output_tensor, parallelize_module) @@ -266,7 +266,6 @@ def __init__(self, config: MPTConfig): mesh_dim_names=['ep', 'tp'], ) new_blocks = nn.ModuleList() - torch.set_printoptions(profile='full', sci_mode=False) for block in self.blocks: qkv_module = block.get_submodule('attn.Wqkv') oned_mesh = _create_1d_device_mesh(twod_mesh, tp_mesh_dim=1) diff --git a/tests/test_model.py b/tests/test_model.py index c160c064dc..8c6c78e2d3 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -12,7 +12,9 @@ import pytest import torch import torch.nn as nn +from torch.distributed._tensor.api import DTensor from accelerate import init_empty_weights +from composer import Trainer from composer.core.precision import Precision, get_precision_context from composer.optim import DecoupledAdamW from composer.trainer.dist_strategy import prepare_fsdp_module @@ -1800,3 +1802,59 @@ def test_head_dim_8_triton_mqa_attn(batch_size: int = 2): output = model(batch) assert not torch.isnan(output.logits).any() + +@pytest.mark.world_size(2) +@pytest.mark.gpu +def test_tp_qkvo(): + local_world_size = dist.get_local_world_size() + model_cfg = { + 'name': 'mpt_causal_lm', + 'init_device': 'cpu', + 'd_model': 128, + 'n_heads': 4, # head size 32 + 'n_layers': 2, + 'expansion_ratio': 1, + 'max_seq_len': 16, + 'vocab_size': 50368, + 'attn_config': { + 'attn_type': 'multihead_attention', + 'alibi': False, + 'attn_impl': 'torch', + 'tensor_parallel_qkvo': True, + 'tp_world_size': local_world_size + } + } + + model_cfg = om.create(model_cfg) + fsdp_config = { + 'sharding_strategy': 'NO_SHARD', + 'mixed_precision': 'DEFAULT' + } + + model = COMPOSER_MODEL_REGISTRY[model_cfg.name](model_cfg) + + # The trainer is used to wrap the model in FSDP, which can be used + # alongside with TP for 2D parallelism + trainer = Trainer( + model=model, + fsdp_config=fsdp_config, + ) + + transformer_blocks = model.model.transformer.blocks + for block in transformer_blocks: + attn_module = block._fsdp_wrapped_module.attn + + # Check that all attention module weights are DTensors + assert isinstance(attn_module.Wqkv.weight, DTensor) + assert isinstance(attn_module.out_proj.weight, DTensor) + + Wqkv_local = attn_module.Wqkv.weight._local_tensor + out_proj_local = attn_module.out_proj.weight._local_tensor + + # Wqkv is colwise-sharded, so its output dimension (dim 0 since torch + # stores everything along the transpose) is sharded along the device mesh + assert Wqkv_local.shape[0] * local_world_size == model_cfg.d_model * 3 + + # The out projection is row-wise sharded, so its input dimension (dim 1) + # is sharded along the device mesh + assert out_proj_local.shape[1] * local_world_size == model_cfg.d_model \ No newline at end of file From daf5a4a31faa5802a7b6dc798335801bccc883e1 Mon Sep 17 00:00:00 2001 From: Linden Li Date: Sun, 19 Nov 2023 01:49:41 +0000 Subject: [PATCH 4/8] Add weight test --- tests/test_model.py | 59 +++++++++++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 8c6c78e2d3..df46fd18b6 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1806,8 +1806,14 @@ def test_head_dim_8_triton_mqa_attn(batch_size: int = 2): @pytest.mark.world_size(2) @pytest.mark.gpu def test_tp_qkvo(): + # Note: we need the RNG state in this test to ensure that weights + # are initialized with the same values in both models. Without it, + # even with a random seed, the weights will be different since the + # RNG state changes with each init. + rng_state = reproducibility.get_rng_state() + local_world_size = dist.get_local_world_size() - model_cfg = { + sharded_model_cfg = { 'name': 'mpt_causal_lm', 'init_device': 'cpu', 'd_model': 128, @@ -1825,36 +1831,59 @@ def test_tp_qkvo(): } } - model_cfg = om.create(model_cfg) + # Create the same model config, but with TP turned off + full_model_cfg = copy.deepcopy(sharded_model_cfg) + full_model_cfg['attn_config']['tensor_parallel_qkvo'] = False + del full_model_cfg['attn_config']['tp_world_size'] + + sharded_model_cfg = om.create(sharded_model_cfg) + full_model_cfg = om.create(full_model_cfg) + + sharded_model = COMPOSER_MODEL_REGISTRY[sharded_model_cfg.name](sharded_model_cfg) + reproducibility.load_rng_state(rng_state) + + full_model = COMPOSER_MODEL_REGISTRY[full_model_cfg.name](full_model_cfg) + reproducibility.load_rng_state(rng_state) + fsdp_config = { 'sharding_strategy': 'NO_SHARD', 'mixed_precision': 'DEFAULT' } - - model = COMPOSER_MODEL_REGISTRY[model_cfg.name](model_cfg) - # The trainer is used to wrap the model in FSDP, which can be used # alongside with TP for 2D parallelism trainer = Trainer( - model=model, + model=sharded_model, + fsdp_config=fsdp_config, + seed=0 + ) + + trainer = Trainer( + model=full_model, fsdp_config=fsdp_config, + seed=0 ) - transformer_blocks = model.model.transformer.blocks - for block in transformer_blocks: - attn_module = block._fsdp_wrapped_module.attn + sharded_transformer_blocks = sharded_model.model.transformer.blocks + full_transformer_blocks = full_model.model.transformer.blocks + for sharded_block, full_block in zip(sharded_transformer_blocks, full_transformer_blocks): + sharded_attn_module = sharded_block._fsdp_wrapped_module.attn + full_attn_module = full_block._fsdp_wrapped_module.attn # Check that all attention module weights are DTensors - assert isinstance(attn_module.Wqkv.weight, DTensor) - assert isinstance(attn_module.out_proj.weight, DTensor) + assert isinstance(sharded_attn_module.Wqkv.weight, DTensor) + assert isinstance(sharded_attn_module.out_proj.weight, DTensor) - Wqkv_local = attn_module.Wqkv.weight._local_tensor - out_proj_local = attn_module.out_proj.weight._local_tensor + Wqkv_local = sharded_attn_module.Wqkv.weight._local_tensor + out_proj_local = sharded_attn_module.out_proj.weight._local_tensor # Wqkv is colwise-sharded, so its output dimension (dim 0 since torch # stores everything along the transpose) is sharded along the device mesh - assert Wqkv_local.shape[0] * local_world_size == model_cfg.d_model * 3 + assert Wqkv_local.shape[0] * local_world_size == sharded_model_cfg.d_model * 3 # The out projection is row-wise sharded, so its input dimension (dim 1) # is sharded along the device mesh - assert out_proj_local.shape[1] * local_world_size == model_cfg.d_model \ No newline at end of file + assert out_proj_local.shape[1] * local_world_size == sharded_model_cfg.d_model + + # Check that the sharded output weights are the same as the full model + # weights + assert torch.equal(out_proj_local, full_attn_module.out_proj.weight[:, :out_proj_local.shape[1]]) \ No newline at end of file From 372255c212ac13f2db8683ff7724bb575d31b79e Mon Sep 17 00:00:00 2001 From: Linden Li Date: Sun, 19 Nov 2023 01:53:50 +0000 Subject: [PATCH 5/8] tests actually pass now --- tests/test_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index df46fd18b6..82461de7f3 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1885,5 +1885,9 @@ def test_tp_qkvo(): assert out_proj_local.shape[1] * local_world_size == sharded_model_cfg.d_model # Check that the sharded output weights are the same as the full model - # weights - assert torch.equal(out_proj_local, full_attn_module.out_proj.weight[:, :out_proj_local.shape[1]]) \ No newline at end of file + # weights - rank 0 should have the top half and rank 1 should have the + # bottom half + if dist.get_local_rank() == 0: + assert torch.equal(out_proj_local, full_attn_module.out_proj.weight[:, :out_proj_local.shape[1]]) + else: + assert torch.equal(out_proj_local, full_attn_module.out_proj.weight[:, out_proj_local.shape[1]:]) \ No newline at end of file From d5bba2e820d34ef2c6d43fff4c196d91fe17a282 Mon Sep 17 00:00:00 2001 From: Linden Li Date: Sun, 19 Nov 2023 01:54:51 +0000 Subject: [PATCH 6/8] get rid of unnecessary rng stage calls --- tests/test_model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 82461de7f3..41a7bc852b 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1843,7 +1843,6 @@ def test_tp_qkvo(): reproducibility.load_rng_state(rng_state) full_model = COMPOSER_MODEL_REGISTRY[full_model_cfg.name](full_model_cfg) - reproducibility.load_rng_state(rng_state) fsdp_config = { 'sharding_strategy': 'NO_SHARD', @@ -1854,13 +1853,11 @@ def test_tp_qkvo(): trainer = Trainer( model=sharded_model, fsdp_config=fsdp_config, - seed=0 ) trainer = Trainer( model=full_model, fsdp_config=fsdp_config, - seed=0 ) sharded_transformer_blocks = sharded_model.model.transformer.blocks From 6a2b18a4b0efb9ff561458957d8318a23511d2d0 Mon Sep 17 00:00:00 2001 From: Linden Li Date: Sun, 19 Nov 2023 02:12:38 +0000 Subject: [PATCH 7/8] Add other weight test --- tests/test_model.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 41a7bc852b..8ec19bb27b 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -33,6 +33,7 @@ from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM +from llmfoundry.models.mpt.modeling_mpt import rearrange_tensor from llmfoundry.utils import build_tokenizer @@ -1877,14 +1878,24 @@ def test_tp_qkvo(): # stores everything along the transpose) is sharded along the device mesh assert Wqkv_local.shape[0] * local_world_size == sharded_model_cfg.d_model * 3 - # The out projection is row-wise sharded, so its input dimension (dim 1) + # The out projection is rowwise-sharded, so its input dimension (dim 1) # is sharded along the device mesh assert out_proj_local.shape[1] * local_world_size == sharded_model_cfg.d_model + Wqkv_interleaved = rearrange_tensor( + full_attn_module.Wqkv.weight, + local_world_size, + sharded_model_cfg.d_model, + sharded_model_cfg.d_model // sharded_model_cfg.n_heads, + sharded_model_cfg.n_heads + ) # Check that the sharded output weights are the same as the full model - # weights - rank 0 should have the top half and rank 1 should have the - # bottom half + # weights: + # rank 0 should have the top half of out proj and the left half of Wqkv + # rank 1 should have the bottom half of out proj and the right half of Wqkv if dist.get_local_rank() == 0: assert torch.equal(out_proj_local, full_attn_module.out_proj.weight[:, :out_proj_local.shape[1]]) + assert torch.equal(Wqkv_local, Wqkv_interleaved[:Wqkv_local.shape[0], :]) else: - assert torch.equal(out_proj_local, full_attn_module.out_proj.weight[:, out_proj_local.shape[1]:]) \ No newline at end of file + assert torch.equal(out_proj_local, full_attn_module.out_proj.weight[:, out_proj_local.shape[1]:]) + assert torch.equal(Wqkv_local, Wqkv_interleaved[Wqkv_local.shape[0]:, :]) \ No newline at end of file From 6d201656435f3da4521ce6c75169141711bca572 Mon Sep 17 00:00:00 2001 From: Linden Li Date: Tue, 5 Dec 2023 09:40:44 -0800 Subject: [PATCH 8/8] address comments --- llmfoundry/models/mpt/modeling_mpt.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 1b26e2cdbf..e27a26f13a 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -234,12 +234,11 @@ def __init__(self, config: MPTConfig): # Call to parallelize_module moves params to gpu if they are cpu params. # Move them back to cpu so that FSDP wrapping sees all params on cpu. # Othewise FSDP wrapping fails as it sees some params on cpu and others on gpu. - assert config.init_device == 'cpu' - if config.init_device == 'cpu': - block = block.to('cpu') + assert config.init_device == 'cpu', "config.init_device must be 'cpu' when using tensor parallelism." + block = block.to('cpu') new_blocks.append(block) self.blocks = new_blocks - print('Tensor parallelism initialized...') + log.info('Tensor parallelism initialized...') # This call is needed to register the hooks to be compatible with FSDP if not enable_2d_with_fsdp():