From f02034b78f2e8333768fbb8d55099f9ed23fba15 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 18 Nov 2023 00:59:44 +0000 Subject: [PATCH] .. --- tests/test_model.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index f8531b3470..38387192af 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -519,11 +519,12 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): assert block.resid_ffn_dropout.p == 0.2 -@pytest.mark.gpu -@pytest.mark.parametrize('attention_impl,device', [('torch', 'cpu'), - ('flash', 'gpu'), - ('triton', 'gpu'), - ('torch', 'gpu')]) +@pytest.mark.parametrize('attn_impl', [ + 'torch', + pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('triton', marks=pytest.mark.gpu), + pytest.param('torch', marks=pytest.mark.gpu), +]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': True, 'rope': False @@ -548,10 +549,11 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): }, }]) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) -def test_sequence_id_based_masking(attention_impl: str, device: str, +def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict, tie_word_embeddings: bool): # Testing the output of concatenated sequence with sequence id masking vs individual sequences. + device = 'gpu' if torch.cuda.is_available() else 'cpu' if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attention_impl} attention.'