From 7a8bc43834348b222d7ec61b36cc8ab56bb26acc Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sun, 19 Nov 2023 06:31:02 +0000 Subject: [PATCH] precommit --- tests/test_rope_dail_vs_hf.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/test_rope_dail_vs_hf.py b/tests/test_rope_dail_vs_hf.py index 45b2ad9aa5..70a00470f9 100644 --- a/tests/test_rope_dail_vs_hf.py +++ b/tests/test_rope_dail_vs_hf.py @@ -15,11 +15,7 @@ 'attn_type', ['multihead_attention', 'multiquery_attention', 'grouped_query_attention']) @pytest.mark.parametrize('seq_len', [1, 233, 2048]) -def test_rope_dail_vs_hf(clip_qkv: bool, - qk_ln: bool, - attn_type: str, - seq_len: int, - device: str = 'cuda'): +def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): # compare rope rotations for the dail vs hf implementations if not is_flash_v2_installed(): pytest.skip('dail implementation of rope requires flash attention 2.') @@ -31,8 +27,8 @@ def test_rope_dail_vs_hf(clip_qkv: bool, 'd_model': 128, 'n_heads': 4, 'attn_pdrop': 0, - 'clip_qkv': clip_qkv, - 'qk_ln': qk_ln, + 'clip_qkv': True, + 'qk_ln': False, }) batch_size = 2