Skip to content

Commit

Permalink
Add option to unfuse Wqkv (#1367)
Browse files Browse the repository at this point in the history
* yo

* logging

* bro

* yo

* yo

* yo

* lint

* if

* datest

* webacc

* liny
  • Loading branch information
snarayan21 authored Jul 17, 2024
1 parent 93dd4f3 commit 221d252
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 14 deletions.
71 changes: 57 additions & 14 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def __init__(
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
qk_gn: bool = False,
fused_qkv: bool = True,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
Expand All @@ -426,6 +427,7 @@ def __init__(
self.clip_qkv = clip_qkv
self.qk_ln = qk_ln
self.qk_gn = qk_gn
self.fused_qkv = fused_qkv

self.d_model = d_model
self.n_heads = n_heads
Expand Down Expand Up @@ -462,7 +464,17 @@ def __init__(
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop

if self.reuse_kv_layer_idx is None:
if self.reuse_kv_layer_idx is not None:
self.Wq = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)]
self.Wq._fused = (0, fuse_splits)
elif self.fused_qkv:
self.Wqkv = build_fc(
name=fc_type_name,
in_features=self.d_model,
Expand All @@ -482,9 +494,26 @@ def __init__(
out_features=self.d_model,
fc_kwargs=fc_type,
)
self.Wk = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
self.Wv = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)]
self.Wq._fused = (0, fuse_splits)
q_fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)]
kv_fuse_splits = [
i * self.head_dim for i in range(1, self.kv_n_heads)
]
self.Wq._fused = (0, q_fuse_splits)
self.Wk._fused = (0, kv_fuse_splits)
self.Wv._fused = (0, kv_fuse_splits)

if self.qk_ln or self.qk_gn:
norm_size = self.head_dim if qk_gn else d_model
Expand Down Expand Up @@ -601,19 +630,29 @@ def get_qkv(
query = self.q_ln(query).to(dtype).view(q_shape)
return query, key, value

qkv = self.Wqkv(x)
if self.fused_qkv:
qkv = self.Wqkv(x)

if self.clip_qkv:
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
if self.clip_qkv:
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)

query, key, value = qkv.split(
[
self.d_model,
self.kv_n_heads * self.head_dim,
self.kv_n_heads * self.head_dim,
],
dim=2,
)
else:
query = self.Wq(x)
key = self.Wk(x)
value = self.Wv(x)

query, key, value = qkv.split(
[
self.d_model,
self.kv_n_heads * self.head_dim,
self.kv_n_heads * self.head_dim,
],
dim=2,
)
if self.clip_qkv:
query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv)
key = key.clamp(min=-self.clip_qkv, max=self.clip_qkv)
value = value.clamp(min=-self.clip_qkv, max=self.clip_qkv)

