Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 bwd #108

Draft
wants to merge 10 commits into
base: main_perf
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx1100"]]'
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"]]'
else
echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]'
fi
Expand Down
126 changes: 126 additions & 0 deletions benchmarks/benchmark_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Install the newest triton version with
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
import pickle
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

from einops import rearrange, repeat

from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined

from flash_attn import flash_attn_qkvpacked_func

try:
from triton.ops.flash_attention import attention as attention_triton
except ImportError:
attention_triton = None

try:
import xformers.ops as xops
except ImportError:
xops = None


def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
assert mode in ["fwd", "bwd", "fwd_bwd"]
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)

def efficiency(flop, time):
return (flop / time / 10**12) if not math.isnan(time) else 0.0


def attention_pytorch(qkv, dropout_p=0.0, causal=True):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
dropout_p: float
Output:
output: (batch_size, seqlen, nheads, head_dim)
"""
batch_size, seqlen, _, nheads, d = qkv.shape
q, k, v = qkv.unbind(dim=2)
q = rearrange(q, 'b t h d -> (b h) t d')
k = rearrange(k, 'b s h d -> (b h) d s')
softmax_scale = 1.0 / math.sqrt(d)
# Preallocate attn_weights for `baddbmm`
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
'(b h) t s -> b h t s', h=nheads)
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1)
attention_drop = F.dropout(attention, dropout_p)
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
return output.to(dtype=qkv.dtype)


def time_fwd(func, *args, **kwargs):
time_f = benchmark_forward(func, *args, **kwargs)
return time_f[1].mean


repeats = 30
device = 'cuda'
dtype = torch.float16

bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
causal_vals = [False, True]
headdim_vals = [64, 128]
dim = 2048
dropout_p = 0.0

methods = ["Flash2_fp8", "Flash2_fp16"]

time_f = {}
time_b = {}
time_f_b = {}
speed_f = {}
speed_b = {}
speed_f_b = {}
"""
NOTE: currently the torch.amax needed to find the scaling values of the fp8 tensors adds a huge amount of overhead.
It makes fp8 kernels have lower throughput than fp16 counterparts.

We can disable scaling using the following env variable:

