Skip to content

Commit

Permalink
Shashank/seq id flash attn (#738)
Browse files Browse the repository at this point in the history
* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* Update llmfoundry/models/layers/attention.py

Co-authored-by: Vitaliy Chiley <[email protected]>

* Update llmfoundry/models/mpt/modeling_mpt.py

Co-authored-by: Vitaliy Chiley <[email protected]>

* Update llmfoundry/models/mpt/modeling_mpt.py

Co-authored-by: Vitaliy Chiley <[email protected]>

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

---------

Co-authored-by: Vitaliy Chiley <[email protected]>
  • Loading branch information
ShashankMosaicML and vchiley authored Dec 4, 2023
1 parent b2e4b0e commit 84b5d96
Show file tree
Hide file tree
Showing 7 changed files with 613 additions and 76 deletions.
104 changes: 65 additions & 39 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def scaled_multihead_dot_product_attention(
multiquery: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:

if multiquery:
warnings.warn(
DeprecationWarning(
Expand Down Expand Up @@ -219,6 +218,9 @@ def flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
try:
Expand Down Expand Up @@ -249,58 +251,65 @@ def flash_attn_fn(

past_key_value = (key, value)

if attn_bias is not None:
# clamp to 0 necessary for torch 2.0 compile()
_s_q = max(0, attn_bias.size(2) - query.size(1))
_s_k = max(0, attn_bias.size(3) - key.size(1))
attn_bias = attn_bias[:, :, _s_q:, _s_k:]

if attn_bias is not None:
raise NotImplementedError(f'attn_bias not implemented for flash attn.')

batch_size, seqlen = query.shape[:2]

if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]
if attention_mask_in_length is None:
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]
unpadding_function = bert_padding.unpad_input
else:
key_padding_mask = attention_mask_in_length
query_padding_mask = attention_mask_in_length
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences

query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
query, query_padding_mask)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function(
key, key_padding_mask)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
value_unpad, _, _, _ = unpadding_function(value, key_padding_mask)
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

# multi-query case
if kv_n_heads == 1:
# Expanding a tensor does not allocate new memory, but only creates a new
# view on the existing tensor where a dimension of size one is expanded
# to a larger size by setting the stride to 0.
# - pytorch docs
#
# hopefully the kernels can utilize this and we're jot just wasting BW here
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads,
key_unpad.size(-1))
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads,
value_unpad.size(-1))
# grouped query case
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels

# since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
# we use .view to modify {key, value}_unpad appropriately
if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and (
not should_repeat_kv_for_gqa):
raise ValueError(
'For Grouped Query Attention or Multi Query Attention, should_repeat_kv_for_gqa should be set to True if not using Flash Attention v2.'
)

key_unpad = repeat_kv_for_gqa(
key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1)
if should_repeat_kv_for_gqa:
# multi-query case
if kv_n_heads == 1:
# Expanding a tensor does not allocate new memory, but only creates a new
# view on the existing tensor where a dimension of size one is expanded
# to a larger size by setting the stride to 0.
# - pytorch docs
#
# hopefully the kernels can utilize this and we're jot just wasting BW here
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads,
key_unpad.size(-1))
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads,
value_unpad.size(-1))
# grouped query case
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels

# since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
# we use .view to modify {key, value}_unpad appropriately

key_unpad = repeat_kv_for_gqa(
key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1)

dropout_p = dropout_p if training else 0.0

Expand Down Expand Up @@ -331,7 +340,8 @@ def flash_attn_fn(
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights)
return_attn_probs=needs_weights,
window_size=(sliding_window_size, sliding_window_size))
else:
raise RuntimeError(
'flash-attn==1.0.9 or flash-attn==2.3.2 is required.')
Expand Down Expand Up @@ -490,6 +500,7 @@ def __init__(
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
):
super().__init__()

