diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 3c0900e8..89160948 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -22,7 +22,8 @@ # check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498). HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) -DATA_PARTITION = os.getenv("DATA_PARTITION_FA") +WITH_COMPPIPE = os.getenv("ENABLE_COMPPIPE") +PEEL_LAST = os.getenv("PEEL_LAST_ITER") if HAS_TMA_DESC: print( @@ -113,6 +114,110 @@ def get_tma_descriptor_kernel_param(self, name): return self.cuda_descriptors[name] +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, # + K_block_ptr, + V_block_ptr, # + desc_k, + desc_v, + Q, + qvk_offset, + stride_kn, + stride_vn, + stride_vk, # + start_m, + qk_scale, # + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, # + N_CTX: tl.constexpr, + fp8_v: tl.constexpr, + ENABLE_TMA: tl.constexpr, + LOOP_SCHEDULE: tl.constexpr, +): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + # causal = False + else: + lo, hi = 0, N_CTX + if not ENABLE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + # loop over k, v and update accumulator + for start_n in tl.range(lo, hi, BLOCK_N): # , loop_schedule=LOOP_SCHEDULE): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + if ENABLE_TMA: + k = tl._experimental_descriptor_load( # load in row major + desc_k, + [start_n.to(tl.int32) + (qvk_offset // stride_kn).to(tl.int32), 0], + [BLOCK_N, HEAD_DIM], + Q.dtype.element_ty, + ) + else: + k = tl.load(K_block_ptr) + if ENABLE_TMA: + k = tl.trans(k) + qk = tl.dot(q, k) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + # -- update output accumulator -- + acc = acc * alpha[:, None] + # update acc + if ENABLE_TMA: + if fp8_v: + v = tl._experimental_descriptor_load( # load in row major + desc_v, + [(qvk_offset // stride_vn).to(tl.int32), start_n.to(tl.int32)], + [HEAD_DIM, BLOCK_N], + Q.dtype.element_ty, + ) + else: + v = tl._experimental_descriptor_load( # load in row major + desc_v, + [(qvk_offset // stride_vk + start_n).to(tl.int32), 0], + [BLOCK_N, HEAD_DIM], + Q.dtype.element_ty, + ) + else: + v = tl.load(V_block_ptr) + if fp8_v: + if ENABLE_TMA: + v = tl.trans(v) + p = p.to(tl.float8e5) + else: + p = p.to(tl.bfloat16) + acc = tl.dot(p, v, acc) + # update m_i and l_i + m_i = m_ij + if not ENABLE_TMA: + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + + @triton.jit def _attn_fwd_inner_ws( acc, @@ -121,6 +226,13 @@ def _attn_fwd_inner_ws( q, # K_block_ptr, V_block_ptr, # + desc_k, + desc_v, + Q, + qvk_offset, + stride_kn, + stride_vn, + stride_vk, # start_m, qk_scale, # BLOCK_M: tl.constexpr, @@ -131,6 +243,8 @@ def _attn_fwd_inner_ws( offs_n: tl.constexpr, # N_CTX: tl.constexpr, fp8_v: tl.constexpr, + ENABLE_TMA: tl.constexpr, + LOOP_SCHEDULE: tl.constexpr, ): # range of values handled by this stage if STAGE == 1: @@ -141,17 +255,26 @@ def _attn_fwd_inner_ws( # causal = False else: lo, hi = 0, N_CTX - K_block_ptr = tl.advance(K_block_ptr, (0, lo)) - V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + if not ENABLE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # loop over k, v and update accumulator - for start_n in tl.range( - lo, hi, BLOCK_N, loop_schedule="FA_secondDot" - ): # FA_firstDot FA_secondDot + for start_n in tl.range(lo, hi, BLOCK_N): # , loop_schedule=LOOP_SCHEDULE): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- with tl.async_task([0]): - k = tl.load(K_block_ptr) + if ENABLE_TMA: + k = tl._experimental_descriptor_load( # load in row major + desc_k, + [start_n.to(tl.int32) + (qvk_offset // stride_kn).to(tl.int32), 0], + [BLOCK_N, HEAD_DIM], + Q.dtype.element_ty, + ) + else: + k = tl.load(K_block_ptr) with tl.async_task([1, 2]): + if ENABLE_TMA: + k = tl.trans(k) qk = tl.dot(q, k) if STAGE == 2: mask = offs_m[:, None] >= (start_n + offs_n[None, :]) @@ -170,30 +293,118 @@ def _attn_fwd_inner_ws( acc = acc * alpha[:, None] # update acc with tl.async_task([0]): - v = tl.load(V_block_ptr) + if ENABLE_TMA: + if fp8_v: + v = tl._experimental_descriptor_load( # load in row major + desc_v, + [(qvk_offset // stride_vn).to(tl.int32), start_n.to(tl.int32)], + [HEAD_DIM, BLOCK_N], + Q.dtype.element_ty, + ) + else: + v = tl._experimental_descriptor_load( # load in row major + desc_v, + [(qvk_offset // stride_vk + start_n).to(tl.int32), 0], + [BLOCK_N, HEAD_DIM], + Q.dtype.element_ty, + ) + else: + v = tl.load(V_block_ptr) with tl.async_task([1, 2]): if fp8_v: + if ENABLE_TMA: + v = tl.trans(v) p = p.to(tl.float8e5) else: p = p.to(tl.bfloat16) acc = tl.dot(p, v, acc) # update m_i and l_i m_i = m_ij - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if not ENABLE_TMA: + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) return acc, l_i, m_i # We don't run auto-tuning every time to keep the tutorial fast. Uncommenting # the code below and commenting out the equivalent parameters is convenient for # re-tuning. -BMIter = [128] if DATA_PARTITION else [64] has_warp_spec = hasattr(tl, "async_task") +schedList = ["default", "FA_firstDot", "FA_secondDot"] if WITH_COMPPIPE else ["default"] +# TODO: incorrect result with PEEL_LAST + FA_firstDot + WarpSpec + TMA +schedList = ["FA_secondDot"] if PEEL_LAST else schedList +# no WS, no TMA, with CompPipe +configsOpt = [ + ( + triton.Config( + { + "BLOCK_M": BM, + "BLOCK_N": BN, + "ENABLE_TMA": enable_tma, + "LOOP_SCHEDULE": sched, + }, + num_stages=4 if sched == "FA_firstDot" or sched == "FA_secondDot" else 3, + num_warps=w, + num_buffers_warp_spec=0, + num_consumer_groups=0, + ) + if has_warp_spec + else triton.Config( + { + "BLOCK_M": BM, + "BLOCK_N": BN, + "ENABLE_TMA": enable_tma, + "LOOP_SCHEDULE": sched, + }, + num_stages=4 if sched == "FA_firstDot" or sched == "FA_secondDot" else 3, + num_warps=w, + ) + ) + for BM in [128] + for BN in [128] + for sched in schedList + for enable_tma in [False] + for w in [8] +] +# no WS, with TMA and CompPipe +configsTma = [ + ( + triton.Config( + { + "BLOCK_M": BM, + "BLOCK_N": BN, + "ENABLE_TMA": enable_tma, + "LOOP_SCHEDULE": sched, + }, + num_stages=4 if sched == "FA_firstDot" or sched == "FA_secondDot" else 3, + num_warps=w, + num_buffers_warp_spec=0, + num_consumer_groups=0, + ) + if has_warp_spec + else triton.Config( + { + "BLOCK_M": BM, + "BLOCK_N": BN, + "ENABLE_TMA": enable_tma, + "LOOP_SCHEDULE": sched, + }, + num_stages=4 if sched == "FA_firstDot" or sched == "FA_secondDot" else 3, + num_warps=w, + ) + ) + for BM in [128] + for BN in [128] + for sched in schedList + for enable_tma in [True] + for w in [8] +] +# no TMA, with WS and CompPipe configsWS = [ ( triton.Config( - {"BLOCK_M": BM, "BLOCK_N": BN}, - num_stages=s, + {"BLOCK_M": BM, "BLOCK_N": BN, "ENABLE_TMA": False, "LOOP_SCHEDULE": sched}, + num_stages=2 if sched == "FA_firstDot" or sched == "FA_secondDot" else 0, num_warps=w, num_buffers_warp_spec=buf, num_consumer_groups=grp, @@ -201,38 +412,64 @@ def _attn_fwd_inner_ws( reg_inc_consumer=inc, ) if has_warp_spec - else triton.Config({"BLOCK_M": BM, "BLOCK_N": BN}, num_stages=s, num_warps=w) + else triton.Config( + {"BLOCK_M": BM, "BLOCK_N": BN, "ENABLE_TMA": False, "LOOP_SCHEDULE": sched}, + num_stages=2 if sched == "FA_firstDot" or sched == "FA_secondDot" else 0, + num_warps=w, + ) ) - for BM in BMIter # 128 with data partitioning, 64 with grid partitioning + for BM in [128] for BN in [128] - for s in [2] # change to 2 if firstDot or secondDot + for sched in schedList + for enable_ws in [True] for w in [4] for buf in [2] for grp in [2] for dec, inc in [(24, 240), (40, 232)] # 32,240 hangs, 24, 240 works 40, 232 works ] +# BLOCK_M: 128, BLOCK_N: 128, ENABLE_TMA: False, LOOP_SCHEDULE: default, num_warps: 8, num_ctas: 1, num_stages: 3 configsOrig = [ ( triton.Config( - {"BLOCK_M": BM, "BLOCK_N": BN}, + { + "BLOCK_M": BM, + "BLOCK_N": BN, + "ENABLE_TMA": False, + "LOOP_SCHEDULE": "default", + }, num_stages=s, num_warps=w, num_buffers_warp_spec=0, num_consumer_groups=0, ) if has_warp_spec - else triton.Config({"BLOCK_M": BM, "BLOCK_N": BN}, num_stages=s, num_warps=w) + else triton.Config( + { + "BLOCK_M": BM, + "BLOCK_N": BN, + "ENABLE_TMA": False, + "LOOP_SCHEDULE": "default", + }, + num_stages=s, + num_warps=w, + ) ) - for BM in [64, 128] - for BN in [64, 128] - for s in [3, 4, 7] - for w in [4, 8] + for BM in [128] # 64, 128] + for BN in [128] # 64, 128] + for s in [3] # , 4, 7] + for w in [8] # 4, 8] ] -configsTma = [ +# TMA, WS, and CompPipe +configsTmaWS = [ ( triton.Config( - {"BLOCK_M": BM, "BLOCK_N": BN}, - num_stages=s, + { + "BLOCK_M": BM, + "BLOCK_N": BN, + "ENABLE_TMA": enable_tma, + "LOOP_SCHEDULE": sched, + }, + num_stages=2 if sched == "FA_firstDot" or sched == "FA_secondDot" else 0, num_warps=w, num_buffers_warp_spec=buf, num_consumer_groups=grp, @@ -240,11 +477,22 @@ def _attn_fwd_inner_ws( reg_inc_consumer=inc, ) if has_warp_spec - else triton.Config({"BLOCK_M": BM, "BLOCK_N": BN}, num_stages=s, num_warps=w) + else triton.Config( + { + "BLOCK_M": BM, + "BLOCK_N": BN, + "ENABLE_TMA": enable_tma, + "LOOP_SCHEDULE": sched, + }, + num_stages=2 if sched == "FA_firstDot" or sched == "FA_secondDot" else 0, + num_warps=w, + ) ) - for BM in BMIter # 128 with data partitioning, 64 with grid partitioning + for BM in [128] for BN in [128] - for s in [2] # change to 2 if firstDot or secondDot + for sched in schedList + for enable_tma in [True] + for enable_ws in [True] for w in [4] for buf in [2] for grp in [2] # 2 @@ -262,15 +510,18 @@ def keep(conf): return True -@triton.autotune(list(filter(keep, configsWS)), key=["N_CTX"]) @triton.jit -def _attn_fwd_ws( +def _attn_fwd_compute( Q, K, V, sm_scale, M, Out, # + desc_q, + desc_k, + desc_v, + desc_o, stride_qz, stride_qh, stride_qm, @@ -294,6 +545,8 @@ def _attn_fwd_ws( BLOCK_N: tl.constexpr, # HEAD_DIM: tl.constexpr, # STAGE: tl.constexpr, # + ENABLE_TMA: tl.constexpr, + LOOP_SCHEDULE: tl.constexpr, ): tl.static_assert(BLOCK_N <= HEAD_DIM) start_m = tl.program_id(0) @@ -302,40 +555,45 @@ def _attn_fwd_ws( off_h = off_hz % H qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh - # block pointers - Q_block_ptr = tl.make_block_ptr( - base=Q + qvk_offset, - shape=(N_CTX, HEAD_DIM), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, HEAD_DIM), - order=(1, 0), - ) - v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) - V_block_ptr = tl.make_block_ptr( - base=V + qvk_offset, - shape=(N_CTX, HEAD_DIM), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, HEAD_DIM), - order=v_order, - ) - K_block_ptr = tl.make_block_ptr( - base=K + qvk_offset, - shape=(HEAD_DIM, N_CTX), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(HEAD_DIM, BLOCK_N), - order=(0, 1), - ) - O_block_ptr = tl.make_block_ptr( - base=Out + qvk_offset, - shape=(N_CTX, HEAD_DIM), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, HEAD_DIM), - order=(1, 0), - ) + K_block_ptr = None + V_block_ptr = None + Q_block_ptr = None + O_block_ptr = None + if not ENABLE_TMA: + # block pointers + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=v_order, + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(HEAD_DIM, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_N), + order=(0, 1), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) @@ -347,19 +605,33 @@ def _attn_fwd_ws( qk_scale = sm_scale qk_scale *= 1.44269504 # 1/log(2) # load q: it will stay in SRAM throughout - with tl.async_task([0]): + if ENABLE_TMA: + q = tl._experimental_descriptor_load( # load in row major + desc_q, + [(qvk_offset // stride_qm + start_m * BLOCK_M).to(tl.int32), 0], + [BLOCK_M, HEAD_DIM], + Q.dtype.element_ty, + ) + else: q = tl.load(Q_block_ptr) # stage 1: off-band # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE if STAGE & 1: - acc, l_i, m_i = _attn_fwd_inner_ws( + acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + desc_k, + desc_v, + Q, + qvk_offset, + stride_kn, + stride_vn, + stride_vk, # start_m, qk_scale, # BLOCK_M, @@ -370,18 +642,27 @@ def _attn_fwd_ws( offs_n, N_CTX, V.dtype.element_ty == tl.float8e5, # + ENABLE_TMA, + LOOP_SCHEDULE, ) # stage 2: on-band if STAGE & 2: # barrier makes it easier for compielr to schedule the # two loops independently - acc, l_i, m_i = _attn_fwd_inner_ws( + acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + desc_k, + desc_v, + Q, + qvk_offset, + stride_kn, + stride_vn, + stride_vk, # start_m, qk_scale, # BLOCK_M, @@ -392,90 +673,36 @@ def _attn_fwd_ws( offs_n, N_CTX, V.dtype.element_ty == tl.float8e5, # + ENABLE_TMA, + LOOP_SCHEDULE, ) # epilogue - with tl.async_task([1, 2]): - m_i += tl.math.log2(l_i) - acc = acc / l_i[:, None] - m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(m_ptrs, m_i) - tl.store(O_block_ptr, acc.to(Out.type.element_ty)) - - -@triton.jit -def _attn_fwd_inner( - acc, - l_i, - m_i, - q, # - K_block_ptr, - V_block_ptr, # - start_m, - qk_scale, # - BLOCK_M: tl.constexpr, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, # - STAGE: tl.constexpr, - offs_m: tl.constexpr, - offs_n: tl.constexpr, # - N_CTX: tl.constexpr, - fp8_v: tl.constexpr, -): - # range of values handled by this stage - if STAGE == 1: - lo, hi = 0, start_m * BLOCK_M - elif STAGE == 2: - lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M - lo = tl.multiple_of(lo, BLOCK_M) - # causal = False + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i) + if ENABLE_TMA: + tl._experimental_descriptor_store( + desc_o, + acc.to(Out.type.element_ty), + [(qvk_offset // stride_om + start_m * BLOCK_M).to(tl.int32), 0], + ) else: - lo, hi = 0, N_CTX - K_block_ptr = tl.advance(K_block_ptr, (0, lo)) - V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(K_block_ptr) - qk = tl.dot(q, k) - if STAGE == 2: - mask = offs_m[:, None] >= (start_n + offs_n[None, :]) - qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - else: - m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - qk = qk * qk_scale - m_ij[:, None] - p = tl.math.exp2(qk) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - alpha = tl.math.exp2(m_i - m_ij) - l_i = l_i * alpha + l_ij - # -- update output accumulator -- - acc = acc * alpha[:, None] - # update acc - v = tl.load(V_block_ptr) - if fp8_v: - p = p.to(tl.float8e5) - else: - p = p.to(tl.bfloat16) - acc = tl.dot(p, v, acc) - # update m_i and l_i - m_i = m_ij - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - return acc, l_i, m_i + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) -@triton.autotune(list(filter(keep, configsOrig)), key=["N_CTX"]) @triton.jit -def _attn_fwd( +def _attn_fwd_compute_ws( Q, K, V, sm_scale, M, Out, # + desc_q, + desc_k, + desc_v, + desc_o, stride_qz, stride_qh, stride_qm, @@ -494,11 +721,13 @@ def _attn_fwd( stride_on, # Z, H, - N_CTX, # + N_CTX, #: tl.constexpr, # BLOCK_M: tl.constexpr, # BLOCK_N: tl.constexpr, # HEAD_DIM: tl.constexpr, # STAGE: tl.constexpr, # + ENABLE_TMA: tl.constexpr, + LOOP_SCHEDULE: tl.constexpr, ): tl.static_assert(BLOCK_N <= HEAD_DIM) start_m = tl.program_id(0) @@ -507,40 +736,45 @@ def _attn_fwd( off_h = off_hz % H qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh - # block pointers - Q_block_ptr = tl.make_block_ptr( - base=Q + qvk_offset, - shape=(N_CTX, HEAD_DIM), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, HEAD_DIM), - order=(1, 0), - ) - v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) - V_block_ptr = tl.make_block_ptr( - base=V + qvk_offset, - shape=(N_CTX, HEAD_DIM), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, HEAD_DIM), - order=v_order, - ) - K_block_ptr = tl.make_block_ptr( - base=K + qvk_offset, - shape=(HEAD_DIM, N_CTX), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(HEAD_DIM, BLOCK_N), - order=(0, 1), - ) - O_block_ptr = tl.make_block_ptr( - base=Out + qvk_offset, - shape=(N_CTX, HEAD_DIM), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, HEAD_DIM), - order=(1, 0), - ) + K_block_ptr = None + V_block_ptr = None + Q_block_ptr = None + O_block_ptr = None + if not ENABLE_TMA: + # block pointers + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=v_order, + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(HEAD_DIM, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_N), + order=(0, 1), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) @@ -552,18 +786,34 @@ def _attn_fwd( qk_scale = sm_scale qk_scale *= 1.44269504 # 1/log(2) # load q: it will stay in SRAM throughout - q = tl.load(Q_block_ptr) + with tl.async_task([0]): + if ENABLE_TMA: + q = tl._experimental_descriptor_load( # load in row major + desc_q, + [(qvk_offset // stride_qm + start_m * BLOCK_M).to(tl.int32), 0], + [BLOCK_M, HEAD_DIM], + Q.dtype.element_ty, + ) + else: + q = tl.load(Q_block_ptr) # stage 1: off-band # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE if STAGE & 1: - acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i = _attn_fwd_inner_ws( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + desc_k, + desc_v, + Q, + qvk_offset, + stride_kn, + stride_vn, + stride_vk, # start_m, qk_scale, # BLOCK_M, @@ -574,18 +824,27 @@ def _attn_fwd( offs_n, N_CTX, V.dtype.element_ty == tl.float8e5, # + ENABLE_TMA, + LOOP_SCHEDULE, ) # stage 2: on-band if STAGE & 2: # barrier makes it easier for compielr to schedule the # two loops independently - acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i = _attn_fwd_inner_ws( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + desc_k, + desc_v, + Q, + qvk_offset, + stride_kn, + stride_vn, + stride_vk, # start_m, qk_scale, # BLOCK_M, @@ -596,119 +855,196 @@ def _attn_fwd( offs_n, N_CTX, V.dtype.element_ty == tl.float8e5, # + ENABLE_TMA, + LOOP_SCHEDULE, ) # epilogue - m_i += tl.math.log2(l_i) - acc = acc / l_i[:, None] - m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(m_ptrs, m_i) - tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + with tl.async_task([1, 2]): + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i) + if ENABLE_TMA: + tl._experimental_descriptor_store( + desc_o, + acc.to(Out.type.element_ty), + [(qvk_offset // stride_om + start_m * BLOCK_M).to(tl.int32), 0], + ) + else: + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) +@triton.autotune(list(filter(keep, configsWS)), key=["N_CTX"]) @triton.jit -def _attn_fwd_inner_tma( - acc, - l_i, - m_i, - q, # - K_desc_ptr, - V_desc_ptr, +def _attn_fwd_ws( Q, - qvk_offset, + K, + V, + sm_scale, + M, + Out, # + desc_q, + desc_k, + desc_v, + desc_o, + stride_qz, + stride_qh, + stride_qm, + stride_qk, # + stride_kz, + stride_kh, stride_kn, - stride_vn, - stride_vk, # - start_m, - qk_scale, # - BLOCK_M: tl.constexpr, - HEAD_DIM: tl.constexpr, + stride_kk, # + stride_vz, + stride_vh, + stride_vk, + stride_vn, # + stride_oz, + stride_oh, + stride_om, + stride_on, # + Z, + H, + N_CTX, #: tl.constexpr, # + BLOCK_M: tl.constexpr, # BLOCK_N: tl.constexpr, # - STAGE: tl.constexpr, - offs_m: tl.constexpr, - offs_n: tl.constexpr, # - N_CTX: tl.constexpr, - fp8_v: tl.constexpr, + HEAD_DIM: tl.constexpr, # + STAGE: tl.constexpr, # + ENABLE_TMA: tl.constexpr, + LOOP_SCHEDULE: tl.constexpr, + ENABLE_WS: tl.constexpr, ): - # range of values handled by this stage - if STAGE == 1: - lo, hi = 0, start_m * BLOCK_M - elif STAGE == 2: - lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M - lo = tl.multiple_of(lo, BLOCK_M) - # causal = False - else: - lo, hi = 0, N_CTX - # loop over k, v and update accumulator - for start_n in tl.range( - lo, hi, BLOCK_N, loop_schedule="FA_secondDot" - ): # FA_firstDot FA_secondDot - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - with tl.async_task([0]): - k = tl._experimental_descriptor_load( # load in row major - K_desc_ptr, - [start_n.to(tl.int32) + (qvk_offset // stride_kn).to(tl.int32), 0], - [BLOCK_N, HEAD_DIM], - Q.dtype.element_ty, - ) - with tl.async_task([1, 2]): - k = tl.trans(k) - qk = tl.dot(q, k) - if STAGE == 2: - mask = offs_m[:, None] >= (start_n + offs_n[None, :]) - qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - else: - m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - qk = qk * qk_scale - m_ij[:, None] - p = tl.math.exp2(qk) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - alpha = tl.math.exp2(m_i - m_ij) - l_i = l_i * alpha + l_ij - # -- update output accumulator -- - acc = acc * alpha[:, None] - # update acc - with tl.async_task([0]): - if fp8_v: - v = tl._experimental_descriptor_load( # load in row major - V_desc_ptr, - [(qvk_offset // stride_vn).to(tl.int32), start_n.to(tl.int32)], - [HEAD_DIM, BLOCK_N], - Q.dtype.element_ty, - ) - else: - v = tl._experimental_descriptor_load( # load in row major - V_desc_ptr, - [(qvk_offset // stride_vk + start_n).to(tl.int32), 0], - [BLOCK_N, HEAD_DIM], - Q.dtype.element_ty, - ) - with tl.async_task([1, 2]): - if fp8_v: - v = tl.trans(v) - p = p.to(tl.float8e5) - else: - p = p.to(tl.bfloat16) - acc = tl.dot(p, v, acc) - # update m_i and l_i - m_i = m_ij - return acc, l_i, m_i + _attn_fwd_compute_ws( + Q, + K, + V, + sm_scale, + M, + Out, # + desc_q, + desc_k, + desc_v, + desc_o, + stride_qz, + stride_qh, + stride_qm, + stride_qk, # + stride_kz, + stride_kh, + stride_kn, + stride_kk, # + stride_vz, + stride_vh, + stride_vk, + stride_vn, # + stride_oz, + stride_oh, + stride_om, + stride_on, # + Z, + H, + N_CTX, #: tl.constexpr, # + BLOCK_M, + BLOCK_N, + HEAD_DIM, + STAGE, + ENABLE_TMA, + LOOP_SCHEDULE, + ) -@triton.autotune(list(filter(keep, configsTma)), key=["N_CTX"]) +@triton.autotune(list(filter(keep, configsOrig)), key=["N_CTX"]) @triton.jit -def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, # +def _attn_fwd( Q, + K, V, - Out, + sm_scale, + M, + Out, # desc_q, desc_k, desc_v, + desc_o, + stride_qz, + stride_qh, + stride_qm, + stride_qk, # + stride_kz, + stride_kh, + stride_kn, + stride_kk, # + stride_vz, + stride_vh, + stride_vk, + stride_vn, # + stride_oz, + stride_oh, + stride_om, + stride_on, # + Z, + H, + N_CTX, #: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + HEAD_DIM: tl.constexpr, # + STAGE: tl.constexpr, # + ENABLE_TMA: tl.constexpr, + LOOP_SCHEDULE: tl.constexpr, + ENABLE_WS: tl.constexpr, +): + _attn_fwd_compute( + Q, + K, + V, + sm_scale, + M, + Out, # + desc_q, + desc_k, + desc_v, + desc_o, + stride_qz, + stride_qh, + stride_qm, + stride_qk, # + stride_kz, + stride_kh, + stride_kn, + stride_kk, # + stride_vz, + stride_vh, + stride_vk, + stride_vn, # + stride_oz, + stride_oh, + stride_om, + stride_on, # + Z, + H, + N_CTX, #: tl.constexpr, # + BLOCK_M, + BLOCK_N, + HEAD_DIM, + STAGE, + ENABLE_TMA, + LOOP_SCHEDULE, + ) + + +@triton.autotune(list(filter(keep, configsOpt)), key=["N_CTX"]) +@triton.jit +def _attn_fwd_opt( # Q, V, desc_k, desc_v, sm_scale, M, Out, # + Q, + K, + V, sm_scale, M, - desc_o, # + Out, # + desc_q, + desc_k, + desc_v, + desc_o, stride_qz, stride_qh, stride_qm, @@ -727,121 +1063,210 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, # stride_on, # Z, H, - N_CTX, # + N_CTX, #: tl.constexpr, # BLOCK_M: tl.constexpr, # BLOCK_N: tl.constexpr, # HEAD_DIM: tl.constexpr, # STAGE: tl.constexpr, # + ENABLE_TMA: tl.constexpr, + LOOP_SCHEDULE: tl.constexpr, + ENABLE_WS: tl.constexpr, ): - tl.static_assert(BLOCK_N <= HEAD_DIM) - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - off_z = off_hz // H - off_h = off_hz % H - qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + _attn_fwd_compute( + Q, + K, + V, + sm_scale, + M, + Out, # + desc_q, + desc_k, + desc_v, + desc_o, + stride_qz, + stride_qh, + stride_qm, + stride_qk, # + stride_kz, + stride_kh, + stride_kn, + stride_kk, # + stride_vz, + stride_vh, + stride_vk, + stride_vn, # + stride_oz, + stride_oh, + stride_om, + stride_on, # + Z, + H, + N_CTX, #: tl.constexpr, # + BLOCK_M, + BLOCK_N, + HEAD_DIM, + STAGE, + ENABLE_TMA, + LOOP_SCHEDULE, + ) - # block pointers - # Q_block_ptr = tl.make_block_ptr( - # base=Q + qvk_offset, - # shape=(N_CTX, HEAD_DIM), - # strides=(stride_qm, stride_qk), - # offsets=(start_m * BLOCK_M, 0), - # block_shape=(BLOCK_M, HEAD_DIM), - # order=(1, 0), - # ) - # O_block_ptr = tl.make_block_ptr( - # base=Out + qvk_offset, - # shape=(N_CTX, HEAD_DIM), - # strides=(stride_om, stride_on), - # offsets=(start_m * BLOCK_M, 0), - # block_shape=(BLOCK_M, HEAD_DIM), - # order=(1, 0), - # ) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) - # load scales - qk_scale = sm_scale - qk_scale *= 1.44269504 # 1/log(2) - # load q: it will stay in SRAM throughout - # q = tl.load(Q_block_ptr) - with tl.async_task([0]): - q = tl._experimental_descriptor_load( # load in row major - desc_q, - [(qvk_offset // stride_qm + start_m * BLOCK_M).to(tl.int32), 0], - [BLOCK_M, HEAD_DIM], - Q.dtype.element_ty, - ) - # stage 1: off-band - # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE - # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE - if STAGE & 1: - acc, l_i, m_i = _attn_fwd_inner_tma( - acc, - l_i, - m_i, - q, - desc_k, - desc_v, - Q, - qvk_offset, - stride_kn, - stride_vn, - stride_vk, # - start_m, - qk_scale, # - BLOCK_M, - HEAD_DIM, - BLOCK_N, # - 4 - STAGE, - offs_m, - offs_n, - N_CTX, - V.dtype.element_ty == tl.float8e5, # - ) - # stage 2: on-band - if STAGE & 2: - # barrier makes it easier for compielr to schedule the - # two loops independently - acc, l_i, m_i = _attn_fwd_inner_tma( - acc, - l_i, - m_i, - q, - desc_k, - desc_v, - Q, - qvk_offset, - stride_kn, - stride_vn, - stride_vk, # - start_m, - qk_scale, # - BLOCK_M, - HEAD_DIM, - BLOCK_N, # - 2, - offs_m, - offs_n, - N_CTX, - V.dtype.element_ty == tl.float8e5, # - ) - # epilogue - with tl.async_task([1, 2]): - m_i += tl.math.log2(l_i) - acc = acc / l_i[:, None] - m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(m_ptrs, m_i) - # tl.device_print("tma", acc.to(Out.type.element_ty)) - tl._experimental_descriptor_store( - desc_o, - acc.to(Out.type.element_ty), - [(qvk_offset // stride_om + start_m * BLOCK_M).to(tl.int32), 0], - ) + +@triton.autotune(list(filter(keep, configsTma)), key=["N_CTX"]) +@triton.jit +def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, # + Q, + K, + V, + sm_scale, + M, + Out, # + desc_q, + desc_k, + desc_v, + desc_o, + stride_qz, + stride_qh, + stride_qm, + stride_qk, # + stride_kz, + stride_kh, + stride_kn, + stride_kk, # + stride_vz, + stride_vh, + stride_vk, + stride_vn, # + stride_oz, + stride_oh, + stride_om, + stride_on, # + Z, + H, + N_CTX, #: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + HEAD_DIM: tl.constexpr, # + STAGE: tl.constexpr, # + ENABLE_TMA: tl.constexpr, + LOOP_SCHEDULE: tl.constexpr, + ENABLE_WS: tl.constexpr, +): + _attn_fwd_compute( + Q, + K, + V, + sm_scale, + M, + Out, # + desc_q, + desc_k, + desc_v, + desc_o, + stride_qz, + stride_qh, + stride_qm, + stride_qk, # + stride_kz, + stride_kh, + stride_kn, + stride_kk, # + stride_vz, + stride_vh, + stride_vk, + stride_vn, # + stride_oz, + stride_oh, + stride_om, + stride_on, # + Z, + H, + N_CTX, #: tl.constexpr, # + BLOCK_M, + BLOCK_N, + HEAD_DIM, + STAGE, + ENABLE_TMA, + LOOP_SCHEDULE, + ) + + +@triton.autotune(list(filter(keep, configsTmaWS)), key=["N_CTX"]) +@triton.jit +def _attn_fwd_tma_ws( # Q, V, desc_k, desc_v, sm_scale, M, Out, # + Q, + K, + V, + sm_scale, + M, + Out, # + desc_q, + desc_k, + desc_v, + desc_o, + stride_qz, + stride_qh, + stride_qm, + stride_qk, # + stride_kz, + stride_kh, + stride_kn, + stride_kk, # + stride_vz, + stride_vh, + stride_vk, + stride_vn, # + stride_oz, + stride_oh, + stride_om, + stride_on, # + Z, + H, + N_CTX, #: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + HEAD_DIM: tl.constexpr, # + STAGE: tl.constexpr, # + ENABLE_TMA: tl.constexpr, + LOOP_SCHEDULE: tl.constexpr, + ENABLE_WS: tl.constexpr, +): + _attn_fwd_compute_ws( + Q, + K, + V, + sm_scale, + M, + Out, # + desc_q, + desc_k, + desc_v, + desc_o, + stride_qz, + stride_qh, + stride_qm, + stride_qk, # + stride_kz, + stride_kh, + stride_kn, + stride_kk, # + stride_vz, + stride_vh, + stride_vk, + stride_vn, # + stride_oz, + stride_oh, + stride_om, + stride_on, # + Z, + H, + N_CTX, #: tl.constexpr, # + BLOCK_M, + BLOCK_N, + HEAD_DIM, + STAGE, + ENABLE_TMA, + LOOP_SCHEDULE, + ) @triton.jit @@ -1094,344 +1519,98 @@ def _attn_bwd( M, D, # stride_tok, - stride_d, # - H, - N_CTX, # - BLOCK_M1, - BLOCK_N1, - HEAD_DIM, # - start_n, - start_m, - num_steps, # - MASK=False, # - ) - - dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d - tl.store(dv_ptrs, dv) - - # Write back dK. - dk *= sm_scale - dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d - tl.store(dk_ptrs, dk) - - # THIS BLOCK DOES DQ: - start_m = pid * BLOCK_M2 - end_n = start_m + BLOCK_M2 - - MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR - offs_m = start_m + tl.arange(0, BLOCK_M2) - - q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) - dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) - do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) - - m = tl.load(M + offs_m) - m = m[:, None] - - # Compute dQ for masked (diagonal) blocks. - # NOTE: This code scans each row of QK^T backward (from right to left, - # but inside each call to _attn_bwd_dq, from left to right), but that's - # not due to anything important. I just wanted to reuse the loop - # structure for dK & dV above as much as possible. - num_steps = BLOCK_M2 // MASK_BLOCK_N2 - dq = _attn_bwd_dq( - dq, - q, - K, - V, # - do, - m, - D, # - stride_tok, - stride_d, # - H, - N_CTX, # - BLOCK_M2, - MASK_BLOCK_N2, - HEAD_DIM, # - start_m, - end_n - num_steps * MASK_BLOCK_N2, - num_steps, # - MASK=True, # - ) - end_n -= num_steps * MASK_BLOCK_N2 - # stage 2 - num_steps = end_n // BLOCK_N2 - dq = _attn_bwd_dq( - dq, - q, - K, - V, # - do, - m, - D, # - stride_tok, - stride_d, # - H, - N_CTX, # - BLOCK_M2, - BLOCK_N2, - HEAD_DIM, # - start_m, - end_n - num_steps * BLOCK_N2, - num_steps, # - MASK=False, # - ) - # Write back dQ. - dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d - dq *= LN2 - tl.store(dq_ptrs, dq) - - -class _attention_ws(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, causal, sm_scale): - # shape constraints - HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] - # when v is in float8_e5m2 it is transposed. - HEAD_DIM_V = v.shape[-2] if v.dtype == torch.float8_e5m2 else v.shape[-1] - assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V - assert HEAD_DIM_K in {16, 32, 64, 128, 256} - o = torch.empty_like(q) - stage = 3 if causal else 1 - extra_kern_args = {} - - grid = lambda args: ( - # grid partitioning: num_consumer_groups * BLOCK_M - # data partitioning: BLOCK_M - triton.cdiv( - q.shape[2], (1 if DATA_PARTITION else 2) * args["BLOCK_M"] - ), # num_consumer_groups, or 1 for debugging - q.shape[0] * q.shape[1], - 1, - ) - M = torch.empty( - (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 - ) - _attn_fwd_ws[grid]( - q, - k, - v, - sm_scale, - M, - o, # - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), # - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), # - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), # - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), # - q.shape[0], - q.shape[1], # - N_CTX=q.shape[2], # - HEAD_DIM=HEAD_DIM_K, # - STAGE=stage, # - **extra_kern_args, - ) - - ctx.save_for_backward(q, k, v, o, M) - ctx.grid = grid - ctx.sm_scale = sm_scale - ctx.HEAD_DIM = HEAD_DIM_K - ctx.causal = causal - return o - - @staticmethod - def backward(ctx, do): - q, k, v, o, M = ctx.saved_tensors - assert do.is_contiguous() - assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - BATCH, N_HEAD, N_CTX = q.shape[:3] - PRE_BLOCK = 128 - NUM_WARPS, NUM_STAGES = 4, 5 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLK_SLICE_FACTOR = 2 - RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) - arg_k = k - arg_k = arg_k * (ctx.sm_scale * RCP_LN2) - PRE_BLOCK = 128 - assert N_CTX % PRE_BLOCK == 0 - pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) - delta = torch.empty_like(M) - _attn_bwd_preprocess[pre_grid]( - o, - do, # - delta, # - BATCH, - N_HEAD, - N_CTX, # - BLOCK_M=PRE_BLOCK, - HEAD_DIM=ctx.HEAD_DIM, # - ) - grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) - _attn_bwd[grid]( - q, - arg_k, - v, - ctx.sm_scale, - do, - dq, - dk, - dv, # - M, - delta, # - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), # - N_HEAD, - N_CTX, # - BLOCK_M1=BLOCK_M1, - BLOCK_N1=BLOCK_N1, # - BLOCK_M2=BLOCK_M2, - BLOCK_N2=BLOCK_N2, # - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # - HEAD_DIM=ctx.HEAD_DIM, # - num_warps=NUM_WARPS, # - num_stages=NUM_STAGES, # - ) + stride_d, # + H, + N_CTX, # + BLOCK_M1, + BLOCK_N1, + HEAD_DIM, # + start_n, + start_m, + num_steps, # + MASK=False, # + ) - return dq, dk, dv, None, None + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv) + # Write back dK. + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk) -class _attention(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, causal, sm_scale): - # shape constraints - HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] - # when v is in float8_e5m2 it is transposed. - HEAD_DIM_V = v.shape[-2] if v.dtype == torch.float8_e5m2 else v.shape[-1] - assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V - assert HEAD_DIM_K in {16, 32, 64, 128, 256} - o = torch.empty_like(q) - stage = 3 if causal else 1 - extra_kern_args = {} + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 - grid = lambda args: ( - triton.cdiv( - q.shape[2], args["BLOCK_M"] - ), # num_consumer_groups, or 1 for debugging - q.shape[0] * q.shape[1], - 1, - ) - M = torch.empty( - (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 - ) - _attn_fwd[grid]( - q, - k, - v, - sm_scale, - M, - o, # - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), # - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), # - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), # - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), # - q.shape[0], - q.shape[1], # - N_CTX=q.shape[2], # - HEAD_DIM=HEAD_DIM_K, # - STAGE=stage, # - **extra_kern_args, - ) + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) - ctx.save_for_backward(q, k, v, o, M) - ctx.grid = grid - ctx.sm_scale = sm_scale - ctx.HEAD_DIM = HEAD_DIM_K - ctx.causal = causal - return o + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) - @staticmethod - def backward(ctx, do): - q, k, v, o, M = ctx.saved_tensors - assert do.is_contiguous() - assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - BATCH, N_HEAD, N_CTX = q.shape[:3] - PRE_BLOCK = 128 - NUM_WARPS, NUM_STAGES = 4, 5 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLK_SLICE_FACTOR = 2 - RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) - arg_k = k - arg_k = arg_k * (ctx.sm_scale * RCP_LN2) - PRE_BLOCK = 128 - assert N_CTX % PRE_BLOCK == 0 - pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) - delta = torch.empty_like(M) - _attn_bwd_preprocess[pre_grid]( - o, - do, # - delta, # - BATCH, - N_HEAD, - N_CTX, # - BLOCK_M=PRE_BLOCK, - HEAD_DIM=ctx.HEAD_DIM, # - ) - grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) - _attn_bwd[grid]( - q, - arg_k, - v, - ctx.sm_scale, - do, - dq, - dk, - dv, # - M, - delta, # - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), # - N_HEAD, - N_CTX, # - BLOCK_M1=BLOCK_M1, - BLOCK_N1=BLOCK_N1, # - BLOCK_M2=BLOCK_M2, - BLOCK_N2=BLOCK_N2, # - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # - HEAD_DIM=ctx.HEAD_DIM, # - num_warps=NUM_WARPS, # - num_stages=NUM_STAGES, # - ) + m = tl.load(M + offs_m) + m = m[:, None] - return dq, dk, dv, None, None + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _attn_bwd_dq( + dq, + q, + K, + V, # + do, + m, + D, # + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M2, + MASK_BLOCK_N2, + HEAD_DIM, # + start_m, + end_n - num_steps * MASK_BLOCK_N2, + num_steps, # + MASK=True, # + ) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _attn_bwd_dq( + dq, + q, + K, + V, # + do, + m, + D, # + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M2, + BLOCK_N2, + HEAD_DIM, # + start_m, + end_n - num_steps * BLOCK_N2, + num_steps, # + MASK=False, # + ) + # Write back dQ. + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + dq *= LN2 + tl.store(dq_ptrs, dq) -class _attention_tma(torch.autograd.Function): +class _attention_opt(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, causal, sm_scale): + def forward(ctx, q, k, v, causal, sm_scale, baseVariant): # shape constraints HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] # when v is in float8_e5m2 it is transposed. @@ -1445,44 +1624,6 @@ def forward(ctx, q, k, v, causal, sm_scale): TMA_SIZE = 128 BATCH, H, N_CTX = q.shape[0], q.shape[1], q.shape[2] # no autotune with fixed BLOCK_N - """ - BLOCK_N = 128 - desc_k = np.empty(TMA_SIZE, dtype=np.int8) - desc_v = np.empty(TMA_SIZE, dtype=np.int8) - # order is (0, 1) for fp8 in make_block_ptr, reverse here - triton.runtime.driver.active.utils.fill_2d_tma_descriptor( - k.data_ptr(), - BATCH * H * N_CTX, - HEAD_DIM_Q, - BLOCK_N, - HEAD_DIM_Q, - k.element_size(), - desc_k, - ) - if v.dtype == torch.float8_e5m2: - triton.runtime.driver.active.utils.fill_2d_tma_descriptor( - v.data_ptr(), - BATCH * H * HEAD_DIM_Q, - N_CTX, - HEAD_DIM_Q, - BLOCK_N, - v.element_size(), - desc_v, - ) - else: - triton.runtime.driver.active.utils.fill_2d_tma_descriptor( - v.data_ptr(), - BATCH * H * N_CTX, - HEAD_DIM_Q, - BLOCK_N, - HEAD_DIM_Q, - v.element_size(), - desc_v, - ) - desc_k = torch.tensor(desc_k, device=v.device) - desc_v = torch.tensor(desc_v, device=v.device) - grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) - """ desc_helper = TmaAutoTuneHelper() desc_helper.init_tma_descriptor("k") desc_helper.init_tma_descriptor("v") @@ -1490,6 +1631,14 @@ def forward(ctx, q, k, v, causal, sm_scale): desc_helper.init_tma_descriptor("o") def grid_tma(META): + if META["ENABLE_TMA"] == False: + return ( + # grid partitioning: num_consumer_groups * BLOCK_M + # data partitioning: BLOCK_M + triton.cdiv(q.shape[2], META["BLOCK_M"]), # num_consumer_groups + q.shape[0] * q.shape[1], + 1, + ) nonlocal desc_helper desc_helper.fill_2d_tma_descriptor( "k", @@ -1526,7 +1675,7 @@ def grid_tma(META): BATCH * H * N_CTX, HEAD_DIM_Q, META["BLOCK_M"] - // (2 if DATA_PARTITION else 1), # data partitioning: halve + // (2 if META["ENABLE_WS"] else 1), # data partitioning: halve HEAD_DIM_Q, q.element_size(), ) @@ -1536,16 +1685,14 @@ def grid_tma(META): BATCH * H * N_CTX, HEAD_DIM_Q, META["BLOCK_M"] - // (2 if DATA_PARTITION else 1), # data partitioning: halve + // (2 if META["ENABLE_WS"] else 1), # data partitioning: halve HEAD_DIM_Q, o.element_size(), ) return ( # grid partitioning: num_consumer_groups * BLOCK_M # data partitioning: BLOCK_M - triton.cdiv( - q.shape[2], (1 if DATA_PARTITION else 2) * META["BLOCK_M"] - ), # num_consumer_groups + triton.cdiv(q.shape[2], META["BLOCK_M"]), # num_consumer_groups q.shape[0] * q.shape[1], 1, ) @@ -1558,48 +1705,253 @@ def grid_tma(META): M = torch.empty( (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 ) - _attn_fwd_tma[grid_tma]( + if baseVariant == "base": + _attn_fwd[grid_tma]( + q, + k, + v, + sm_scale, + M, + o, + desc_q, + desc_k, + desc_v, + desc_o, # + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), # + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), # + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), # + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), # + q.shape[0], + q.shape[1], # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + STAGE=stage, # + ENABLE_WS=False, + **extra_kern_args, + ) + elif baseVariant == "ws": + _attn_fwd_ws[grid_tma]( + q, + k, + v, + sm_scale, + M, + o, + desc_q, + desc_k, + desc_v, + desc_o, # + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), # + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), # + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), # + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), # + q.shape[0], + q.shape[1], # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + STAGE=stage, # + ENABLE_WS=True, + **extra_kern_args, + ) + elif baseVariant == "opt": + _attn_fwd_opt[grid_tma]( + q, + k, + v, + sm_scale, + M, + o, + desc_q, + desc_k, + desc_v, + desc_o, # + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), # + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), # + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), # + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), # + q.shape[0], + q.shape[1], # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + STAGE=stage, # + ENABLE_WS=False, + **extra_kern_args, + ) + elif baseVariant == "tma": + _attn_fwd_tma[grid_tma]( + q, + k, + v, + sm_scale, + M, + o, + desc_q, + desc_k, + desc_v, + desc_o, # + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), # + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), # + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), # + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), # + q.shape[0], + q.shape[1], # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + STAGE=stage, # + ENABLE_WS=False, + **extra_kern_args, + ) + elif baseVariant == "tma_ws": + _attn_fwd_tma_ws[grid_tma]( + q, + k, + v, + sm_scale, + M, + o, + desc_q, + desc_k, + desc_v, + desc_o, # + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), # + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), # + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), # + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), # + q.shape[0], + q.shape[1], # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + STAGE=stage, # + ENABLE_WS=True, + **extra_kern_args, + ) + + ctx.save_for_backward(q, k, v, o, M) + ctx.grid = grid_tma + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, M = ctx.saved_tensors + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + NUM_WARPS, NUM_STAGES = 4, 5 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + PRE_BLOCK = 128 + assert N_CTX % PRE_BLOCK == 0 + pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) + delta = torch.empty_like(M) + _attn_bwd_preprocess[pre_grid]( + o, + do, # + delta, # + BATCH, + N_HEAD, + N_CTX, # + BLOCK_M=PRE_BLOCK, + HEAD_DIM=ctx.HEAD_DIM, # + ) + grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) + _attn_bwd[grid]( q, + arg_k, v, - o, - desc_q, - desc_k, - desc_v, - sm_scale, + ctx.sm_scale, + do, + dq, + dk, + dv, # M, - desc_o, # + delta, # q.stride(0), q.stride(1), q.stride(2), q.stride(3), # - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), # - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), # - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), # - q.shape[0], - q.shape[1], # - N_CTX=q.shape[2], # - HEAD_DIM=HEAD_DIM_K, # - STAGE=stage, # - **extra_kern_args, + N_HEAD, + N_CTX, # + BLOCK_M1=BLOCK_M1, + BLOCK_N1=BLOCK_N1, # + BLOCK_M2=BLOCK_M2, + BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + HEAD_DIM=ctx.HEAD_DIM, # + num_warps=NUM_WARPS, # + num_stages=NUM_STAGES, # ) - ctx.save_for_backward(q, k, v, o, M) - ctx.grid = grid_tma - ctx.sm_scale = sm_scale - ctx.HEAD_DIM = HEAD_DIM_K - ctx.causal = causal - return o + return dq, dk, dv, None, None -attention = _attention.apply -attention_ws = _attention_ws.apply -attention_tma = _attention_tma.apply +attention_opt = _attention_opt.apply diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 59db4a44..de21bfd1 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -55,9 +55,7 @@ from torch.nn.functional import scaled_dot_product_attention as sdpa from tritonbench.kernels.triton_fused_attention import ( - attention as triton_tutorial_FA2, - attention_tma as triton_tutorial_FA2_tma, - attention_ws as triton_tutorial_FA2_ws, + attention_opt as triton_tutorial_FA2_opt, ) # [Optional] flash_attn v2 @@ -248,7 +246,22 @@ def triton_tutorial_flash_v2( k: torch.Tensor, v: torch.Tensor, ) -> Callable: - return lambda: triton_tutorial_FA2(q, k, v, self.causal, self.sm_scale) + # base: do not enable TMA/WarpSpec/CompPipe + return lambda: triton_tutorial_FA2_opt( + q, k, v, self.causal, self.sm_scale, "base" + ) + + @register_benchmark(enabled=HAS_CUDA_124) + def triton_tutorial_flash_v2_opt( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> Callable: + # autotune CompPipe + return lambda: triton_tutorial_FA2_opt( + q, k, v, self.causal, self.sm_scale, "opt" + ) @register_benchmark(enabled=HAS_CUDA_124) def triton_tutorial_flash_v2_tma( @@ -257,7 +270,10 @@ def triton_tutorial_flash_v2_tma( k: torch.Tensor, v: torch.Tensor, ) -> Callable: - return lambda: triton_tutorial_FA2_tma(q, k, v, self.causal, self.sm_scale) + # autotune TMA/CompPipe + return lambda: triton_tutorial_FA2_opt( + q, k, v, self.causal, self.sm_scale, "tma" + ) @register_benchmark(enabled=HAS_CUDA_124) def triton_tutorial_flash_v2_ws( @@ -266,7 +282,22 @@ def triton_tutorial_flash_v2_ws( k: torch.Tensor, v: torch.Tensor, ) -> Callable: - return lambda: triton_tutorial_FA2_ws(q, k, v, self.causal, self.sm_scale) + # autotune WarpSpec/CompPipe + return lambda: triton_tutorial_FA2_opt( + q, k, v, self.causal, self.sm_scale, "ws" + ) + + @register_benchmark(enabled=HAS_CUDA_124) + def triton_tutorial_flash_v2_tma_ws( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> Callable: + # autotune TMA/WarpSpec/CompPipe + return lambda: triton_tutorial_FA2_opt( + q, k, v, self.causal, self.sm_scale, "tma_ws" + ) @register_benchmark(enabled=HAS_KERNELS) def triton_op_flash_v2(