Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Nov 25, 2023
1 parent 6a4f73e commit c275365
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 78 deletions.
31 changes: 11 additions & 20 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,12 @@ def scaled_multihead_dot_product_attention(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
query_attention_mask_in_length: Optional[torch.Tensor] = None,
key_attention_mask_in_length: Optional[torch.Tensor] = None,
attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
del query_attention_mask_in_length, key_attention_mask_in_length, should_repeat_kv_for_gqa, sliding_window_size
del attention_mask_in_length, should_repeat_kv_for_gqa, sliding_window_size

if multiquery:
warnings.warn(
Expand Down Expand Up @@ -224,8 +223,7 @@ def flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
key_attention_mask_in_length: Optional[torch.Tensor] = None,
query_attention_mask_in_length: Optional[torch.Tensor] = None,
attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
Expand Down Expand Up @@ -269,7 +267,7 @@ def flash_attn_fn(

batch_size, seqlen = query.shape[:2]

if query_attention_mask_in_length is None:
if attention_mask_in_length is None:
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]
Expand All @@ -287,20 +285,16 @@ def flash_attn_fn(
'nnz (h d) -> nnz h d',
h=kv_n_heads)
else:
if key_attention_mask_in_length is None:
raise ValueError(
'key_attention_mask_in_length must not be None if query_attention_mask_in_length is not None.'
)
query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input_for_concatenated_sequences(
query, query_attention_mask_in_length)
query, attention_mask_in_length)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input_for_concatenated_sequences(
key, key_attention_mask_in_length)
key, attention_mask_in_length)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

value_unpad, _, _, _ = bert_padding.unpad_input_for_concatenated_sequences(
value, key_attention_mask_in_length)
value, attention_mask_in_length)
value_unpad = rearrange(value_unpad,
'nnz (h d) -> nnz h d',
h=kv_n_heads)
Expand Down Expand Up @@ -389,13 +383,12 @@ def triton_flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
query_attention_mask_in_length: Optional[torch.Tensor] = None,
key_attention_mask_in_length: Optional[torch.Tensor] = None,
attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
del query_attention_mask_in_length, key_attention_mask_in_length, should_repeat_kv_for_gqa, sliding_window_size
del attention_mask_in_length, should_repeat_kv_for_gqa, sliding_window_size
try:
from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
except:
Expand Down Expand Up @@ -608,8 +601,7 @@ def forward(
rotary_emb_w_meta_info: Optional[dict] = None,
is_causal: bool = True,
needs_weights: bool = False,
query_attention_mask_in_length: Optional[torch.Tensor] = None,
key_attention_mask_in_length: Optional[torch.Tensor] = None,
attention_mask_in_length: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
qkv = self.Wqkv(x)
Expand Down Expand Up @@ -681,8 +673,7 @@ def forward(
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
query_attention_mask_in_length=query_attention_mask_in_length,
key_attention_mask_in_length=key_attention_mask_in_length,
attention_mask_in_length=attention_mask_in_length,
should_repeat_kv_for_gqa=not is_flash_v2_installed(),
sliding_window_size=self.sliding_window_size,
)
Expand Down
6 changes: 2 additions & 4 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def forward(
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
query_attention_mask_in_length: Optional[torch.Tensor] = None,
key_attention_mask_in_length: Optional[torch.Tensor] = None,
attention_mask_in_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
Expand All @@ -127,8 +126,7 @@ def forward(
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
query_attention_mask_in_length=query_attention_mask_in_length,
key_attention_mask_in_length=key_attention_mask_in_length,
attention_mask_in_length=attention_mask_in_length,
)
x = x + self.resid_attn_dropout(b)
m = x
Expand Down
22 changes: 6 additions & 16 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,28 +140,19 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,
# In case of left padding:
# 1. Training with left padding is not supported in MPT (see https://github.com/mosaicml/llm-foundry/blob/1eecd4cb8e734499f77f6a35f657b8b20c0adfcb/llmfoundry/models/mpt/modeling_mpt.py#L407).
# 2. For generation with left padding, we only have a single sequence id per sample, so we don't need sequence id based sparse attention.
query_attention_mask_in_length = None
key_attention_mask_in_length = None
attention_mask_in_length = None
left_padding = (attention_mask is not None) and (attention_mask[:, 0].sum()
!= attention_mask.shape[0])
if (not left_padding) and (
sequence_id is not None) and attn_uses_sequence_id and (attn_impl
== 'flash'):
assert S == sequence_id.shape[-1]
if attention_mask is not None:
sequence_id = sequence_id.masked_fill(~attention_mask, S)
query_attention_mask_in_length = torch.nn.functional.one_hot(
sequence_id[:, -S:], num_classes=S + 1).sum(dim=1)[:, :-1]
# We use S as the number of classes while creating key_attention_mask_in_length instead of sequence_id.shape[-1]
# because in case of inference, sequence_id.shape[-1] can become very large. In that case, the one_hot vectors
# would've become very large as well.
key_attention_mask_in_length = torch.nn.functional.one_hot(
attention_mask_in_length = torch.nn.functional.one_hot(
sequence_id, num_classes=S + 1).sum(dim=1)[:, :-1]
# Since Flash Attention expects the masks to have same shape as the keys, we pad it with zeros.
key_attention_mask_in_length = torch.nn.functional.pad(
key_attention_mask_in_length, (0, sequence_id.shape[-1] - S),
value=0)

return query_attention_mask_in_length, key_attention_mask_in_length
return attention_mask_in_length


def apply_sequence_id(attn_bias: torch.Tensor, sequence_id: torch.LongTensor,
Expand Down Expand Up @@ -543,7 +534,7 @@ def forward(
prefix_mask=prefix_mask,
sequence_id=sequence_id,
)
query_attention_mask_in_length, key_attention_mask_in_length = gen_attention_mask_in_length(
attention_mask_in_length = gen_attention_mask_in_length(
sequence_id=sequence_id,
S=S,
attn_uses_sequence_id=self.attn_uses_sequence_id,
Expand Down Expand Up @@ -571,8 +562,7 @@ def forward(
attention_mask=attention_mask,
is_causal=self.is_causal,
output_attentions=bool(output_attentions),
query_attention_mask_in_length=query_attention_mask_in_length,
key_attention_mask_in_length=key_attention_mask_in_length,
attention_mask_in_length=attention_mask_in_length,
)
if presents is not None:
presents += (present,)
Expand Down
27 changes: 9 additions & 18 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def test_gqa_kv_repetition(kv_n_heads: int):
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=None,
query_attention_mask_in_length=None,
attention_mask_in_length=None,
should_repeat_kv_for_gqa=True)

output_1.sum().backward()
Expand All @@ -72,8 +71,7 @@ def test_gqa_kv_repetition(kv_n_heads: int):
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=None,
query_attention_mask_in_length=None,
attention_mask_in_length=None,
should_repeat_kv_for_gqa=False)

output_2.sum().backward()
Expand Down Expand Up @@ -107,12 +105,9 @@ def test_seq_id_masking_FA_v2():
seq_ranges = [
(0, 3), (3, 5), (5, 6)
] # Each batch has 3 sequences of length 3, 2, and 1 respectively.
query_attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0],
[3, 2, 1, 0, 0, 0]
]).to(torch.int64).cuda()
key_attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0],
[3, 2, 1, 0, 0,
0]]).to(torch.int64).cuda()
attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0],
[3, 2, 1, 0, 0,
0]]).to(torch.int64).cuda()