Expand All @@ -500,6 +511,7 @@ def __init__(
self.d_model = d_model
self.n_heads = n_heads
self.kv_n_heads = kv_n_heads
self.sliding_window_size = sliding_window_size

self.head_dim = d_model // n_heads

Expand Down Expand Up @@ -569,6 +581,7 @@ def forward(
rotary_emb_w_meta_info: Optional[dict] = None,
is_causal: bool = True,
needs_weights: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
qkv = self.Wqkv(x)
Expand Down Expand Up @@ -626,6 +639,14 @@ def forward(
query = query.view(bsz, seqlen, self.d_model)
key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim)

extra_attn_kwargs = {}
if self.attn_impl == 'flash':
extra_attn_kwargs = {
'attention_mask_in_length': attention_mask_in_length,
'should_repeat_kv_for_gqa': not is_flash_v2_installed(),
'sliding_window_size': self.sliding_window_size,
}

context, attn_weights, past_key_value = self.attn_fn(
query,
key,
Expand All @@ -640,6 +661,7 @@ def forward(
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
**extra_attn_kwargs,
)

return self.out_proj(context), attn_weights, past_key_value
Expand All @@ -665,6 +687,7 @@ def __init__(
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
):
super().__init__(
d_model=d_model,
Expand All @@ -679,6 +702,7 @@ def __init__(
fc_type=fc_type,
device=device,
bias=bias,
sliding_window_size=sliding_window_size,
)


Expand All @@ -702,6 +726,7 @@ def __init__(
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
):
super().__init__(
d_model=d_model,
Expand All @@ -716,6 +741,7 @@ def __init__(
fc_type=fc_type,
device=device,
bias=bias,
sliding_window_size=sliding_window_size,
)


Expand Down
3 changes: 3 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
'softmax_scale': None,
'prefix_lm': False,
'attn_uses_sequence_id': False,
'sliding_window_size': -1,
'alibi': False,
'alibi_bias_max': 8,
'rope': False,
Expand Down Expand Up @@ -113,6 +114,7 @@ def forward(
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
Expand All @@ -124,6 +126,7 @@ def forward(
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
attention_mask_in_length=attention_mask_in_length,
)
x = x + self.resid_attn_dropout(b)
m = x
Expand Down
15 changes: 12 additions & 3 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
which sub-sequence each token belongs to.
Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
sliding_window_size (int): Window size for sliding window local attention. Defaults to -1, which means no sliding window. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size, i + seqlen_k - seqlen_q + window_size] inclusive. Only works for flash attention v2.3.0 or higher.
alibi (bool): Whether to use the alibi bias instead of position embeddings.
alibi_bias_max (int): The maximum value of the alibi bias.
rope (bool): Whether to use rotary positional embeddings.
Expand Down Expand Up @@ -221,10 +222,12 @@ def _validate_config(self) -> None:
]:
raise NotImplementedError(
'alibi only implemented with torch and triton attention.')
if self.attn_config['attn_uses_sequence_id'] and self.attn_config[
'attn_impl'] not in ['torch', 'triton']:
if self.attn_config['attn_uses_sequence_id'] and not (
self.attn_config['attn_impl'] in ['torch', 'triton'] or
(self.attn_config['attn_impl'] == 'flash' and
is_flash_v2_installed(v2_version='v2.1.2'))):
raise NotImplementedError(
'attn_uses_sequence_id only implemented with torch and triton attention.'
'attn_uses_sequence_id only implemented with torch, triton, and flash (v2.1.2 or higher) attention.'
)
if self.attn_config['rope'] and (self.attn_config['rope_impl']
not in ['dail', 'hf']):
Expand All @@ -251,6 +254,12 @@ def _validate_config(self) -> None:
raise ImportError(
'If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support'
)
if self.attn_config['sliding_window_size'] != -1 and not (
self.attn_config['attn_impl'] == 'flash' and
is_flash_v2_installed(v2_version='v2.3.0')):
raise NotImplementedError(
'sliding window only implemented with flash attention v2.3.0 or higher.'
)
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
raise ValueError(
'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'
Expand Down
Loading

0 comments on commit 84b5d96

Please sign in to comment.