diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index c2d62fdb..2ef84ca8 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -19,6 +19,7 @@ flash_attention: - triton_tutorial_flash_v2_ws - triton_tutorial_flash_v2_tma_ws - triton_tutorial_flash_v2_tma_ws_persistent + - triton_tutorial_flash_v2_bwd_ws fp8_attention: - colfax_fmha # triton_flash_v2 requires triton-main diff --git a/test/test_gpu/skip_tests_h100_triton_main.yaml b/test/test_gpu/skip_tests_h100_triton_main.yaml index 58ac80d9..c2b1ac79 100644 --- a/test/test_gpu/skip_tests_h100_triton_main.yaml +++ b/test/test_gpu/skip_tests_h100_triton_main.yaml @@ -11,6 +11,7 @@ flash_attention: - triton_tutorial_flash_v2_ws - triton_tutorial_flash_v2_tma_ws - triton_tutorial_flash_v2_tma_ws_persistent + - triton_tutorial_flash_v2_bwd_ws fp8_attention: # fb-only kernel - colfax_fmha diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 9b8c527f..aa953c9c 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -1596,8 +1596,213 @@ def _attn_bwd_dq( return dq +# The main inner-loop logic for computing dK and dV. @triton.jit -def _attn_bwd( +def _attn_bwd_dkdv_ws( + dk, + dv, # + Q, + k, + v, + sm_scale, # + DO, # + M, + D, # + # shared by Q/K/V/DO. + stride_tok, + stride_d, # + H, + N_CTX, + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + HEAD_DIM: tl.constexpr, # + # Filled in by the wrapper. + start_n, + start_m, + num_steps, # + MASK: tl.constexpr, +): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_k = tl.arange(0, HEAD_DIM) + qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + with tl.async_task([0]): + qT = tl.load(qT_ptrs) + with tl.async_task([1, 2]): + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + qkT = tl.dot(k, qT) + # dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + pT = tl.math.exp2(qkT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = offs_m[None, :] >= offs_n[:, None] + pT = tl.where(mask, pT, 0.0) + with tl.async_task([0]): + do = tl.load(do_ptrs) + # Compute dV. + with tl.async_task([1, 2]): + ppT = pT + ppT = ppT.to(tl.bfloat16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.bfloat16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + with tl.async_task([0]): + qT_ptrs += step_m * stride_tok + do_ptrs += step_m * stride_tok + return dk, dv + + +# the main inner-loop logic for computing dQ +@triton.jit +def _attn_bwd_dq_ws( + dq, + q, + K, + V, # + do, + m, + D, + # shared by Q/K/V/DO. + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, + # Filled in by the wrapper. + start_m, + start_n, + num_steps, # + MASK: tl.constexpr, +): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + with tl.async_task([0]): + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + with tl.async_task([1, 2]): + qk = tl.dot(q, kT) + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = offs_m[:, None] >= offs_n[None, :] + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.bfloat16) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + with tl.async_task([0]): + kT_ptrs += step_n * stride_tok + vT_ptrs += step_n * stride_tok + return dq + + +def keep2(conf): + BLOCK_M = conf.kwargs["BLOCK_M1"] + BLOCK_N = conf.kwargs["BLOCK_N1"] + if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: + return False + return True + + +configsBwd = [ + ( + triton.Config( + { + "BLOCK_M1": BM, + "BLOCK_N1": BN, + "BLOCK_M2": BN, + "BLOCK_N2": BM, + }, + num_stages=s, # 0 or s, + num_warps=w, + num_buffers_warp_spec=0, # 0 or 2, + num_consumer_groups=0, # 0 or 1, + ) + if has_warp_spec + else triton.Config( + { + "BLOCK_M1": BM, + "BLOCK_N1": BN, + "BLOCK_M2": BN, + "BLOCK_N2": BM, + }, + num_stages=s, + num_warps=w, + ) + ) + for BM in [64] # 32, 64] # BLOCK_N1 % BLOCK_M1 == 0 + for BN in [128] # 64, 128] + for s in [3] # , 4, 7] + for w in [4] # , 8] +] +configsBwdWs = [ + ( + triton.Config( + { + "BLOCK_M1": BM, + "BLOCK_N1": BN, + "BLOCK_M2": BN, + "BLOCK_N2": BM, + }, + num_stages=s, + num_warps=w, + num_buffers_warp_spec=2, + num_consumer_groups=2, + ) + if has_warp_spec + else triton.Config( + { + "BLOCK_M1": BM, + "BLOCK_N1": BN, + "BLOCK_M2": BN, + "BLOCK_N2": BM, + }, + num_stages=s, + num_warps=w, + ) + ) + for BM in [64] # 32, 64] # BLOCK_N1 % BLOCK_M1 == 0 + for BN in [128] # [64] #64, 128] + for s in [3] # , 4, 7] + for w in [4] # , 8] +] + + +@triton.jit +def _attn_bwd_compute( Q, K, V, @@ -1785,9 +1990,310 @@ def _attn_bwd( tl.store(dq_ptrs, dq) +@triton.jit +def _attn_bwd_compute_ws( + Q, + K, + V, + sm_scale, # + DO, # + DQ, + DK, + DV, # + M, + D, + # shared by Q/K/V/DO. + stride_z, + stride_h, + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + HEAD_DIM: tl.constexpr, +): + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # load scales + offs_k = tl.arange(0, HEAD_DIM) + + start_n = pid * BLOCK_N1 + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + with tl.async_task([0]): + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + + dk, dv = _attn_bwd_dkdv_ws( + dk, + dv, # + Q, + k, + v, + sm_scale, # + DO, # + M, + D, # + stride_tok, + stride_d, # + H, + N_CTX, # + MASK_BLOCK_M1, + BLOCK_N1, + HEAD_DIM, # + start_n, + start_m, + num_steps, # + MASK=True, # + ) + + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + # Compute dK and dV for non-masked blocks. + dk, dv = _attn_bwd_dkdv_ws( # + dk, + dv, # + Q, + k, + v, + sm_scale, # + DO, # + M, + D, # + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M1, + BLOCK_N1, + HEAD_DIM, # + start_n, + start_m, + num_steps, # + MASK=False, # + ) + + with tl.async_task([1, 2]): + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv) + + # Write back dK. + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + with tl.async_task([0]): + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _attn_bwd_dq_ws( + dq, + q, + K, + V, # + do, + m, + D, # + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M2, + MASK_BLOCK_N2, + HEAD_DIM, # + start_m, + end_n - num_steps * MASK_BLOCK_N2, + num_steps, # + MASK=True, # + ) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _attn_bwd_dq_ws( + dq, + q, + K, + V, # + do, + m, + D, # + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M2, + BLOCK_N2, + HEAD_DIM, # + start_m, + end_n - num_steps * BLOCK_N2, + num_steps, # + MASK=False, # + ) + # Write back dQ. + with tl.async_task([1, 2]): + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + dq *= LN2 + tl.store(dq_ptrs, dq) + + +@triton.autotune(list(filter(keep2, configsBwd)), key=["N_CTX"]) +@triton.jit +def _attn_bwd( + Q, + K, + V, + sm_scale, # + DO, # + DQ, + DK, + DV, # + M, + D, + # shared by Q/K/V/DO. + stride_z, + stride_h, + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + HEAD_DIM: tl.constexpr, +): + _attn_bwd_compute( + Q, + K, + V, + sm_scale, # + DO, # + DQ, + DK, + DV, # + M, + D, + # shared by Q/K/V/DO. + stride_z, + stride_h, + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M1, + BLOCK_N1, + BLOCK_M2, + BLOCK_N2, + BLK_SLICE_FACTOR, + HEAD_DIM, + ) + + +@triton.autotune(list(filter(keep2, configsBwdWs)), key=["N_CTX"]) +@triton.jit +def _attn_bwd_ws( + Q, + K, + V, + sm_scale, # + DO, # + DQ, + DK, + DV, # + M, + D, + # shared by Q/K/V/DO. + stride_z, + stride_h, + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + HEAD_DIM: tl.constexpr, +): + _attn_bwd_compute_ws( + Q, + K, + V, + sm_scale, # + DO, # + DQ, + DK, + DV, # + M, + D, + # shared by Q/K/V/DO. + stride_z, + stride_h, + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M1, + BLOCK_N1, + BLOCK_M2, + BLOCK_N2, + BLK_SLICE_FACTOR, + HEAD_DIM, + ) + + class _attention_opt(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, causal, sm_scale, baseVariant): + def forward(ctx, q, k, v, causal, sm_scale, baseVariant, bwdVariant): # shape constraints HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] # when v is in float8_e5m2 it is transposed. @@ -2175,6 +2681,8 @@ def grid_tma_persistent(META): ctx.sm_scale = sm_scale ctx.HEAD_DIM = HEAD_DIM_K ctx.causal = causal + ctx.bwdVariant = bwdVariant + # If we want to use different variants for bwd, save bwd mode here. return o @staticmethod @@ -2189,8 +2697,10 @@ def backward(ctx, do): dv = torch.empty_like(v) BATCH, N_HEAD, N_CTX = q.shape[:3] PRE_BLOCK = 128 - NUM_WARPS, NUM_STAGES = 4, 5 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + + # NUM_WARPS, NUM_STAGES = 4, 5 + # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) arg_k = k @@ -2209,35 +2719,64 @@ def backward(ctx, do): BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM, # ) - grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) - _attn_bwd[grid]( - q, - arg_k, - v, - ctx.sm_scale, - do, - dq, - dk, - dv, # - M, - delta, # - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), # - N_HEAD, - N_CTX, # - BLOCK_M1=BLOCK_M1, - BLOCK_N1=BLOCK_N1, # - BLOCK_M2=BLOCK_M2, - BLOCK_N2=BLOCK_N2, # - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # - HEAD_DIM=ctx.HEAD_DIM, # - num_warps=NUM_WARPS, # - num_stages=NUM_STAGES, # - ) + grid = lambda args: (N_CTX // args["BLOCK_N1"], 1, BATCH * N_HEAD) + # grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) + if ctx.bwdVariant == "base": + _attn_bwd[grid]( + q, + arg_k, + v, + ctx.sm_scale, + do, + dq, + dk, + dv, # + M, + delta, # + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), # + N_HEAD, + N_CTX, # + # BLOCK_M1=BLOCK_M1, + # BLOCK_N1=BLOCK_N1, # + # BLOCK_M2=BLOCK_M2, + # BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + HEAD_DIM=ctx.HEAD_DIM, # + # num_warps=NUM_WARPS, # + # num_stages=NUM_STAGES, # + ) + elif ctx.bwdVariant == "ws": + _attn_bwd_ws[grid]( + q, + arg_k, + v, + ctx.sm_scale, + do, + dq, + dk, + dv, # + M, + delta, # + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), # + N_HEAD, + N_CTX, # + # BLOCK_M1=BLOCK_M1, + # BLOCK_N1=BLOCK_N1, # + # BLOCK_M2=BLOCK_M2, + # BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + HEAD_DIM=ctx.HEAD_DIM, # + # num_warps=NUM_WARPS, # + # num_stages=NUM_STAGES, # + ) - return dq, dk, dv, None, None, None + return dq, dk, dv, None, None, None, None attention_opt = _attention_opt.apply diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 5d3bcc46..0d4ad2d9 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -252,7 +252,19 @@ def triton_tutorial_flash_v2( ) -> Callable: # base: do not enable TMA/WarpSpec/CompPipe return lambda: triton_tutorial_FA2_opt( - q, k, v, self.causal, self.sm_scale, "base" + q, k, v, self.causal, self.sm_scale, "base", "base" + ) + + @register_benchmark() + def triton_tutorial_flash_v2_bwd_ws( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> Callable: + # base: do not enable TMA/WarpSpec/CompPipe + return lambda: triton_tutorial_FA2_opt( + q, k, v, self.causal, self.sm_scale, "base", "ws" ) @register_benchmark(enabled=HAS_CUDA_124) @@ -264,7 +276,7 @@ def triton_tutorial_flash_v2_opt( ) -> Callable: # autotune CompPipe return lambda: triton_tutorial_FA2_opt( - q, k, v, self.causal, self.sm_scale, "opt" + q, k, v, self.causal, self.sm_scale, "opt", "base" ) @register_benchmark(enabled=HAS_CUDA_124) @@ -276,7 +288,7 @@ def triton_tutorial_flash_v2_tma( ) -> Callable: # autotune TMA/CompPipe return lambda: triton_tutorial_FA2_opt( - q, k, v, self.causal, self.sm_scale, "tma" + q, k, v, self.causal, self.sm_scale, "tma", "base" ) @register_benchmark(enabled=HAS_CUDA_124) @@ -288,7 +300,7 @@ def triton_tutorial_flash_v2_ws( ) -> Callable: # autotune WarpSpec/CompPipe return lambda: triton_tutorial_FA2_opt( - q, k, v, self.causal, self.sm_scale, "ws" + q, k, v, self.causal, self.sm_scale, "ws", "base" ) @register_benchmark(enabled=HAS_CUDA_124) @@ -300,7 +312,7 @@ def triton_tutorial_flash_v2_tma_ws( ) -> Callable: # autotune TMA/WarpSpec/CompPipe return lambda: triton_tutorial_FA2_opt( - q, k, v, self.causal, self.sm_scale, "tma_ws" + q, k, v, self.causal, self.sm_scale, "tma_ws", "base" ) @register_benchmark(enabled=HAS_CUDA_124) @@ -312,7 +324,7 @@ def triton_tutorial_flash_v2_tma_ws_persistent( ) -> Callable: # autotune TMA/WarpSpec/CompPipe/Persistent return lambda: triton_tutorial_FA2_opt( - q, k, v, self.causal, self.sm_scale, "tma_ws_persistent" + q, k, v, self.causal, self.sm_scale, "tma_ws_persistent", "base" ) @register_benchmark(enabled=HAS_KERNELS) diff --git a/tritonbench/operators/fp8_attention/operator.py b/tritonbench/operators/fp8_attention/operator.py index 0131ca73..71c0ffcc 100644 --- a/tritonbench/operators/fp8_attention/operator.py +++ b/tritonbench/operators/fp8_attention/operator.py @@ -110,7 +110,7 @@ def triton_flash_v2( triton_q, triton_k, triton_v = self.triton_preprocess(q, k, v) # full fp8 will be enabled if type of q,k,v is fp8 return lambda: triton_attention( - triton_q, triton_k, triton_v, False, self.sm_scale, "base" + triton_q, triton_k, triton_v, False, self.sm_scale, "base", "base" ) def get_x_val(self, _example_inputs) -> Tuple[int, int, int, int]: