diff --git a/tests/test_model.py b/tests/test_model.py index e7025b2ebe..4ed3a385cb 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1575,11 +1575,11 @@ def test_alibi_vs_hf(): torch.testing.assert_close(alibi_bias_hf, alibi_bias_m) -@pytest.mark.parametrize('attn_impl,device', [ +@pytest.mark.parametrize('attn_impl', [ 'torch', - pytest.param('flash', 'gpu', marks=pytest.mark.gpu), - pytest.param('triton', 'gpu', marks=pytest.mark.gpu), - pytest.param('torch', 'gpu', marks=pytest.mark.gpu), + pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('triton', marks=pytest.mark.gpu), + pytest.param('torch', marks=pytest.mark.gpu), ]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False,