Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jan 17, 2024
1 parent 09d9bdf commit 0474e05
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 107 deletions.
2 changes: 2 additions & 0 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ def flash_attn_fn(
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
del key_padding_mask
if flash_attn_padding_info is None:
raise ValueError('flash_attn_padding_info is required for flash attn.')
try:
from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
except:
Expand Down
49 changes: 26 additions & 23 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,17 +217,19 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,
return attention_mask_in_length


def get_flash_attn_padding_info(
def gen_flash_attn_padding_info(
bsz: int,
S: int,
past_key_len: int,
attention_mask_in_length: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None):
flash_attn_padding_info = {}
if attention_mask_in_length is None:
key_padding_mask = attention_mask
if key_padding_mask is None:
key_padding_mask = torch.ones(
(x.shape[0], past_key_len + x.shape[1]), dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -x.shape[1]:]
key_padding_mask = torch.ones((bsz, past_key_len + S),
dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -S:]
unpadding_function = bert_padding.unpad_input
else:
key_padding_mask = attention_mask_in_length
Expand Down Expand Up @@ -550,10 +552,12 @@ def forward(
raise ValueError(
'You cannot specify both input_ids and inputs_embeds.')
elif input_ids is not None:
bsz = input_ids.size(0)
S = input_ids.size(1)
x = self.wte(input_ids)
input_device = input_ids.device
elif inputs_embeds is not None:
bsz = inputs_embeds.size(0)
S = inputs_embeds.size(1)
x = inputs_embeds
input_device = inputs_embeds.device
Expand All @@ -565,22 +569,23 @@ def forward(
), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'

rotary_emb_w_meta_info = None
if self.learned_pos_emb or self.rope:
past_position = 0
if past_key_values is not None:
if len(past_key_values) != self.config.n_layers:
raise ValueError(
f'past_key_values must provide a past_key_value for each attention '
+
f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
)
# For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
# For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
# Here we shift position embedding using the `seq` dim of the past key
past_position = past_key_values[0][0].size(1)
if self.attn_impl == 'torch':
past_position = past_key_values[0][0].size(3)

past_position = 0
if past_key_values is not None:
if len(past_key_values) != self.config.n_layers:
raise ValueError(
f'past_key_values must provide a past_key_value for each attention '
+
f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
)
# For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
# For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
# Here we shift position embedding using the `seq` dim of the past key
past_position = past_key_values[0][0].size(1)
if self.attn_impl == 'torch':
past_position = past_key_values[0][0].size(3)

if self.learned_pos_emb or self.rope:
if self.learned_pos_emb and (S + past_position >
self.config.max_seq_len):
raise ValueError(
Expand Down Expand Up @@ -660,10 +665,8 @@ def forward(
all_self_attns = () if output_attentions else None
flash_attn_padding_info = {}
if self.attn_impl == 'flash':
past_key_len = past_key_values[0].shape[
1] if past_key_values is not None else 0
flash_attn_padding_info = get_flash_attn_padding_info(
past_key_len, attention_mask_in_length, attention_mask)
flash_attn_padding_info = gen_flash_attn_padding_info(
bsz, S, past_position, attention_mask_in_length, attention_mask)

for b_idx, block in enumerate(self.blocks):
if output_hidden_states:
Expand Down
180 changes: 98 additions & 82 deletions tests/models/layers/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
flash_attn_fn, gen_slopes,
is_flash_v2_installed,
triton_flash_attn_fn)
from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info


@pytest.mark.gpu
Expand All @@ -35,22 +36,24 @@ def test_gqa_kv_repetition(kv_n_heads: int):
kv_n_heads * d).to(torch.bfloat16).cuda()
value_1.requires_grad = True

output_1, _, _ = flash_attn_fn(query=query_1,
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
attention_mask_in_length=None,
should_repeat_kv_for_gqa=True)
output_1, _, _ = flash_attn_fn(
query=query_1,
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
flash_attn_padding_info=gen_flash_attn_padding_info(
bsz, seqlen_1, 0, None, None),
should_repeat_kv_for_gqa=True)

output_1.sum().backward()

Expand All @@ -61,22 +64,24 @@ def test_gqa_kv_repetition(kv_n_heads: int):
value_2 = value_1.detach().clone()
value_2.requires_grad = True

output_2, _, _ = flash_attn_fn(query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
attention_mask_in_length=None,
should_repeat_kv_for_gqa=False)
output_2, _, _ = flash_attn_fn(
query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
flash_attn_padding_info=gen_flash_attn_padding_info(
bsz, seqlen_1, 0, None, None),
should_repeat_kv_for_gqa=False)

output_2.sum().backward()
assert torch.allclose(output_1, output_2)
Expand Down Expand Up @@ -114,6 +119,9 @@ def test_seq_id_masking_FA_v2():
[3, 2, 1, 0, 0,
0]]).to(torch.int64).cuda()

flash_attn_padding_info_1 = gen_flash_attn_padding_info(
bsz, seqlen_1, 0, attention_mask_in_length_1, None)

output_1, _, _ = flash_attn_fn(
query=query_1,
key=key_1,
Expand All @@ -129,7 +137,7 @@ def test_seq_id_masking_FA_v2():
training=False,
needs_weights=False,
multiquery=False,
attention_mask_in_length=attention_mask_in_length_1)
flash_attn_padding_info=flash_attn_padding_info_1)

