Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Nov 30, 2023
1 parent 4b25da2 commit 67deef8
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 58 deletions.
62 changes: 27 additions & 35 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,8 @@ def scaled_multihead_dot_product_attention(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
del attention_mask_in_length, should_repeat_kv_for_gqa, sliding_window_size

if multiquery:
warnings.warn(
DeprecationWarning(
Expand Down Expand Up @@ -271,33 +266,28 @@ def flash_attn_fn(
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]
unpadding_function = bert_padding.unpad_input
else:
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
query_padding_mask = attention_mask_in_length
key_padding_mask = attention_mask_in_length

query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
query, query_padding_mask)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
key, key_padding_mask)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
query, query_padding_mask)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
value_unpad = rearrange(value_unpad,
'nnz (h d) -> nnz h d',
h=kv_n_heads)
else:
query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input_for_concatenated_sequences(
query, attention_mask_in_length)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function(
key, key_padding_mask)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input_for_concatenated_sequences(
key, attention_mask_in_length)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
value_unpad, _, _, _ = unpadding_function(value, key_padding_mask)
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

value_unpad, _, _, _ = bert_padding.unpad_input_for_concatenated_sequences(
value, attention_mask_in_length)
value_unpad = rearrange(value_unpad,
'nnz (h d) -> nnz h d',
h=kv_n_heads)
if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and (
not should_repeat_kv_for_gqa):
raise ValueError(
'For Grouped Query Attention or Multi Query Attention, should_repeat_kv_for_gqa should be set to True if not using Flash Attention v2.'
)

if should_repeat_kv_for_gqa:
# multi-query case
Expand Down Expand Up @@ -383,12 +373,8 @@ def triton_flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
del attention_mask_in_length, should_repeat_kv_for_gqa, sliding_window_size
try:
from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
except:
Expand Down Expand Up @@ -659,6 +645,14 @@ def forward(
query = query.view(bsz, seqlen, self.d_model)
key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim)

extra_attn_kwargs = {}
if self.attn_impl == 'flash':
extra_attn_kwargs = {
'attention_mask_in_length': attention_mask_in_length,
'should_repeat_kv_for_gqa': not is_flash_v2_installed(),
'sliding_window_size': self.sliding_window_size,
}

context, attn_weights, past_key_value = self.attn_fn(
query,
key,
Expand All @@ -673,9 +667,7 @@ def forward(
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
attention_mask_in_length=attention_mask_in_length,
should_repeat_kv_for_gqa=not is_flash_v2_installed(),
sliding_window_size=self.sliding_window_size,
**extra_attn_kwargs,
)

return self.out_proj(context), attn_weights, past_key_value
Expand Down
24 changes: 18 additions & 6 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,20 @@ def gen_rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int,
def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,
attn_uses_sequence_id: bool, attn_impl: str,
attention_mask: Union[torch.Tensor, None]):
# Generates the attention masks used for sequence masking in flash attention
# NOTE: Only supports sequence id based sparse attention for no attention masking or attention masking with right padding.
# In case of left padding:
# 1. Training with left padding is not supported in MPT (see https://github.com/mosaicml/llm-foundry/blob/1eecd4cb8e734499f77f6a35f657b8b20c0adfcb/llmfoundry/models/mpt/modeling_mpt.py#L407).
# 2. For generation with left padding, we only have a single sequence id per sample, so we don't need sequence id based sparse attention.
"""Generates the attention mask used for sequence masking in FA v2.
Only supports sequence id based sparse attention for no attention masking or attention masking with right padding.
In case of left padding:
1. Training with left padding is not supported in MPT (see https://github.com/mosaicml/llm-foundry/blob/1eecd4cb8e734499f77f6a35f657b8b20c0adfcb/llmfoundry/models/mpt/modeling_mpt.py#L407).
2. For generation with left padding, we only have a single sequence id per sample, so we don't need sequence id based sparse attention.
Args:
sequence_id (Union[None, torch.Tensor]): Tensor containing the sequence id for each token. Shape (batch_size, seq_len).
S (int): Sequence length
attn_uses_sequence_id (bool): Whether the attention uses sequence id based masking.
attn_impl (str): Attention implementation. This function is only creates attention_mask_in_length for flash attention.
attention_mask (Union[torch.Tensor, None]): Attention mask tensor of shape (batch_size, seq_len)
"""
attention_mask_in_length = None
if (sequence_id is not None) and attn_uses_sequence_id and (attn_impl
== 'flash'):
Expand All @@ -149,7 +158,10 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,
raise NotImplementedError(
'Left padding is not supported with flash attention when attn_uses_sequence_id is set to True.'
)
assert S == sequence_id.shape[-1]
if S != sequence_id.shape[-1]:
raise ValueError(
f'Sequence length ({S}) does not match length of sequences in sequence_id ({sequence_id.shape[-1]}).'
)
attention_mask_in_length = torch.nn.functional.one_hot(sequence_id)
if attention_mask is not None:
attention_mask_in_length = attention_mask_in_length.masked_fill(
Expand Down
33 changes: 16 additions & 17 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,23 +225,22 @@ def test_sliding_window(sliding_window_size: int):
torch.ones(seqlen_1, seqlen_1), diagonal=-(sliding_window_size + 1)).to(
dtype=dtype, device=device) * torch.finfo(attn_bias_2.dtype).min
attn_bias_2 = attn_bias_2 + window_mask_2
output_2, _, _ = triton_flash_attn_fn(query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=attn_bias_2,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
attention_mask_in_length=None,
should_repeat_kv_for_gqa=False,
sliding_window_size=-1)
output_2, _, _ = triton_flash_attn_fn(
query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=attn_bias_2,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
)

output_2.sum().backward()

Expand Down

0 comments on commit 67deef8

Please sign in to comment.