diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index e573c068d6..5a3196b295 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -182,10 +182,15 @@ def gen_bias(attn_impl: str): tp = torch_name_param_map[n] assert p.grad is not None assert tp.grad is not None - assert allclose_helper(p, tp) - # Increased tolerance due to rope_impl=hf having 1 failing element - # in the torch vs. triton, clip=True, qk_ln=True case - assert allclose_helper(p.grad, tp.grad, atol=2.e-2, rtol=2.e-2) + + using_hf_rope = pos_emb_config['rope'] and pos_emb_config[ + 'rope_impl'] == 'hf' + + # special case that (likely) fails due to numerics + if clip_qkv and qk_ln and using_hf_rope and attn_type == 'grouped_query_attention': + assert allclose_helper(p.grad, tp.grad, atol=2.e-2, rtol=2.e-2) + else: + assert allclose_helper(p.grad, tp.grad) assert x0.grad is not None assert x1.grad is not None