if self.qk_ln or self.qk_gn:
# Applying layernorm to qk
Expand Down Expand Up @@ -753,6 +792,7 @@ def __init__(
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
qk_gn: bool = False,
fused_qkv: bool = True,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
Expand All @@ -770,6 +810,7 @@ def __init__(
clip_qkv=clip_qkv,
qk_ln=qk_ln,
qk_gn=qk_gn,
fused_qkv=fused_qkv,
softmax_scale=softmax_scale,
attn_pdrop=attn_pdrop,
norm_type=norm_type,
Expand All @@ -796,6 +837,7 @@ def __init__(
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
qk_gn: bool = False,
fused_qkv: bool = True,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
Expand All @@ -813,6 +855,7 @@ def __init__(
clip_qkv=clip_qkv,
qk_ln=qk_ln,
qk_gn=qk_gn,
fused_qkv=fused_qkv,
softmax_scale=softmax_scale,
attn_pdrop=attn_pdrop,
norm_type=norm_type,
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def __init__(
attn_impl (str): The attention implementation to use. One of 'torch' or 'flash'.
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer.
fused_qkv (bool): Whether to fuse the Wq, Wk, and Wv weight matrices in the attention layer. If True, the weights are fused into a single
Wqkv matrix, which can be faster for matmuls. If False, the weights are kept separate. Defaults to True.
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
this value.
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/models/utils/config_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
'attn_impl': 'flash',
'qk_ln': False,
'qk_gn': False,
'fused_qkv': True,
'clip_qkv': None,
'softmax_scale': None,
'attn_uses_sequence_id': False,
Expand Down
160 changes: 160 additions & 0 deletions tests/models/layers/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from llmfoundry.models.layers.layer_builders import build_attention_layer


@pytest.mark.parametrize(
'attn_name',
['multihead_attention', 'grouped_query_attention', 'multiquery_attention'],
)
@pytest.mark.parametrize('dim', [1024])
def test_unfused_wqkv(attn_name: str, dim: int):
d_head = 128
n_heads = dim // d_head

generic_attn_kwargs = {
'd_model': dim,
'n_heads': n_heads,
'fc_type': {
'name': 'torch',
},
'device': 'cpu',
'attn_pdrop': 0.0,
'attn_impl': 'torch',
'qk_ln': False,
'qk_gn': False,
'clip_qkv': None,
'softmax_scale': None,
'sliding_window_size': -1,
}

if attn_name == 'grouped_query_attention':
kv_n_heads = 2
generic_attn_kwargs['kv_n_heads'] = kv_n_heads
elif attn_name == 'multiquery_attention':
kv_n_heads = 1
elif attn_name == 'multihead_attention':
kv_n_heads = n_heads
else:
raise ValueError(f'Unknown attention name: {attn_name}')

attn_config_fused = generic_attn_kwargs.copy()
attn_config_fused['fused_qkv'] = True

attn_config_unfused = generic_attn_kwargs.copy()
attn_config_unfused['fused_qkv'] = False

attn_fused = build_attention_layer(
name=attn_name,
attn_kwargs=attn_config_fused,
)
attn_unfused = build_attention_layer(
name=attn_name,
attn_kwargs=attn_config_unfused,
)

# Make sure unfused attention has the same params as the fused one.
fused_wqkv = attn_fused.Wqkv.weight.detach().clone()
kv_heads_len = (fused_wqkv.shape[0] - dim) // 2
Wq_shape_before = (attn_unfused.Wq.weight.shape, attn_unfused.Wq.bias.shape)
Wk_shape_before = (attn_unfused.Wk.weight.shape, attn_unfused.Wk.bias.shape)
Wv_shape_before = (attn_unfused.Wv.weight.shape, attn_unfused.Wv.bias.shape)

attn_unfused.Wq.weight.data = fused_wqkv[:dim, :]
attn_unfused.Wk.weight.data = fused_wqkv[dim:dim + kv_heads_len, :]
attn_unfused.Wv.weight.data = fused_wqkv[dim + kv_heads_len:, :]
attn_unfused.out_proj.weight.data = attn_fused.out_proj.weight
attn_unfused.Wq.bias.data = attn_fused.Wqkv.bias[:dim]
attn_unfused.Wk.bias.data = attn_fused.Wqkv.bias[dim:dim + kv_heads_len]
attn_unfused.Wv.bias.data = attn_fused.Wqkv.bias[dim + kv_heads_len:]
attn_unfused.out_proj.bias.data = attn_fused.out_proj.bias

# Make sure initialization fuse splits are as expected.
all_fuse_splits = (
0,
[i * d_head for i in range(1, n_heads + 2 * kv_n_heads)],
)
q_fuse_splits = (0, [i * d_head for i in range(1, n_heads)])
kv_fuse_splits = (0, [i * d_head for i in range(1, kv_n_heads)])

assert attn_fused.Wqkv._fused == all_fuse_splits
assert attn_unfused.Wq._fused == q_fuse_splits
assert attn_unfused.Wk._fused == kv_fuse_splits
assert attn_unfused.Wv._fused == kv_fuse_splits

assert torch.allclose(
attn_fused.Wqkv.weight,
torch.cat(
[
attn_unfused.Wq.weight,
attn_unfused.Wk.weight,
attn_unfused.Wv.weight,
],
dim=0,
),
)
assert torch.allclose(
attn_fused.Wqkv.bias,
torch.cat(
[
attn_unfused.Wq.bias,
attn_unfused.Wk.bias,
attn_unfused.Wv.bias,
],
dim=0,
),
)
assert torch.allclose(
attn_fused.out_proj.weight,
attn_unfused.out_proj.weight,
)
assert torch.allclose(attn_fused.out_proj.bias, attn_unfused.out_proj.bias)

assert Wq_shape_before == (
attn_unfused.Wq.weight.shape,
attn_unfused.Wq.bias.shape,
)
assert Wk_shape_before == (
attn_unfused.Wk.weight.shape,
attn_unfused.Wk.bias.shape,
)
assert Wv_shape_before == (
attn_unfused.Wv.weight.shape,
attn_unfused.Wv.bias.shape,
)

x1 = torch.randn(1, 1, dim)
x2 = x1.detach().clone()
x1.requires_grad = True
x2.requires_grad = True

out_fused, _, _ = attn_fused(x1)
out_unfused, _, _ = attn_unfused(x2)

assert torch.allclose(out_fused, out_unfused)

# Dummy loss function is simply the sum.
loss_fused = out_fused.sum()
loss_fused.backward()

loss_unfused = out_unfused.sum()
loss_unfused.backward()

assert isinstance(x1.grad, torch.Tensor)
assert isinstance(x2.grad, torch.Tensor)
assert torch.allclose(x1.grad, x2.grad)
combined_grad = torch.concat(
[
attn_unfused.Wq.weight.grad,
attn_unfused.Wk.weight.grad,
attn_unfused.Wv.weight.grad,
],
dim=0,
)
assert isinstance(attn_fused.Wqkv.weight.grad, torch.Tensor)
assert isinstance(combined_grad, torch.Tensor)
assert torch.allclose(attn_fused.Wqkv.weight.grad, combined_grad)

0 comments on commit 221d252

Please sign in to comment.