Skip to content

Commit

Permalink
add special case
Browse files Browse the repository at this point in the history
  • Loading branch information
sashaDoubov committed Nov 15, 2023
1 parent f603b46 commit 51a43b1
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,15 @@ def gen_bias(attn_impl: str):
tp = torch_name_param_map[n]
assert p.grad is not None
assert tp.grad is not None
assert allclose_helper(p, tp)
# Increased tolerance due to rope_impl=hf having 1 failing element
# in the torch vs. triton, clip=True, qk_ln=True case
assert allclose_helper(p.grad, tp.grad, atol=2.e-2, rtol=2.e-2)

using_hf_rope = pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'hf'

# special case that (likely) fails due to numerics
if clip_qkv and qk_ln and using_hf_rope and attn_type == 'grouped_query_attention':
assert allclose_helper(p.grad, tp.grad, atol=2.e-2, rtol=2.e-2)
else:
assert allclose_helper(p.grad, tp.grad)

assert x0.grad is not None
assert x1.grad is not None
Expand Down

0 comments on commit 51a43b1

Please sign in to comment.