Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Nov 19, 2023
1 parent 0a9e0f2 commit 7a8bc43
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions tests/test_rope_dail_vs_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@
'attn_type',
['multihead_attention', 'multiquery_attention', 'grouped_query_attention'])
@pytest.mark.parametrize('seq_len', [1, 233, 2048])
def test_rope_dail_vs_hf(clip_qkv: bool,
qk_ln: bool,
attn_type: str,
seq_len: int,
device: str = 'cuda'):
def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
# compare rope rotations for the dail vs hf implementations
if not is_flash_v2_installed():
pytest.skip('dail implementation of rope requires flash attention 2.')
Expand All @@ -31,8 +27,8 @@ def test_rope_dail_vs_hf(clip_qkv: bool,
'd_model': 128,
'n_heads': 4,
'attn_pdrop': 0,
'clip_qkv': clip_qkv,
'qk_ln': qk_ln,
'clip_qkv': True,
'qk_ln': False,
})

batch_size = 2
Expand Down

0 comments on commit 7a8bc43

Please sign in to comment.