From e96b234bf913aa4f97aab6fe7cee0d52675d58b2 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 18 Nov 2023 00:34:11 +0000 Subject: [PATCH] .. --- tests/test_flash_attn.py | 145 ++++++++++---------- tests/test_model.py | 1 + tests/tst_dont_repeat_kv_for_gqa.py | 78 ----------- tests/tst_seq_id_masking_works_correctly.py | 87 ------------ 4 files changed, 77 insertions(+), 234 deletions(-) delete mode 100644 tests/tst_dont_repeat_kv_for_gqa.py delete mode 100644 tests/tst_seq_id_masking_works_correctly.py diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 95b040bd04..50dd92f26c 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1,21 +1,20 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import math + import pytest import torch -import math -from llmfoundry.models.layers.attention import flash_attn_fn -from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.attention import (flash_attn_fn, + is_flash_v2_installed) @pytest.mark.gpu @pytest.mark.parametrize('kv_n_heads', [1, 2, 4, 8]) def test_gqa_kv_repetition(kv_n_heads: int): if not is_flash_v2_installed(): - pytest.skip( - 'GQA natively only supported by Flash Attention after v2.' - ) + pytest.skip('GQA natively only supported by Flash Attention after v2.') d = 128 n_heads = 8 seqlen_1 = 6 @@ -25,26 +24,27 @@ def test_gqa_kv_repetition(kv_n_heads: int): query_1.requires_grad = True key_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda() key_1.requires_grad = True - value_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda() + value_1 = torch.randn(bsz, seqlen_1, + kv_n_heads * d).to(torch.bfloat16).cuda() value_1.requires_grad = True output_1, _, _ = flash_attn_fn(query=query_1, - key=key_1, - value=value_1, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - key_attention_mask_in_length=None, - query_attention_mask_in_length=None, - should_repeat_kv_for_gqa=True) + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + key_attention_mask_in_length=None, + query_attention_mask_in_length=None, + should_repeat_kv_for_gqa=True) output_1.sum().backward() @@ -56,28 +56,28 @@ def test_gqa_kv_repetition(kv_n_heads: int): value_2.requires_grad = True output_2, _, _ = flash_attn_fn(query=query_2, - key=key_2, - value=value_2, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - key_attention_mask_in_length=None, - query_attention_mask_in_length=None, - should_repeat_kv_for_gqa=False) + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + key_attention_mask_in_length=None, + query_attention_mask_in_length=None, + should_repeat_kv_for_gqa=False) output_2.sum().backward() assert torch.allclose(output_1, output_2) - assert torch.allclose(query_1.grad, query_2.grad) - assert torch.allclose(key_1.grad, key_2.grad) - assert torch.allclose(value_1.grad, value_2.grad) + assert torch.allclose(query_1.grad, query_2.grad) # type: ignore + assert torch.allclose(key_1.grad, key_2.grad) # type: ignore + assert torch.allclose(value_1.grad, value_2.grad) # type: ignore @pytest.mark.gpu @@ -96,15 +96,18 @@ def test_seq_id_masking_FA_v2(): query_1.requires_grad = True key_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda() key_1.requires_grad = True - value_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda() + value_1 = torch.randn(bsz, seqlen_1, + kv_n_heads * d).to(torch.bfloat16).cuda() value_1.requires_grad = True - seq_ranges = [(0, 3), (3, 5), (5, 6)] # Each batch has 3 sequences of length 3, 2, and 1 respectively. + seq_ranges = [ + (0, 3), (3, 5), (5, 6) + ] # Each batch has 3 sequences of length 3, 2, and 1 respectively. query_attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0], - [3, 2, 1, 0, 0, - 0]]).to(torch.int64).cuda() + [3, 2, 1, 0, 0, 0] + ]).to(torch.int64).cuda() key_attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0], - [3, 2, 1, 0, 0, + [3, 2, 1, 0, 0, 0]]).to(torch.int64).cuda() output_1, _, _ = flash_attn_fn( @@ -136,27 +139,31 @@ def test_seq_id_masking_FA_v2(): value_2.requires_grad = True output_2, _, _ = flash_attn_fn(query=query_2, - key=key_2, - value=value_2, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - key_attention_mask_in_length=None, - query_attention_mask_in_length=None) + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + key_attention_mask_in_length=None, + query_attention_mask_in_length=None) output_2.sum().backward() - assert torch.allclose(output_1[:, seq_range[0]:seq_range[1], :], output_2) - assert torch.allclose(query_1.grad[:, seq_range[0]:seq_range[1], :], - query_2.grad) - assert torch.allclose(key_1.grad[:, seq_range[0]:seq_range[1], :], - key_2.grad) - assert torch.allclose(value_1.grad[:, seq_range[0]:seq_range[1], :], - value_2.grad) + assert torch.allclose(output_1[:, seq_range[0]:seq_range[1], :], + output_2) + assert torch.allclose( + query_1.grad[:, seq_range[0]:seq_range[1], :], # type: ignore + query_2.grad) # type: ignore + assert torch.allclose( + key_1.grad[:, seq_range[0]:seq_range[1], :], # type: ignore + key_2.grad) # type: ignore + assert torch.allclose( + value_1.grad[:, seq_range[0]:seq_range[1], :], # type: ignore + value_2.grad) # type: ignore diff --git a/tests/test_model.py b/tests/test_model.py index 372ff41ae9..f8531b3470 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -519,6 +519,7 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): assert block.resid_ffn_dropout.p == 0.2 +@pytest.mark.gpu @pytest.mark.parametrize('attention_impl,device', [('torch', 'cpu'), ('flash', 'gpu'), ('triton', 'gpu'), diff --git a/tests/tst_dont_repeat_kv_for_gqa.py b/tests/tst_dont_repeat_kv_for_gqa.py deleted file mode 100644 index fb8e4f936d..0000000000 --- a/tests/tst_dont_repeat_kv_for_gqa.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2022 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -# NOTE: This is a temporary test file to test that we do not need to repeat kv tensors for gqa on our end for flash_attn_2. -# Will be deleted before merging to main (or not, if people think this should be added). -# To run, simply run `python test_dont_repeat_kv_for_gqa.py`. -import math - -import torch - -from llmfoundry.models.layers.attention import (flash_attn_fn, - is_flash_v2_installed) - -assert is_flash_v2_installed() - -d = 128 -n_heads = 4 -kv_n_heads = 2 -seqlen_1 = 6 -bsz = 2 - -query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(torch.bfloat16).cuda() -query_1.requires_grad = True -key_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda() -key_1.requires_grad = True -value_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda() -value_1.requires_grad = True - -output_1, _, _ = flash_attn_fn(query=query_1, - key=key_1, - value=value_1, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - key_attention_mask_in_length=None, - query_attention_mask_in_length=None, - should_repeat_kv_for_gqa=True) - -output_1.sum().backward() - -query_2 = query_1.detach().clone() -query_2.requires_grad = True -key_2 = key_1.detach().clone() -key_2.requires_grad = True -value_2 = value_1.detach().clone() -value_2.requires_grad = True - -output_2, _, _ = flash_attn_fn(query=query_2, - key=key_2, - value=value_2, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - key_attention_mask_in_length=None, - query_attention_mask_in_length=None, - should_repeat_kv_for_gqa=False) - -output_2.sum().backward() -assert torch.allclose(output_1, output_2) -assert torch.allclose(query_1.grad, query_2.grad) -assert torch.allclose(key_1.grad, key_2.grad) -assert torch.allclose(value_1.grad, value_2.grad) diff --git a/tests/tst_seq_id_masking_works_correctly.py b/tests/tst_seq_id_masking_works_correctly.py deleted file mode 100644 index 07b0f977d5..0000000000 --- a/tests/tst_seq_id_masking_works_correctly.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2022 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -# NOTE: This is a temporary test file to test the correctness of the seq_id_masking function. -# Will be deleted before merging to main (or not, if people think this should be added). -# To run, simply run `python test_seq_id_masking_works_correctly.py`. -import math - -import torch - -from llmfoundry.models.layers.attention import flash_attn_fn - -d = 128 -n_heads = 4 -kv_n_heads = 4 -seqlen_1 = 6 -bsz = 2 - -query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(torch.bfloat16).cuda() -query_1.requires_grad = True -key_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda() -key_1.requires_grad = True -value_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda() -value_1.requires_grad = True - -query_attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0], - [3, 2, 1, 0, 0, - 0]]).to(torch.int64).cuda() -key_attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0], - [3, 2, 1, 0, 0, - 0]]).to(torch.int64).cuda() - -output_1, _, _ = flash_attn_fn( - query=query_1, - key=key_1, - value=value_1, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - key_attention_mask_in_length=key_attention_mask_in_length_1, - query_attention_mask_in_length=query_attention_mask_in_length_1) - -output_1.sum().backward() - -seq_ranges = [(0, 3), (3, 5), (5, 6)] -for seq_range in seq_ranges: - seqlen_2 = seq_range[1] - seq_range[0] - query_2 = query_1.detach().clone()[:, seq_range[0]:seq_range[1], :] - query_2.requires_grad = True - key_2 = key_1.detach().clone()[:, seq_range[0]:seq_range[1], :] - key_2.requires_grad = True - value_2 = value_1.detach().clone()[:, seq_range[0]:seq_range[1], :] - value_2.requires_grad = True - - output_2, _, _ = flash_attn_fn(query=query_2, - key=key_2, - value=value_2, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - key_attention_mask_in_length=None, - query_attention_mask_in_length=None) - - output_2.sum().backward() - assert torch.allclose(output_1[:, seq_range[0]:seq_range[1], :], output_2) - assert torch.allclose(query_1.grad[:, seq_range[0]:seq_range[1], :], - query_2.grad) - assert torch.allclose(key_1.grad[:, seq_range[0]:seq_range[1], :], - key_2.grad) - assert torch.allclose(value_1.grad[:, seq_range[0]:seq_range[1], :], - value_2.grad)