output_1, _, _ = flash_attn_fn(
query=query_1,
Expand All @@ -129,8 +124,7 @@ def test_seq_id_masking_FA_v2():
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=key_attention_mask_in_length_1,
query_attention_mask_in_length=query_attention_mask_in_length_1)
attention_mask_in_length=attention_mask_in_length_1)

output_1.sum().backward()

Expand All @@ -156,8 +150,7 @@ def test_seq_id_masking_FA_v2():
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=None,
query_attention_mask_in_length=None)
attention_mask_in_length=None)

output_2.sum().backward()
assert torch.allclose(output_1[:, seq_range[0]:seq_range[1], :],
Expand Down Expand Up @@ -212,8 +205,7 @@ def test_sliding_window(sliding_window_size: int):
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=None,
query_attention_mask_in_length=None,
attention_mask_in_length=None,
should_repeat_kv_for_gqa=True,
sliding_window_size=sliding_window_size)

Expand Down Expand Up @@ -247,8 +239,7 @@ def test_sliding_window(sliding_window_size: int):
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=None,
query_attention_mask_in_length=None,
attention_mask_in_length=None,
should_repeat_kv_for_gqa=False,
sliding_window_size=-1)

Expand Down
36 changes: 16 additions & 20 deletions tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ def gen_bias(attn_impl: str):

return attn_bias

query_attention_mask_in_length_0, key_attention_mask_in_length_0 = gen_attention_mask_in_length(
attention_mask_in_length_0 = gen_attention_mask_in_length(
sequence_id=sequence_id,
S=s,
attn_uses_sequence_id=attn_uses_sequence_id,
attn_impl=attn_impl_0,
attention_mask=attention_mask)
query_attention_mask_in_length_1, key_attention_mask_in_length_1 = gen_attention_mask_in_length(
attention_mask_in_length_1 = gen_attention_mask_in_length(
sequence_id=sequence_id,
S=s,
attn_uses_sequence_id=attn_uses_sequence_id,
Expand Down Expand Up @@ -204,25 +204,21 @@ def gen_bias(attn_impl: str):
s,
}

y0, _, _ = attn0(
x0,
past_key_value=None,
attn_bias=attn_bias,
attention_mask=attention_mask,
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
is_causal=True,
query_attention_mask_in_length=query_attention_mask_in_length_0,
key_attention_mask_in_length=key_attention_mask_in_length_0)
y0, _, _ = attn0(x0,
past_key_value=None,
attn_bias=attn_bias,
attention_mask=attention_mask,
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
is_causal=True,
attention_mask_in_length=attention_mask_in_length_0)
attn_bias = gen_bias(attn_impl_1)
y1, _, _ = attn1(
x1,
past_key_value=None,
attn_bias=attn_bias,
attention_mask=attention_mask,
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
is_causal=True,
query_attention_mask_in_length=query_attention_mask_in_length_1,
key_attention_mask_in_length=key_attention_mask_in_length_1)
y1, _, _ = attn1(x1,
past_key_value=None,
attn_bias=attn_bias,
attention_mask=attention_mask,
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
is_causal=True,
attention_mask_in_length=attention_mask_in_length_1)
y0 *= attention_mask.unsqueeze(-1)
y1 *= attention_mask.unsqueeze(-1)

Expand Down

0 comments on commit c275365

Please sign in to comment.