diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 50dd92f26c..604472e901 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -13,6 +13,8 @@ @pytest.mark.gpu @pytest.mark.parametrize('kv_n_heads', [1, 2, 4, 8]) def test_gqa_kv_repetition(kv_n_heads: int): + # Test that flash attention v2 with GQA (kv_n_heads < n_heads) works the same + # whether we repeat the kv_n_heads explicitly or flash attention v2 handles it on its own. if not is_flash_v2_installed(): pytest.skip('GQA natively only supported by Flash Attention after v2.') d = 128 @@ -82,6 +84,7 @@ def test_gqa_kv_repetition(kv_n_heads: int): @pytest.mark.gpu def test_seq_id_masking_FA_v2(): + # Test that flash attention v2 with sequence id masking works correctly. if not is_flash_v2_installed(v2_version='v2.1.2'): pytest.skip( 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.' diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index bde14a2775..9dc412b162 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -106,7 +106,7 @@ def test_attn_impl(attn_impl_0: str, assert s >= 8 sequence_id = torch.Tensor([[0] * 4 + [1] * (s - 4), [0] * 8 + [1] * (s - 8) - ]).to(device=device, dtype=torch.long) + ]).to(device=device).long() cfg.attn_impl = attn_impl_0 attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) @@ -140,7 +140,10 @@ def gen_bias(attn_impl: str): ) if attn_impl != 'flash' and attn_uses_sequence_id and sequence_id is not None: assert isinstance(attn_bias, torch.Tensor) # pyright - attn_bias = apply_sequence_id(attn_bias, sequence_id, s) # type: ignore + attn_bias = apply_sequence_id( + attn_bias, + sequence_id, # type: ignore + s) return attn_bias diff --git a/tests/test_model.py b/tests/test_model.py index 882573bfdd..6a1863a6fa 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -515,7 +515,6 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): assert block.resid_ffn_dropout.p == 0.2 - @pytest.mark.parametrize('attention_impl', [ 'torch', pytest.param('flash', marks=pytest.mark.gpu), @@ -546,22 +545,16 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): }, }]) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) -def test_sequence_id_based_masking(attention_impl: str, - pos_emb_config: dict, +def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict, tie_word_embeddings: bool): # Testing the output of concatenated sequence with sequence id masking vs individual sequences. - device = 'gpu' if torch.cuda.is_available() else 'cpu' - if not torch.cuda.is_available() and device == 'gpu': - pytest.skip( - f'This test requires CUDA to be available in order to run with {attention_impl} attention.' - ) alibi = pos_emb_config['alibi'] if alibi and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') rope = pos_emb_config['rope'] - if rope and pos_emb_config['rope_impl'] == 'dail' and ( - device != 'gpu' or not is_flash_v2_installed()): + if rope and pos_emb_config[ + 'rope_impl'] == 'dail' and not is_flash_v2_installed(): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') @@ -571,7 +564,7 @@ def test_sequence_id_based_masking(attention_impl: str, 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.' ) - composer_device = get_device(device) + composer_device = get_device(None) hf_config = MPTConfig( init_device='cpu',