From 21031b59cb949e9b6fa5a490ec05a92f7ba32e6c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 17 Nov 2023 16:34:06 -0800 Subject: [PATCH] fix another --- tests/test_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 19a89dcc07..e7025b2ebe 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1352,11 +1352,11 @@ def test_generate_with_past_kv(attn_impl: str, pos_emb_config: dict, hf_config.d_model) -@pytest.mark.parametrize('attn_impl,device', [ +@pytest.mark.parametrize('attn_impl', [ 'torch', - pytest.param('flash', 'gpu', marks=pytest.mark.gpu), - pytest.param('triton', 'gpu', marks=pytest.mark.gpu), - pytest.param('torch', 'gpu', marks=pytest.mark.gpu), + pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('triton', marks=pytest.mark.gpu), + pytest.param('torch', marks=pytest.mark.gpu), ]) @pytest.mark.parametrize('generation_kwargs', [{ 'max_new_tokens': 2,