diff --git a/tests/test_model.py b/tests/test_model.py index e5644712ee..19a89dcc07 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1250,11 +1250,11 @@ def test_forward_with_cache(attn_impl: str, pos_emb_config: dict, ) -@pytest.mark.parametrize('attn_impl,device', [ +@pytest.mark.parametrize('attn_impl', [ 'torch', - pytest.param('flash', 'gpu', marks=pytest.mark.gpu), - pytest.param('triton', 'gpu', marks=pytest.mark.gpu), - pytest.param('torch', 'gpu', marks=pytest.mark.gpu), + pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('triton', marks=pytest.mark.gpu), + pytest.param('torch', marks=pytest.mark.gpu), ]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False,