Skip to content

Commit

Permalink
fix: refactor fp32 for torch, moved scaling of fp8 to out of kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkranias-amd committed Dec 13, 2024
1 parent fd342f7 commit 543736b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 54 deletions.
33 changes: 13 additions & 20 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,11 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
k_offs_n = None
k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL)
k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k)
if IS_FP8:
k = (k.to(tl.float16) / k_scale.to(tl.float16)).to(k.type.element_ty)

if PRE_LOAD_V:
# We can use the same offsets as k, just with dims transposed.
v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL)
if IS_FP8:
v = (v.to(tl.float16) / v_scale.to(tl.float16)).to(v.type.element_ty)

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
Expand All @@ -107,7 +104,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
qk = tl.where(mask, qk, float("-inf"))

# -- compute qk ----
qk += tl.dot(q, k)
qk += tl.dot(q.to(tl.float16), k.to(tl.float16))
qk_scaled = qk * SM_SCALE
if IS_FP8:
qk_scaled = qk_scaled * q_scale * k_scale # descale qk after matmul if quantized
Expand Down Expand Up @@ -173,17 +170,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
acc = acc * alpha[:, None]
if not PRE_LOAD_V:
v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL)
if IS_FP8:
v = (v.to(tl.float16) / v_scale.to(tl.float16)).to(v.type.element_ty)
# -- update m_i and l_i
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij

if IS_FP8:
p_scale = 1 # NOTE: for proper scaling set this = tl.max(p) (increases error)
p_scaled = (p / p_scale)
acc += tl.dot(p_scaled.to(v.type.element_ty), v.to(v.type.element_ty)).to(tl.float32) * v_scale * p_scale # if you want to use p_scaled: tl.dot(p_scaled.to(v.type.element_ty), v.to(v.type.element_ty)) * v_scale * p_scale
acc += tl.dot(p.to(v.type.element_ty), v.to(v.type.element_ty)).to(tl.float32) * v_scale # if you want to use p_scaled: tl.dot(p_scaled.to(v.type.element_ty), v.to(v.type.element_ty)) * v_scale * p_scale
else:
# NOTE: if you make the below operation tl.float16 + set FLASH_ATTENTION_TRITON_AMD_REMOVE_QUANT_SCALE=1. It passes. --> acc += tl.dot(p.to(tl.float16), v.to(tl.float16)) PASSES
acc += tl.dot(p.to(v.type.element_ty), v).to(tl.float32)
Expand Down Expand Up @@ -416,8 +409,6 @@ def attn_fwd(Q, K, V, bias, Q_SCALE, K_SCALE, V_SCALE, stride_qscale_z, stride_k
# if IS FP8 get q_scale and quantize
if IS_FP8:
q_scale = tl.load(Q_SCALE + off_z*stride_qscale_z + off_h_q)
q = (q.to(tl.float16) / q_scale.to(tl.float16)).to(q.type.element_ty) # scale q by q_scale

k_scale = tl.load(K_SCALE + off_z*stride_kvscale_z + off_h_k)
v_scale = tl.load(V_SCALE + off_z*stride_kvscale_z + off_h_k)
else:
Expand Down Expand Up @@ -570,12 +561,16 @@ def attention_prefill_forward_triton_impl(

is_fp8 = check_is_fp8(q)

# if qkv are fp8, then find scaling factor for quantization
q_scale, k_scale, v_scale = create_scale_tensors(q, k, v, SCALE_PER_HEAD=True, layout=layout) # TODO: if SCALE_PER_HEAD: within the kernel itself just compute qkv_scale = tl.max(q or k or v)
q_scale_stride_z = q_scale.stride(0)
kv_scale_stride_z = k_scale.stride(0)

# import pdb; pdb.set_trace()
if is_fp8:
# if qkv are fp8, then find scaling factor for quantization
q_scale, k_scale, v_scale = create_scale_tensors(q, k, v, SCALE_PER_HEAD=True, layout=layout) # TODO: if SCALE_PER_HEAD: within the kernel itself just compute qkv_scale = tl.max(q or k or v)
q_scale_stride_z = q_scale.stride(0)
kv_scale_stride_z = k_scale.stride(0)
q = (q.to(torch.float32) / q_scale.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, q.shape[-2], q.shape[-1])).to(q.dtype)
k = (k.to(torch.float32) / k_scale.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, k.shape[-2], k.shape[-1])).to(k.dtype)
v = (v.to(torch.float32) / v_scale.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, v.shape[-2], v.shape[-1])).to(v.dtype)
else:
q_scale = k_scale = v_scale = 1

if DEBUG:
print()
Expand Down Expand Up @@ -661,8 +656,6 @@ def attention_prefill_forward_triton_impl(
else:
alibi_strides = (0, 0)

# import pdb; pdb.set_trace()

attn_fwd[grid](q, k, v, bias, q_scale, k_scale, v_scale, q_scale_stride_z, kv_scale_stride_z, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes,
Expand Down
48 changes: 18 additions & 30 deletions flash_attn/flash_attn_triton_amd/fwd_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,8 @@

DEBUG_CORE = False

def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2):
is_fp8 = check_is_fp8(q)
if is_fp8:
# if qkv are fp8, then find scaling factor for quantization
q_scale, k_scale, v_scale = create_scale_tensors(q, k, v, SCALE_PER_HEAD=True, layout=layout) # TODO: if SCALE_PER_HEAD: within the kernel itself just compute qkv_scale = tl.max(q or k or v)
q_scale_stride_z = q_scale.stride(0)
kv_scale_stride_z = k_scale.stride(0)

# scale qkv tensors if FP8
q = q / q_scale
k = k / k_scale
v = v / v_scale
else:
q_scale = k_scale = v_scale = 1
def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2, is_fp8):

