Skip to content

Commit

Permalink
Update test_flash_triton_torch.py
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley authored Jan 13, 2024
1 parent bcef3fe commit 63de0b4
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/models/layers/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def allclose_helper(t0: torch.Tensor,
('triton', 'torch'),
])
@pytest.mark.parametrize('clip_qkv', [True, False])
@pytest.mark.parametrize('qk_ln', [True, False])
@pytest.mark.parametrize('qk_ln, qk_gn', [
(True, False),
(False, True),
Expand Down Expand Up @@ -68,15 +67,16 @@ def test_attn_impl(attn_impl_0: str,
attn_impl_1: str,
clip_qkv: bool,
qk_ln: bool,
qk_gn: bool,
pos_emb_config: dict,
attn_type: str,
attn_uses_sequence_id: bool,
pad_attention_mask: bool,
device: str = 'cuda'):
"""Compare all attn impl with each other.
Includes testing with and without attn_clip_qkv, attn_qk_ln, alibi, and
rope.
Includes testing with and without attn_clip_qkv, attn_qk_ln, attn_qk_gn,
alibi, and rope.
"""
alibi = pos_emb_config['alibi']
rope = pos_emb_config['rope']
Expand Down Expand Up @@ -104,6 +104,7 @@ def test_attn_impl(attn_impl_0: str,
'attn_pdrop': 0,
'clip_qkv': clip_qkv,
'qk_ln': qk_ln,
'qk_gn': qk_gn,
})

n, s, f = 2, 4, cfg.d_model
Expand Down Expand Up @@ -260,7 +261,7 @@ def gen_bias(attn_impl: str):
'rope_impl'] == 'hf'

# special case that (likely) fails due to numerics
if clip_qkv and qk_ln and using_hf_rope and attn_type == 'grouped_query_attention':
if clip_qkv and (qk_ln or qk_gn) and using_hf_rope and attn_type == 'grouped_query_attention':
assert allclose_helper(p.grad, tp.grad, atol=2.e-2, rtol=2.e-2)
else:
assert allclose_helper(p.grad, tp.grad)
Expand Down

0 comments on commit 63de0b4

Please sign in to comment.