From cb6864aafd736b5e9b0108e1a37f9b98c3c43201 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 2 Dec 2023 01:46:29 +0000 Subject: [PATCH] .. --- tests/models/layers/test_flash_attn.py | 22 ++++++++++++---------- tests/models/test_model.py | 5 +---- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py index 7e282dbc9d..acefd2c42d 100644 --- a/tests/models/layers/test_flash_attn.py +++ b/tests/models/layers/test_flash_attn.py @@ -12,12 +12,13 @@ @pytest.mark.gpu +@pytest.mark.skipif( + not is_flash_v2_installed(), + reason='GQA natively only supported by Flash Attention after v2.') @pytest.mark.parametrize('kv_n_heads', [1, 4, 8]) def test_gqa_kv_repetition(kv_n_heads: int): # Test that flash attention v2 with GQA (kv_n_heads < n_heads) works the same # whether we repeat the kv_n_heads explicitly or flash attention v2 handles it on its own. - if not is_flash_v2_installed(): - pytest.skip('GQA natively only supported by Flash Attention after v2.') d = 128 n_heads = 8 seqlen_1 = 6 @@ -82,12 +83,13 @@ def test_gqa_kv_repetition(kv_n_heads: int): @pytest.mark.gpu +@pytest.mark.skipif( + not is_flash_v2_installed(v2_version='v2.1.2'), + reason= + 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.' +) def test_seq_id_masking_FA_v2(): # Test that flash attention v2 with sequence id masking works correctly. - if not is_flash_v2_installed(v2_version='v2.1.2'): - pytest.skip( - 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.' - ) d = 128 n_heads = 4 kv_n_heads = 4 @@ -167,13 +169,13 @@ def test_seq_id_masking_FA_v2(): @pytest.mark.gpu +@pytest.mark.skipif( + not is_flash_v2_installed(v2_version='v2.3.0'), + reason= + 'Sliding window attention only supported by Flash Attention after v2.3.0.') @pytest.mark.parametrize('sliding_window_size', [1, 4, 8]) def test_sliding_window(sliding_window_size: int): # Test that sliding window attention works as expected. - if not is_flash_v2_installed('v2.3.0'): - pytest.skip( - 'Sliding window attention only supported by Flash Attention after v2.3.0.' - ) dtype = torch.bfloat16 device = 'cuda' d = 128 diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 9bac6b11b7..98a556f534 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -580,9 +580,7 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): 'factor': 1.0, }, }]) -@pytest.mark.parametrize('tie_word_embeddings', [True, False]) -def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict, - tie_word_embeddings: bool): +def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): # Testing the output of concatenated sequence with sequence id masking vs individual sequences. alibi = pos_emb_config['alibi'] if alibi and attention_impl == 'flash': @@ -620,7 +618,6 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict, 'name': 'baseline_', 'init_std': 0.02, }, - tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt.eval()