if DEBUG_CORE:
print()
print("attention_forward_core_ref_impl")
Expand All @@ -32,20 +20,14 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p
print("use_exp2:", use_exp2)
print('layout:', layout)
print('is_fp8:', is_fp8)
print('q_scale:', q_scale)
print('k_scale:', k_scale)
print('v_scale:', v_scale)

# cast to float32
q = q.to(torch.float32)
k = k.to(torch.float32)
v = v.to(torch.float32)

# Compute attention scores
if is_fp8:
attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32)) * q_scale * v_scale
else:
attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32))
attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32))
if DEBUG_CORE:
print("attention_scores:", attention_scores, attention_scores.shape)

Expand Down Expand Up @@ -150,10 +132,7 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p
print("softmax_lse:", softmax_lse, softmax_lse.shape)

# Compute output
if is_fp8:
o = torch.matmul(p, v.to(torch.float32)) * v_scale
else:
o = torch.matmul(p, v)
o = torch.matmul(p, v)
if DEBUG_CORE:
print("o:", o, o.shape)

Expand All @@ -164,7 +143,7 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p

return o, softmax_lse, sd_mask

def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2):
def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2, is_fp8):
"""Compute reference output and softmax_lse using PyTorch's built-in function"""

# Ensure the layout is 'bhsd'
Expand Down Expand Up @@ -200,7 +179,7 @@ def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout

# Call the core attention function
o, softmax_lse, sd_mask = attention_forward_core_ref_impl(
q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2
q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2, is_fp8
)

if group_size != 1:
Expand Down Expand Up @@ -238,7 +217,8 @@ def attention_varlen_forward_pytorch_ref_impl(
dropout_p,
philox_seed,
philox_offset,
use_exp2
use_exp2,
is_fp8
):
# Ensure the layout is 'thd'
if layout != 'thd':
Expand Down Expand Up @@ -302,7 +282,7 @@ def attention_varlen_forward_pytorch_ref_impl(
v_i = v_i.reshape(nheads_k, seqlen_k, head_dim)

# Call the core attention function for this sequence
o_i, softmax_lse_i, sd_mask_i = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2)
o_i, softmax_lse_i, sd_mask_i = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2, is_fp8)

# Reshape outputs back to original dimensions
if group_size != 1:
Expand Down Expand Up @@ -365,6 +345,12 @@ def attention_forward_pytorch_ref_impl(
print("philox_offset:", philox_offset)
print("use_exp2:", use_exp2)

is_fp8 = check_is_fp8(q)

# if is fp8 upcast to fp32 for torch ops to be supported
if is_fp8:
q, k, v = q.to(torch.float32), k.to(torch.float32), v.to(torch.float32)

# compute reference
if layout == "thd":
o_ref, softmax_lse_ref, sd_mask_ref = attention_varlen_forward_pytorch_ref_impl(
Expand All @@ -382,6 +368,7 @@ def attention_forward_pytorch_ref_impl(
philox_seed,
philox_offset,
use_exp2,
is_fp8
)
else:
o_ref, softmax_lse_ref, sd_mask_ref = attention_vanilla_forward_pytorch_ref_impl(q.clone(),
Expand All @@ -393,7 +380,8 @@ def attention_forward_pytorch_ref_impl(
dropout_p,
philox_seed,
philox_offset,
use_exp2)
use_exp2,
is_fp8)

if DEBUG:
print()
Expand Down
4 changes: 2 additions & 2 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropou
if layout == "thd":
q, k, v, metadata = varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
else:
q, k, v, q_fp32, k_fp32, v_fp32, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT)
q, k, v, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT)
if DEBUG_INPUT:
output_triton = torch.zeros_like(q).contiguous()
else:
Expand Down Expand Up @@ -436,7 +436,7 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropou
metadata.use_exp2)

output_ref, softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl(
q_fp32, k_fp32, v_fp32,
q, k, v,
metadata.sm_scale,
causal,
layout,
Expand Down
3 changes: 1 addition & 2 deletions flash_attn/flash_attn_triton_amd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device="cud
v = torch.randn(k_tensor_shape, dtype=torch.float32, device=device, requires_grad=True)

q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)
q_fp32, k_fp32, v_fp32 = q.to(torch.float32), k.to(torch.float32), v.to(torch.float32)

if DEBUG_INPUT:
sm_scale = 1
Expand All @@ -177,7 +176,7 @@ def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device="cud
input_metadata.max_seqlens_q = N_CTX_Q
input_metadata.max_seqlens_k = N_CTX_K
input_metadata.layout = layout
return q, k, v, q_fp32, k_fp32, v_fp32, input_metadata
return q, k, v, input_metadata


def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device="cuda", equal_seqlens=False, DEBUG_INPUT=False):
Expand Down

0 comments on commit 543736b

Please sign in to comment.