FLASH_ATTENTION_TRITON_AMD_REMOVE_QUANT_SCALE=1
"""
for causal in causal_vals:
for headdim in headdim_vals:
for batch_size, seqlen in bs_seqlen_vals:
config = (causal, headdim, batch_size, seqlen)
nheads = dim // headdim
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
t_fp8 = time_fwd(
flash_attn_qkvpacked_func, qkv.to(torch.float8_e4m3fnuz), dropout_p, causal=causal, repeats=repeats, verbose=False
)
time_f[config, "Flash2_fp8"] = t_fp8

t_fp16 = time_fwd(
flash_attn_qkvpacked_func, qkv.to(torch.float16), dropout_p, causal=causal, repeats=repeats, verbose=False
)
time_f[config, "Flash2_fp16"] = t_fp16

print(f"\n### ENABLE_QUANTIZATION_SCALING={os.getenv('FLASH_ATTENTION_TRITON_AMD_REMOVE_QUANT_SCALE') != 1}, causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
for method in methods:
speed_f[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
time_f[config, method]
)
print(
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
)


# with open('flash2_attn_time.plk', 'wb') as fp:
# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
62 changes: 53 additions & 9 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import triton
import triton.language as tl
from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask
from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask, create_scale_tensors, check_is_fp8

# NOTE: triton fails to import tl.constexprs so create them here for the file
tl_DROPOUT_USE_PYTORCH: tl.constexpr = DROPOUT_USE_PYTORCH
Expand Down Expand Up @@ -63,6 +63,7 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo
def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m,
actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs,
block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope,
q_scale, k_scale, v_scale, IS_FP8: tl.constexpr,
IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr,
Expand All @@ -81,9 +82,14 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
k_offs_n = None
k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL)
k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k)
if IS_FP8:
k = (k.to(tl.float16) / k_scale.to(tl.float16)).to(k.type.element_ty)

if PRE_LOAD_V:
# We can use the same offsets as k, just with dims transposed.
v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL)
if IS_FP8:
v = (v.to(tl.float16) / v_scale.to(tl.float16)).to(v.type.element_ty)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
Expand All @@ -103,7 +109,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
# -- compute qk ----
qk += tl.dot(q, k)
qk_scaled = qk * SM_SCALE

if IS_FP8:
qk_scaled = qk_scaled * q_scale * k_scale # descale qk after matmul if quantized
if IS_CAUSAL:
causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
Expand Down Expand Up @@ -135,7 +142,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
p_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)

# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
l_ij = tl.sum(p, 1) # p = fp32 at this point
if ENABLE_DROPOUT:
if tl_DROPOUT_USE_PYTORCH:
dropout_mask = tl.load(dropout_mask_ptrs, mask=p_mask)
Expand Down Expand Up @@ -166,11 +173,21 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
acc = acc * alpha[:, None]
if not PRE_LOAD_V:
v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL)
if IS_FP8:
v = (v.to(tl.float16) / v_scale.to(tl.float16)).to(v.type.element_ty)
# -- update m_i and l_i
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
acc += tl.dot(p.to(v.type.element_ty), v)

if IS_FP8:
p_scale = 1 # NOTE: for proper scaling set this = tl.max(p) (increases error)
p_scaled = (p / p_scale)
acc += tl.dot(p_scaled.to(v.type.element_ty), v.to(v.type.element_ty)).to(tl.float32) * v_scale * p_scale # if you want to use p_scaled: tl.dot(p_scaled.to(v.type.element_ty), v.to(v.type.element_ty)) * v_scale * p_scale
else:
# NOTE: if you make the below operation tl.float16 + set FLASH_ATTENTION_TRITON_AMD_REMOVE_QUANT_SCALE=1. It passes. --> acc += tl.dot(p.to(tl.float16), v.to(tl.float16)) PASSES
acc += tl.dot(p.to(v.type.element_ty), v).to(tl.float32)

k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
if bias_ptrs is not None:
Expand Down Expand Up @@ -259,15 +276,15 @@ def get_autotune_configs():
use_cuda_graph=True,
)
@triton.jit
def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
def attn_fwd(Q, K, V, bias, Q_SCALE, K_SCALE, V_SCALE, stride_qscale_z, stride_kvscale_z, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah,
stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr,
HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr):
ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
Expand Down Expand Up @@ -396,6 +413,16 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL)
q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0)

# if IS FP8 get q_scale and quantize
if IS_FP8:
q_scale = tl.load(Q_SCALE + off_z*stride_qscale_z + off_h_q)
q = (q.to(tl.float16) / q_scale.to(tl.float16)).to(q.type.element_ty) # scale q by q_scale

k_scale = tl.load(K_SCALE + off_z*stride_kvscale_z + off_h_k)
v_scale = tl.load(V_SCALE + off_z*stride_kvscale_z + off_h_k)
else:
q_scale, k_scale, v_scale = 1.0, 1.0, 1.0

# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
Expand All @@ -421,11 +448,12 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
sd_mask_ptrs, dropout_mask_ptrs,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min, block_max, 0, 0, 0, alibi_slope,
q_scale, k_scale, v_scale, IS_FP8,
# IS_CAUSAL, ....
False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V, False, ENABLE_DROPOUT, PADDED_HEAD,
ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
block_min = block_max
block_max = n_blocks * BLOCK_N

Expand All @@ -449,6 +477,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs,
sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks,
n_extra_tokens, alibi_slope,
q_scale, k_scale, v_scale, IS_FP8,
IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD,
Expand Down Expand Up @@ -538,6 +567,15 @@ def attention_prefill_forward_triton_impl(
# misc
return_softmax,
use_exp2):

is_fp8 = check_is_fp8(q)

# if qkv are fp8, then find scaling factor for quantization
q_scale, k_scale, v_scale = create_scale_tensors(q, k, v, SCALE_PER_HEAD=True, layout=layout) # TODO: if SCALE_PER_HEAD: within the kernel itself just compute qkv_scale = tl.max(q or k or v)
q_scale_stride_z = q_scale.stride(0)
kv_scale_stride_z = k_scale.stride(0)

# import pdb; pdb.set_trace()

if DEBUG:
print()
Expand All @@ -546,6 +584,9 @@ def attention_prefill_forward_triton_impl(
print("k:", k, k.shape)
print("v:", v, v.shape)
print("o:", o, o.shape)
print("q_scale", q_scale)
print("k_scale", k_scale)
print("v_scale", v_scale)
print("sm_scale:", sm_scale)
print("alibi_slopes:", alibi_slopes)
print("causal:", causal)
Expand All @@ -561,6 +602,8 @@ def attention_prefill_forward_triton_impl(
print("return_scores:", return_softmax)
print("use_exp2:", use_exp2)

# import pdb; pdb.set_trace()

# check if varlen
is_varlen = layout == "thd"

Expand Down Expand Up @@ -618,15 +661,16 @@ def attention_prefill_forward_triton_impl(
else:
alibi_strides = (0, 0)

# import pdb; pdb.set_trace()

attn_fwd[grid](q, k, v, bias, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
attn_fwd[grid](q, k, v, bias, q_scale, k_scale, v_scale, q_scale_stride_z, kv_scale_stride_z, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen,
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax)
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=is_fp8)

if DEBUG:
print()
Expand Down
Loading
Loading