Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor parallelism for attention QKVO #690

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 67 additions & 14 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def scaled_multihead_dot_product_attention(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
tensor_parallel_qkvo: bool = False,
tp_world_size: Optional[int] = None
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:

Expand All @@ -106,9 +108,19 @@ def scaled_multihead_dot_product_attention(
))
kv_n_heads = n_heads

q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
if tensor_parallel_qkvo:
assert tp_world_size is not None
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads // tp_world_size)
k = rearrange(key,
'b s (h d) -> b h d s',
h=kv_n_heads // tp_world_size)
v = rearrange(value,
'b s (h d) -> b h s d',
h=kv_n_heads // tp_world_size)
else:
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)

if past_key_value is not None:
# attn_impl: flash & triton use kernels which expect input shape [b, s, h, d_head].
Expand Down Expand Up @@ -357,6 +369,8 @@ def triton_flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
tensor_parallel_qkvo: bool = False,
tp_world_size: Optional[int] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
try:
Expand Down Expand Up @@ -440,9 +454,21 @@ def triton_flash_attn_fn(
~key_padding_mask.view((b_size, 1, 1, s_k)),
torch.finfo(query.dtype).min)

query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads)
value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads)
if tensor_parallel_qkvo:
assert tp_world_size is not None
query = rearrange(query,
'b s (h d) -> b s h d',
h=n_heads // tp_world_size)
key = rearrange(key,
'b s (h d) -> b s h d',
h=kv_n_heads // tp_world_size)
value = rearrange(value,
'b s (h d) -> b s h d',
h=kv_n_heads // tp_world_size)
else:
query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads)
value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads)

# multi-query case
if kv_n_heads == 1:
Expand Down Expand Up @@ -484,6 +510,8 @@ def __init__(
attn_impl: str = 'triton',
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
tensor_parallel_qkvo: bool = False,
tp_world_size: Optional[int] = None,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
Expand All @@ -497,6 +525,9 @@ def __init__(
self.clip_qkv = clip_qkv
self.qk_ln = qk_ln

self.tensor_parallel_qkvo = tensor_parallel_qkvo
self.tp_world_size = tp_world_size

self.d_model = d_model
self.n_heads = n_heads
self.kv_n_heads = kv_n_heads
Expand Down Expand Up @@ -576,14 +607,26 @@ def forward(
if self.clip_qkv:
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)

query, key, value = qkv.split(
[
self.d_model,
self.kv_n_heads * self.head_dim,
self.kv_n_heads * self.head_dim,
],
dim=2,
)
if self.tensor_parallel_qkvo:
# If tensor parallelism is used, each of the QKV tensors gets a
# 1 / tp_world_size fraction of the original split.
query, key, value = qkv.split(
[
self.d_model // self.tp_world_size,
self.kv_n_heads * self.head_dim // self.tp_world_size,
self.kv_n_heads * self.head_dim // self.tp_world_size,
],
dim=2,
)
else:
query, key, value = qkv.split(
[
self.d_model,
self.kv_n_heads * self.head_dim,
self.kv_n_heads * self.head_dim,
],
dim=2,
)

key_padding_mask = attention_mask

Expand Down Expand Up @@ -640,6 +683,8 @@ def forward(
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
tensor_parallel_qkvo=self.tensor_parallel_qkvo,
tp_world_size=self.tp_world_size,
)

return self.out_proj(context), attn_weights, past_key_value
Expand All @@ -659,6 +704,8 @@ def __init__(
attn_impl: str = 'triton',
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
tensor_parallel_qkvo: bool = False,
tp_world_size: Optional[int] = None,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
Expand All @@ -673,6 +720,8 @@ def __init__(
attn_impl=attn_impl,
clip_qkv=clip_qkv,
qk_ln=qk_ln,
tensor_parallel_qkvo=tensor_parallel_qkvo,
tp_world_size=tp_world_size,
softmax_scale=softmax_scale,
attn_pdrop=attn_pdrop,
norm_type=norm_type,
Expand All @@ -696,6 +745,8 @@ def __init__(
attn_impl: str = 'triton',
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
tensor_parallel_qkvo: bool = False,
tp_world_size: Optional[int] = None,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
Expand All @@ -710,6 +761,8 @@ def __init__(
attn_impl=attn_impl,
clip_qkv=clip_qkv,
qk_ln=qk_ln,
tensor_parallel_qkvo=tensor_parallel_qkvo,
tp_world_size=tp_world_size,
softmax_scale=softmax_scale,
attn_pdrop=attn_pdrop,
norm_type=norm_type,
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
'type': 'no_scaling',
'factor': 1.0,
},
'tensor_parallel_qkvo': False,
'tp_world_size': None,
}


Expand Down
3 changes: 2 additions & 1 deletion llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Any, Dict, Optional, Union

from transformers import PretrainedConfig

from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.layers.blocks import attn_config_defaults

Expand Down Expand Up @@ -82,6 +81,8 @@ def __init__(
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
this value.
tensor_parallel_qkvo (bool): Whether to implement tensor parallel attention projections
tp_world_size (Optional[Int]): Must be set if tensor_parallel_qkvo is True. The number of GPUs to use for tensor parallelism.
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
use the default scale of ``1/sqrt(d_keys)``.
prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
Expand Down
119 changes: 118 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import math
import warnings
from functools import partial
from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple,
Union)

Expand All @@ -22,7 +23,7 @@
InContextLearningQAAccuracy)
from composer.metrics.nlp import LanguageCrossEntropy, LanguagePerplexity
from composer.models import HuggingFaceModel
from composer.utils import dist
from composer.utils import dist, get_device

from llmfoundry.models.layers.attention import is_flash_v2_installed

Expand All @@ -35,6 +36,14 @@

from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from torch.distributed._tensor import (DeviceMesh, Shard, distribute_module,
distribute_tensor)
from torch.distributed.tensor.parallel import (RowwiseParallel,
make_input_replicate_1d,
make_sharded_output_tensor,
parallelize_module)
from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh
from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast)
Expand Down Expand Up @@ -137,6 +146,47 @@ class MPTPreTrainedModel(PreTrainedModel):
base_model_prefix = 'model'
_no_split_modules = ['MPTBlock']

def rearrange_tensor(t: torch.Tensor, n_devices: int, d_model: int,
head_dim: int, kv_n_heads: int):
# Split output dim into three chunks: query proj. weights, key proj. weights, value proj. weights
# The Wqkv projection is a (n_heads * head_dim + 2 * kv_n_heads * head_dim, d_model)-dim tensor.
# The projection Wqkv(x) in the attention module will eventually be split into
# Q, K, V tensors with the split (n_heads * head_dim, kv_n_heads * head_dim, kv_n_heads * head_dim).
# As a result, each device should have a 1/n_devices fraction of the rows of each chunk responsible for a projection
# in order for numerical equivalence.
t_chunks = torch.split(
t, [d_model, kv_n_heads * head_dim, kv_n_heads * head_dim], dim=0)

# For each chunk, split d_model (dim=0) into n_devices chunks
# this ends up sampling a 1/n_devices fraction of the rows for each chunk
sub_chunks = [
torch.chunk(chunk, chunks=n_devices, dim=0) for chunk in t_chunks
]

# concatenate the q, k, v chunks for each device
new_chunks = [
torch.cat([sub_chunk[i]
for sub_chunk in sub_chunks], dim=0)
for i in range(n_devices)
]

return torch.cat(new_chunks, dim=0)


def shard_qkv(
mod_name: str,
mod: nn.Module,
mesh: DeviceMesh,
d_model: int,
head_dim: int,
kv_n_heads: int,
):
placement = [Shard(0)]
rearr_weight = rearrange_tensor(mod.weight, mesh.size(), d_model, head_dim,
kv_n_heads)
rearr_weight_param = torch.nn.Parameter(rearr_weight)
mod.weight = torch.nn.Parameter(
distribute_tensor(rearr_weight_param, mesh, placement))

class MPTModel(MPTPreTrainedModel):

Expand All @@ -145,6 +195,8 @@ def __init__(self, config: MPTConfig):
super().__init__(config)

self.attn_impl = config.attn_config['attn_impl']
self.tensor_parallel_qkvo = config.attn_config['tensor_parallel_qkvo']
self.tp_world_size = config.attn_config['tp_world_size']
self.prefix_lm = config.attn_config['prefix_lm']
self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
self.alibi = config.attn_config['alibi']
Expand Down Expand Up @@ -202,6 +254,71 @@ def __init__(self, config: MPTConfig):
f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.'
)
self.apply(self.param_init_fn)

if self.tensor_parallel_qkvo:
device_type = 'cuda' if get_device(None).name == 'gpu' else 'cpu'
world_size = dist.get_world_size()
node_count = world_size // dist.get_local_world_size()
# Configures intranode tensor parallelism
twod_mesh = DeviceMesh(
device_type=device_type,
mesh=torch.arange(0, world_size).view(node_count, -1),
mesh_dim_names=['ep', 'tp'],
)
new_blocks = nn.ModuleList()
for block in self.blocks:
qkv_module = block.get_submodule('attn.Wqkv')
oned_mesh = _create_1d_device_mesh(twod_mesh, tp_mesh_dim=1)

kv_n_heads = config.n_heads
if config.attn_config['attn_type'] == 'grouped_query_attention':
kv_n_heads = config.attn_config['kv_n_heads']
elif config.attn_config['attn_type'] == 'multiquery_attention':
raise NotImplementedError(
'Tensor parallel currently does not work for multiquery attention.'
)

# Megatron trick:
# Shard qkv module column wise
# Shard output projection row wise
# Note: since PyTorch does not support interleaved sharding yet, we need to
# manually rearrange the weight tensors since the QKV projection is fused.
distribute_module(
qkv_module,
oned_mesh,
partition_fn=partial(shard_qkv,
d_model=config.d_model,
head_dim=config.d_model //
config.n_heads,
kv_n_heads=kv_n_heads),
input_fn=make_input_replicate_1d,
output_fn=make_sharded_output_tensor,
)

block = parallelize_module(
module=block,
device_mesh=twod_mesh,
parallelize_plan={
'attn.out_proj': RowwiseParallel(),
},
tp_mesh_dim=1,
)

# Call to parallelize_module moves params to gpu if they are cpu params.
# Move them back to cpu so that FSDP wrapping sees all params on cpu.
# Othewise FSDP wrapping fails as it sees some params on cpu and others on gpu.
assert config.init_device == 'cpu'
if config.init_device == 'cpu':
block = block.to('cpu')
new_blocks.append(block)
linden-li marked this conversation as resolved.
Show resolved Hide resolved
self.blocks = new_blocks
print('Tensor parallelism initialized...')
linden-li marked this conversation as resolved.
Show resolved Hide resolved

# This call is needed to register the hooks to be compatible with FSDP
if not enable_2d_with_fsdp():
raise RuntimeError(
'Failed to enable 2D parallelism with FSDP. Please check your environment.'
)

self.is_causal = not self.prefix_lm

Expand Down
Loading
Loading