Skip to content

Commit

Permalink
feat: fp8 tests. small amount of error
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkranias-amd committed Dec 4, 2024
1 parent 83d33cc commit ec71399
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
5 changes: 0 additions & 5 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,6 @@ def attention_prefill_forward_triton_impl(
else:
alibi_strides = (0, 0)

import pdb; pdb.set_trace()


attn_fwd[grid](q, k, v, bias, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
Expand All @@ -650,9 +648,6 @@ def attention_prefill_forward_triton_impl(
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_scores)


import pdb; pdb.set_trace()

if DEBUG:
print()
Expand Down
6 changes: 5 additions & 1 deletion flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def test_op_bwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, torch_sdpa_test, use_ali
@pytest.mark.parametrize('use_exp2', [True, False]) # works when use_exp2 is false
@pytest.mark.parametrize('DEBUG_INPUT', [False]) # NOTE: debug input can overflow when the tensors are large. Just use to figure out issues
def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scores, layout, use_exp2, DEBUG_INPUT):
dtype = torch.float16
dtype = torch.float8_e4m3fnuz
torch.manual_seed(0)
alibi_slopes = None
dropout_p = 0.0
Expand Down Expand Up @@ -473,6 +473,10 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, return
print("softmax_triton:", softmax_triton, softmax_triton.shape)
print("softmax_ref:", softmax_ref, softmax_ref.shape)
torch.testing.assert_close(softmax_triton, softmax_ref, atol=ATOL, rtol=RTOL)

# if triton is fp8, cast to fp16 in order to compare with ref
if output_triton.dtype in {torch.float8_e4m3fnuz}:
output_triton = output_triton.to(torch.float16)

if DEBUG:
print("output_triton:", output_triton, output_triton.shape)
Expand Down

0 comments on commit ec71399

Please sign in to comment.