diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py index 2d0e7ad05b..1a1a54fb7d 100644 --- a/tests/models/layers/test_flash_triton_torch.py +++ b/tests/models/layers/test_flash_triton_torch.py @@ -27,7 +27,6 @@ def allclose_helper(t0: torch.Tensor, ('triton', 'torch'), ]) @pytest.mark.parametrize('clip_qkv', [True, False]) -@pytest.mark.parametrize('qk_ln', [True, False]) @pytest.mark.parametrize('qk_ln, qk_gn', [ (True, False), (False, True), @@ -68,6 +67,7 @@ def test_attn_impl(attn_impl_0: str, attn_impl_1: str, clip_qkv: bool, qk_ln: bool, + qk_gn: bool, pos_emb_config: dict, attn_type: str, attn_uses_sequence_id: bool, @@ -75,8 +75,8 @@ def test_attn_impl(attn_impl_0: str, device: str = 'cuda'): """Compare all attn impl with each other. - Includes testing with and without attn_clip_qkv, attn_qk_ln, alibi, and - rope. + Includes testing with and without attn_clip_qkv, attn_qk_ln, attn_qk_gn, + alibi, and rope. """ alibi = pos_emb_config['alibi'] rope = pos_emb_config['rope'] @@ -104,6 +104,7 @@ def test_attn_impl(attn_impl_0: str, 'attn_pdrop': 0, 'clip_qkv': clip_qkv, 'qk_ln': qk_ln, + 'qk_gn': qk_gn, }) n, s, f = 2, 4, cfg.d_model @@ -260,7 +261,7 @@ def gen_bias(attn_impl: str): 'rope_impl'] == 'hf' # special case that (likely) fails due to numerics - if clip_qkv and qk_ln and using_hf_rope and attn_type == 'grouped_query_attention': + if clip_qkv and (qk_ln or qk_gn) and using_hf_rope and attn_type == 'grouped_query_attention': assert allclose_helper(p.grad, tp.grad, atol=2.e-2, rtol=2.e-2) else: assert allclose_helper(p.grad, tp.grad)