output_1.sum().backward()

Expand All @@ -141,21 +149,25 @@ def test_seq_id_masking_FA_v2():
value_2 = value_1.detach().clone()[:, seq_range[0]:seq_range[1], :]
value_2.requires_grad = True

output_2, _, _ = flash_attn_fn(query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
attention_mask_in_length=None)
flash_attn_padding_info_2 = gen_flash_attn_padding_info(
bsz, seq_range[1] - seq_range[0], 0, None, None)

output_2, _, _ = flash_attn_fn(
query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
flash_attn_padding_info=flash_attn_padding_info_2)

output_2.sum().backward()
assert torch.allclose(output_1[:, seq_range[0]:seq_range[1], :],
Expand Down Expand Up @@ -196,23 +208,25 @@ def test_sliding_window(sliding_window_size: int):
device=device)
value_1.requires_grad = True

output_1, _, _ = flash_attn_fn(query=query_1,
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
attention_mask_in_length=None,
should_repeat_kv_for_gqa=True,
sliding_window_size=sliding_window_size)
output_1, _, _ = flash_attn_fn(
query=query_1,
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
flash_attn_padding_info=gen_flash_attn_padding_info(
bsz, seqlen_1, 0, None, None),
should_repeat_kv_for_gqa=True,
sliding_window_size=sliding_window_size)

output_1.sum().backward()

Expand Down Expand Up @@ -284,23 +298,25 @@ def test_alibi_bias(n_heads: int):
alibi_bias_max=8,
device=torch.device(device),
return_1d=True)
output_1, _, _ = flash_attn_fn(query=query_1,
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
attention_mask_in_length=None,
should_repeat_kv_for_gqa=True,
alibi_slopes=alibi_slopes_1)
output_1, _, _ = flash_attn_fn(
query=query_1,
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
flash_attn_padding_info=gen_flash_attn_padding_info(
bsz, seqlen_1, 0, None, None),
should_repeat_kv_for_gqa=True,
alibi_slopes=alibi_slopes_1)

output_1.sum().backward()

Expand Down
16 changes: 14 additions & 2 deletions tests/models/layers/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
is_flash_v2_installed)
from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id,
gen_attention_mask_in_length,
gen_flash_attn_padding_info,
gen_rotary_embedding)


Expand Down Expand Up @@ -164,13 +165,24 @@ def gen_bias(attn_impl: str):
attn_uses_sequence_id=attn_uses_sequence_id,
attn_impl=attn_impl_0,
attention_mask=attention_mask)

flash_attn_padding_info_0 = {}
if attn_impl_0 == 'flash':
flash_attn_padding_info_0 = gen_flash_attn_padding_info(
n, s, 0, attention_mask_in_length_0, attention_mask)

attention_mask_in_length_1 = gen_attention_mask_in_length(
sequence_id=sequence_id,
S=s,
attn_uses_sequence_id=attn_uses_sequence_id,
attn_impl=attn_impl_1,
attention_mask=attention_mask)

flash_attn_padding_info_1 = {}
if attn_impl_1 == 'flash':
flash_attn_padding_info_1 = gen_flash_attn_padding_info(
n, s, 0, attention_mask_in_length_1, attention_mask)

x0 = torch.randn(n, s, f).to(device)
x1 = x0.clone().detach()
x0.requires_grad = True
Expand Down Expand Up @@ -216,7 +228,7 @@ def gen_bias(attn_impl: str):
attention_mask=attention_mask,
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
is_causal=True,
attention_mask_in_length=attention_mask_in_length_0,
flash_attn_padding_info=flash_attn_padding_info_0,
alibi_slopes=alibi_slopes_0)
attn_bias_1 = gen_bias(attn_impl_1)
alibi_slopes_1 = None
Expand All @@ -231,7 +243,7 @@ def gen_bias(attn_impl: str):
attention_mask=attention_mask,
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
is_causal=True,
attention_mask_in_length=attention_mask_in_length_1,
flash_attn_padding_info=flash_attn_padding_info_1,
alibi_slopes=alibi_slopes_1)
y0 *= attention_mask.unsqueeze(-1)
y1 *= attention_mask.unsqueeze(-1)
Expand Down

0 comments on commit 0474e05

Please sign in to comment.