Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Nov 18, 2023
1 parent c94e7fe commit e96b234
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 234 deletions.
145 changes: 76 additions & 69 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import math

import pytest
import torch
import math

from llmfoundry.models.layers.attention import flash_attn_fn
from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.layers.attention import (flash_attn_fn,
is_flash_v2_installed)


@pytest.mark.gpu
@pytest.mark.parametrize('kv_n_heads', [1, 2, 4, 8])
def test_gqa_kv_repetition(kv_n_heads: int):
if not is_flash_v2_installed():
pytest.skip(
'GQA natively only supported by Flash Attention after v2.'
)
pytest.skip('GQA natively only supported by Flash Attention after v2.')
d = 128
n_heads = 8
seqlen_1 = 6
Expand All @@ -25,26 +24,27 @@ def test_gqa_kv_repetition(kv_n_heads: int):
query_1.requires_grad = True
key_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda()
key_1.requires_grad = True
value_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda()
value_1 = torch.randn(bsz, seqlen_1,
kv_n_heads * d).to(torch.bfloat16).cuda()
value_1.requires_grad = True

output_1, _, _ = flash_attn_fn(query=query_1,
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=None,
query_attention_mask_in_length=None,
should_repeat_kv_for_gqa=True)
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=None,
query_attention_mask_in_length=None,
should_repeat_kv_for_gqa=True)

output_1.sum().backward()

Expand All @@ -56,28 +56,28 @@ def test_gqa_kv_repetition(kv_n_heads: int):
value_2.requires_grad = True

output_2, _, _ = flash_attn_fn(query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=None,
query_attention_mask_in_length=None,
should_repeat_kv_for_gqa=False)
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=None,
query_attention_mask_in_length=None,
should_repeat_kv_for_gqa=False)

output_2.sum().backward()
assert torch.allclose(output_1, output_2)
assert torch.allclose(query_1.grad, query_2.grad)
assert torch.allclose(key_1.grad, key_2.grad)
assert torch.allclose(value_1.grad, value_2.grad)
assert torch.allclose(query_1.grad, query_2.grad) # type: ignore
assert torch.allclose(key_1.grad, key_2.grad) # type: ignore
assert torch.allclose(value_1.grad, value_2.grad) # type: ignore


@pytest.mark.gpu
Expand All @@ -96,15 +96,18 @@ def test_seq_id_masking_FA_v2():
query_1.requires_grad = True
key_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda()
key_1.requires_grad = True
value_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda()
value_1 = torch.randn(bsz, seqlen_1,
kv_n_heads * d).to(torch.bfloat16).cuda()
value_1.requires_grad = True

seq_ranges = [(0, 3), (3, 5), (5, 6)] # Each batch has 3 sequences of length 3, 2, and 1 respectively.
seq_ranges = [
(0, 3), (3, 5), (5, 6)
] # Each batch has 3 sequences of length 3, 2, and 1 respectively.
query_attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0],
[3, 2, 1, 0, 0,
0]]).to(torch.int64).cuda()
[3, 2, 1, 0, 0, 0]
]).to(torch.int64).cuda()
key_attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0],
[3, 2, 1, 0, 0,
[3, 2, 1, 0, 0,
0]]).to(torch.int64).cuda()

output_1, _, _ = flash_attn_fn(
Expand Down Expand Up @@ -136,27 +139,31 @@ def test_seq_id_masking_FA_v2():
value_2.requires_grad = True

output_2, _, _ = flash_attn_fn(query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=None,
query_attention_mask_in_length=None)
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
key_attention_mask_in_length=None,
query_attention_mask_in_length=None)

output_2.sum().backward()
assert torch.allclose(output_1[:, seq_range[0]:seq_range[1], :], output_2)
assert torch.allclose(query_1.grad[:, seq_range[0]:seq_range[1], :],
query_2.grad)
assert torch.allclose(key_1.grad[:, seq_range[0]:seq_range[1], :],
key_2.grad)
assert torch.allclose(value_1.grad[:, seq_range[0]:seq_range[1], :],
value_2.grad)
assert torch.allclose(output_1[:, seq_range[0]:seq_range[1], :],
output_2)
assert torch.allclose(
query_1.grad[:, seq_range[0]:seq_range[1], :], # type: ignore
query_2.grad) # type: ignore
assert torch.allclose(
key_1.grad[:, seq_range[0]:seq_range[1], :], # type: ignore
key_2.grad) # type: ignore
assert torch.allclose(
value_1.grad[:, seq_range[0]:seq_range[1], :], # type: ignore
value_2.grad) # type: ignore
1 change: 1 addition & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool):
assert block.resid_ffn_dropout.p == 0.2


@pytest.mark.gpu
@pytest.mark.parametrize('attention_impl,device', [('torch', 'cpu'),
('flash', 'gpu'),
('triton', 'gpu'),
Expand Down
78 changes: 0 additions & 78 deletions tests/tst_dont_repeat_kv_for_gqa.py

This file was deleted.

87 changes: 0 additions & 87 deletions tests/tst_seq_id_masking_works_correctly.py

This file was deleted.

0 comments on commit e96b234

Please sign in to comment.