From 152d4768446854fc70f5a96fb634a6f7d9f06712 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Fri, 29 Nov 2024 16:43:02 -0800 Subject: [PATCH 1/9] bwd two variants Summary: Also add support for warp spec Test Plan: Reviewers: Subscribers: Tasks: Tags: --- tritonbench/kernels/triton_fused_attention.py | 375 ++++++++++++++---- .../operators/flash_attention/operator.py | 26 +- 2 files changed, 321 insertions(+), 80 deletions(-) diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 9b8c527f..1864e5bb 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -458,10 +458,10 @@ def _attn_fwd_inner_ws( num_warps=w, ) ) - for BM in [64, 128] - for BN in [64, 128] - for s in [3, 4, 7] - for w in [4, 8] + for BM in [128] #64, 128] + for BN in [128] #64, 128] + for s in [3] #3, 4, 7] + for w in [8] #4, 8] ] # TMA, WS, and CompPipe configsTmaWS = [ @@ -1508,28 +1508,35 @@ def _attn_bwd_dkdv( curr_m = start_m step_m = BLOCK_M1 for blk_idx in range(num_steps): - qT = tl.load(qT_ptrs) + with tl.async_task([0]): + qT = tl.load(qT_ptrs) # Load m before computing qk to reduce pipeline stall. offs_m = curr_m + tl.arange(0, BLOCK_M1) m = tl.load(M + offs_m) - qkT = tl.dot(k, qT) - pT = tl.math.exp2(qkT - m[None, :]) - # Autoregressive masking. - if MASK: - mask = offs_m[None, :] >= offs_n[:, None] - pT = tl.where(mask, pT, 0.0) - do = tl.load(do_ptrs) + #with tl.async_task([0]): + # do = tl.load(do_ptrs) + with tl.async_task([1]): + qkT = tl.dot(k, qT) + #dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + pT = tl.math.exp2(qkT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = offs_m[None, :] >= offs_n[:, None] + pT = tl.where(mask, pT, 0.0) + with tl.async_task([0]): + do = tl.load(do_ptrs) # Compute dV. - ppT = pT - ppT = ppT.to(tl.bfloat16) - dv += tl.dot(ppT, do) - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(D + offs_m) - # Compute dP and dS. - dpT = tl.dot(v, tl.trans(do)).to(tl.float32) - dsT = pT * (dpT - Di[None, :]) - dsT = dsT.to(tl.bfloat16) - dk += tl.dot(dsT, tl.trans(qT)) + with tl.async_task([1]): + ppT = pT + ppT = ppT.to(tl.bfloat16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.bfloat16) + dk += tl.dot(dsT, tl.trans(qT)) # Increment pointers. curr_m += step_m qT_ptrs += step_m * stride_tok @@ -1573,22 +1580,24 @@ def _attn_bwd_dq( curr_n = start_n step_n = BLOCK_N2 for blk_idx in range(num_steps): - kT = tl.load(kT_ptrs) - vT = tl.load(vT_ptrs) - qk = tl.dot(q, kT) - p = tl.math.exp2(qk - m) - # Autoregressive masking. - if MASK: - offs_n = curr_n + tl.arange(0, BLOCK_N2) - mask = offs_m[:, None] >= offs_n[None, :] - p = tl.where(mask, p, 0.0) - # Compute dP and dS. - dp = tl.dot(do, vT).to(tl.float32) - ds = p * (dp - Di[:, None]) - ds = ds.to(tl.bfloat16) - # Compute dQ. - # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. - dq += tl.dot(ds, tl.trans(kT)) + with tl.async_task([0]): + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + with tl.async_task([1]): + qk = tl.dot(q, kT) + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = offs_m[:, None] >= offs_n[None, :] + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.bfloat16) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) # Increment pointers. curr_n += step_n kT_ptrs += step_n * stride_tok @@ -1596,8 +1605,80 @@ def _attn_bwd_dq( return dq +def keep2(conf): + BLOCK_M = conf.kwargs["BLOCK_M1"] + BLOCK_N = conf.kwargs["BLOCK_N1"] + if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: + return False + return True + + +configsBwd = [ + ( + triton.Config( + { + "BLOCK_M1": BM, + "BLOCK_N1": BN, + "BLOCK_M2": BN, + "BLOCK_N2": BM, + }, + num_stages=s, #0 or s, + num_warps=w, + num_buffers_warp_spec=0, #0 or 2, + num_consumer_groups=0, #0 or 1, + ) + if has_warp_spec + else triton.Config( + { + "BLOCK_M1": BM, + "BLOCK_N1": BN, + "BLOCK_M2": BN, + "BLOCK_N2": BM, + }, + num_stages=s, + num_warps=w, + ) + ) + for BM in [64] #32, 64] # BLOCK_N1 % BLOCK_M1 == 0 + for BN in [128] #64, 128] + for s in [3]#, 4, 7] + for w in [4]#, 8] +] +configsBwd2 = [ + ( + triton.Config( + { + "BLOCK_M1": BM, + "BLOCK_N1": BN, + "BLOCK_M2": BN, + "BLOCK_N2": BM, + }, + num_stages=s, + num_warps=w, + num_buffers_warp_spec=0, + num_consumer_groups=0, + ) + if has_warp_spec + else triton.Config( + { + "BLOCK_M1": BM, + "BLOCK_N1": BN, + "BLOCK_M2": BN, + "BLOCK_N2": BM, + }, + num_stages=s, + num_warps=w, + ) + ) + for BM in [32] #32, 64] # BLOCK_N1 % BLOCK_M1 == 0 + for BN in [64] #64, 128] + for s in [3]#, 4, 7] + for w in [4]#, 8] +] + + @triton.jit -def _attn_bwd( +def _attn_bwd_compute( Q, K, V, @@ -1653,8 +1734,9 @@ def _attn_bwd( dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) - v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + with tl.async_task([0]): + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) num_steps = BLOCK_N1 // MASK_BLOCK_M1 @@ -1723,9 +1805,10 @@ def _attn_bwd( MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR offs_m = start_m + tl.arange(0, BLOCK_M2) - q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + with tl.async_task([0]): + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) - do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) m = tl.load(M + offs_m) m = m[:, None] @@ -1785,9 +1868,117 @@ def _attn_bwd( tl.store(dq_ptrs, dq) +@triton.autotune(list(filter(keep2, configsBwd)), key=["N_CTX"]) +@triton.jit +def _attn_bwd( + Q, + K, + V, + sm_scale, # + DO, # + DQ, + DK, + DV, # + M, + D, + # shared by Q/K/V/DO. + stride_z, + stride_h, + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + HEAD_DIM: tl.constexpr, +): + _attn_bwd_compute( + Q, + K, + V, + sm_scale, # + DO, # + DQ, + DK, + DV, # + M, + D, + # shared by Q/K/V/DO. + stride_z, + stride_h, + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M1, + BLOCK_N1, + BLOCK_M2, + BLOCK_N2, + BLK_SLICE_FACTOR, + HEAD_DIM, + ) + + +@triton.autotune(list(filter(keep2, configsBwd2)), key=["N_CTX"]) +@triton.jit +def _attn_bwd2( + Q, + K, + V, + sm_scale, # + DO, # + DQ, + DK, + DV, # + M, + D, + # shared by Q/K/V/DO. + stride_z, + stride_h, + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + HEAD_DIM: tl.constexpr, +): + _attn_bwd_compute( + Q, + K, + V, + sm_scale, # + DO, # + DQ, + DK, + DV, # + M, + D, + # shared by Q/K/V/DO. + stride_z, + stride_h, + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M1, + BLOCK_N1, + BLOCK_M2, + BLOCK_N2, + BLK_SLICE_FACTOR, + HEAD_DIM, + ) + + class _attention_opt(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, causal, sm_scale, baseVariant): + def forward(ctx, q, k, v, causal, sm_scale, baseVariant): #, bwdVariant): # shape constraints HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] # when v is in float8_e5m2 it is transposed. @@ -2175,6 +2366,8 @@ def grid_tma_persistent(META): ctx.sm_scale = sm_scale ctx.HEAD_DIM = HEAD_DIM_K ctx.causal = causal + #ctx.bwdVariant = bwdVariant + # If we want to use different variants for bwd, save bwd mode here. return o @staticmethod @@ -2189,8 +2382,12 @@ def backward(ctx, do): dv = torch.empty_like(v) BATCH, N_HEAD, N_CTX = q.shape[:3] PRE_BLOCK = 128 - NUM_WARPS, NUM_STAGES = 4, 5 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + + #NUM_WARPS, NUM_STAGES = 4, 5 + #BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + NUM_WARPS, NUM_STAGES = 4, 3 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 128, 128, 64 + BLK_SLICE_FACTOR = 2 RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) arg_k = k @@ -2209,35 +2406,65 @@ def backward(ctx, do): BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM, # ) - grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) - _attn_bwd[grid]( - q, - arg_k, - v, - ctx.sm_scale, - do, - dq, - dk, - dv, # - M, - delta, # - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), # - N_HEAD, - N_CTX, # - BLOCK_M1=BLOCK_M1, - BLOCK_N1=BLOCK_N1, # - BLOCK_M2=BLOCK_M2, - BLOCK_N2=BLOCK_N2, # - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # - HEAD_DIM=ctx.HEAD_DIM, # - num_warps=NUM_WARPS, # - num_stages=NUM_STAGES, # - ) + grid = lambda args: (N_CTX // args["BLOCK_N1"], 1, BATCH * N_HEAD) + #grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) + print(q.stride(0), q.stride(1), q.stride(2), q.stride(3)) + if True: #ctx.bwdVariant == "base": + _attn_bwd[grid]( + q, + arg_k, + v, + ctx.sm_scale, + do, + dq, + dk, + dv, # + M, + delta, # + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), # + N_HEAD, + N_CTX, # + #BLOCK_M1=BLOCK_M1, + #BLOCK_N1=BLOCK_N1, # + #BLOCK_M2=BLOCK_M2, + #BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + HEAD_DIM=ctx.HEAD_DIM, # + #num_warps=NUM_WARPS, # + #num_stages=NUM_STAGES, # + ) + else: #if ctx.bwdVariant == "base2": + _attn_bwd2[grid]( + q, + arg_k, + v, + ctx.sm_scale, + do, + dq, + dk, + dv, # + M, + delta, # + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), # + N_HEAD, + N_CTX, # + #BLOCK_M1=BLOCK_M1, + #BLOCK_N1=BLOCK_N1, # + #BLOCK_M2=BLOCK_M2, + #BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + HEAD_DIM=ctx.HEAD_DIM, # + #num_warps=NUM_WARPS, # + #num_stages=NUM_STAGES, # + ) - return dq, dk, dv, None, None, None + return dq, dk, dv, None, None, None, None attention_opt = _attention_opt.apply diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 5d3bcc46..370ae727 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -252,7 +252,19 @@ def triton_tutorial_flash_v2( ) -> Callable: # base: do not enable TMA/WarpSpec/CompPipe return lambda: triton_tutorial_FA2_opt( - q, k, v, self.causal, self.sm_scale, "base" + q, k, v, self.causal, self.sm_scale, "base"#, "base" + ) + + @register_benchmark() + def triton_tutorial_flash_v2_bwd2( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> Callable: + # base: do not enable TMA/WarpSpec/CompPipe + return lambda: triton_tutorial_FA2_opt( + q, k, v, self.causal, self.sm_scale, "base", "base2" ) @register_benchmark(enabled=HAS_CUDA_124) @@ -264,7 +276,7 @@ def triton_tutorial_flash_v2_opt( ) -> Callable: # autotune CompPipe return lambda: triton_tutorial_FA2_opt( - q, k, v, self.causal, self.sm_scale, "opt" + q, k, v, self.causal, self.sm_scale, "opt", "base" ) @register_benchmark(enabled=HAS_CUDA_124) @@ -276,7 +288,7 @@ def triton_tutorial_flash_v2_tma( ) -> Callable: # autotune TMA/CompPipe return lambda: triton_tutorial_FA2_opt( - q, k, v, self.causal, self.sm_scale, "tma" + q, k, v, self.causal, self.sm_scale, "tma", "base" ) @register_benchmark(enabled=HAS_CUDA_124) @@ -288,7 +300,7 @@ def triton_tutorial_flash_v2_ws( ) -> Callable: # autotune WarpSpec/CompPipe return lambda: triton_tutorial_FA2_opt( - q, k, v, self.causal, self.sm_scale, "ws" + q, k, v, self.causal, self.sm_scale, "ws", "base" ) @register_benchmark(enabled=HAS_CUDA_124) @@ -300,7 +312,7 @@ def triton_tutorial_flash_v2_tma_ws( ) -> Callable: # autotune TMA/WarpSpec/CompPipe return lambda: triton_tutorial_FA2_opt( - q, k, v, self.causal, self.sm_scale, "tma_ws" + q, k, v, self.causal, self.sm_scale, "tma_ws", "base" ) @register_benchmark(enabled=HAS_CUDA_124) @@ -312,7 +324,7 @@ def triton_tutorial_flash_v2_tma_ws_persistent( ) -> Callable: # autotune TMA/WarpSpec/CompPipe/Persistent return lambda: triton_tutorial_FA2_opt( - q, k, v, self.causal, self.sm_scale, "tma_ws_persistent" + q, k, v, self.causal, self.sm_scale, "tma_ws_persistent", "base" ) @register_benchmark(enabled=HAS_KERNELS) @@ -458,6 +470,8 @@ def tflops( BATCH, H, N_CTX, D_HEAD = q.shape flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD tflops = 2 * flops_per_matmul + print("causal, mode: ", self.causal, self.mode) + print("fn_name: ", fn_name, metrics.latency) if self.causal: tflops *= 0.5 if self.mode == BenchmarkMode.BWD: From e652eb014e8c93d51aaabd70bae7c37ee3b872a1 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Mon, 2 Dec 2024 19:25:25 -0800 Subject: [PATCH 2/9] fix variants Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- tritonbench/kernels/triton_fused_attention.py | 26 +++++++++---------- .../operators/flash_attention/operator.py | 6 ++--- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 1864e5bb..c8931436 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -1644,7 +1644,7 @@ def keep2(conf): for s in [3]#, 4, 7] for w in [4]#, 8] ] -configsBwd2 = [ +configsBwdWs = [ ( triton.Config( { @@ -1655,8 +1655,8 @@ def keep2(conf): }, num_stages=s, num_warps=w, - num_buffers_warp_spec=0, - num_consumer_groups=0, + num_buffers_warp_spec=2, + num_consumer_groups=2, ) if has_warp_spec else triton.Config( @@ -1670,8 +1670,8 @@ def keep2(conf): num_warps=w, ) ) - for BM in [32] #32, 64] # BLOCK_N1 % BLOCK_M1 == 0 - for BN in [64] #64, 128] + for BM in [64] #32, 64] # BLOCK_N1 % BLOCK_M1 == 0 + for BN in [128] #[64] #64, 128] for s in [3]#, 4, 7] for w in [4]#, 8] ] @@ -1922,9 +1922,9 @@ def _attn_bwd( ) -@triton.autotune(list(filter(keep2, configsBwd2)), key=["N_CTX"]) +@triton.autotune(list(filter(keep2, configsBwdWs)), key=["N_CTX"]) @triton.jit -def _attn_bwd2( +def _attn_bwd_ws( Q, K, V, @@ -1978,7 +1978,7 @@ def _attn_bwd2( class _attention_opt(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, causal, sm_scale, baseVariant): #, bwdVariant): + def forward(ctx, q, k, v, causal, sm_scale, baseVariant, bwdVariant): # shape constraints HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] # when v is in float8_e5m2 it is transposed. @@ -2366,7 +2366,7 @@ def grid_tma_persistent(META): ctx.sm_scale = sm_scale ctx.HEAD_DIM = HEAD_DIM_K ctx.causal = causal - #ctx.bwdVariant = bwdVariant + ctx.bwdVariant = bwdVariant # If we want to use different variants for bwd, save bwd mode here. return o @@ -2385,8 +2385,6 @@ def backward(ctx, do): #NUM_WARPS, NUM_STAGES = 4, 5 #BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - NUM_WARPS, NUM_STAGES = 4, 3 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 128, 128, 64 BLK_SLICE_FACTOR = 2 RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) @@ -2409,7 +2407,7 @@ def backward(ctx, do): grid = lambda args: (N_CTX // args["BLOCK_N1"], 1, BATCH * N_HEAD) #grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) print(q.stride(0), q.stride(1), q.stride(2), q.stride(3)) - if True: #ctx.bwdVariant == "base": + if ctx.bwdVariant == "base": _attn_bwd[grid]( q, arg_k, @@ -2436,8 +2434,8 @@ def backward(ctx, do): #num_warps=NUM_WARPS, # #num_stages=NUM_STAGES, # ) - else: #if ctx.bwdVariant == "base2": - _attn_bwd2[grid]( + elif ctx.bwdVariant == "ws": + _attn_bwd_ws[grid]( q, arg_k, v, diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 370ae727..a9d80833 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -252,11 +252,11 @@ def triton_tutorial_flash_v2( ) -> Callable: # base: do not enable TMA/WarpSpec/CompPipe return lambda: triton_tutorial_FA2_opt( - q, k, v, self.causal, self.sm_scale, "base"#, "base" + q, k, v, self.causal, self.sm_scale, "base", "base" ) @register_benchmark() - def triton_tutorial_flash_v2_bwd2( + def triton_tutorial_flash_v2_bwd_ws( self, q: torch.Tensor, k: torch.Tensor, @@ -264,7 +264,7 @@ def triton_tutorial_flash_v2_bwd2( ) -> Callable: # base: do not enable TMA/WarpSpec/CompPipe return lambda: triton_tutorial_FA2_opt( - q, k, v, self.causal, self.sm_scale, "base", "base2" + q, k, v, self.causal, self.sm_scale, "base", "ws" ) @register_benchmark(enabled=HAS_CUDA_124) From 50db08c980abcc7c4747f0bcad113a00fb92db22 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Mon, 2 Dec 2024 21:34:13 -0800 Subject: [PATCH 3/9] make a copy for ws version Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- tritonbench/kernels/triton_fused_attention.py | 332 +++++++++++++++++- 1 file changed, 324 insertions(+), 8 deletions(-) diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index c8931436..9466682c 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -1497,6 +1497,131 @@ def _attn_bwd_dkdv( start_m, num_steps, # MASK: tl.constexpr, +): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_k = tl.arange(0, HEAD_DIM) + qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(qT_ptrs) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + qkT = tl.dot(k, qT) + #dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + pT = tl.math.exp2(qkT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = offs_m[None, :] >= offs_n[:, None] + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs) + # Compute dV. + ppT = pT + ppT = ppT.to(tl.bfloat16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.bfloat16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_tok + do_ptrs += step_m * stride_tok + return dk, dv + + +# the main inner-loop logic for computing dQ +@triton.jit +def _attn_bwd_dq( + dq, + q, + K, + V, # + do, + m, + D, + # shared by Q/K/V/DO. + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, + # Filled in by the wrapper. + start_m, + start_n, + num_steps, # + MASK: tl.constexpr, +): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + qk = tl.dot(q, kT) + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = offs_m[:, None] >= offs_n[None, :] + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.bfloat16) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_tok + vT_ptrs += step_n * stride_tok + return dq + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _attn_bwd_dkdv_ws( + dk, + dv, # + Q, + k, + v, + sm_scale, # + DO, # + M, + D, # + # shared by Q/K/V/DO. + stride_tok, + stride_d, # + H, + N_CTX, + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + HEAD_DIM: tl.constexpr, # + # Filled in by the wrapper. + start_n, + start_m, + num_steps, # + MASK: tl.constexpr, ): offs_m = start_m + tl.arange(0, BLOCK_M1) offs_n = start_n + tl.arange(0, BLOCK_N1) @@ -1546,7 +1671,7 @@ def _attn_bwd_dkdv( # the main inner-loop logic for computing dQ @triton.jit -def _attn_bwd_dq( +def _attn_bwd_dq_ws( dq, q, K, @@ -1734,9 +1859,8 @@ def _attn_bwd_compute( dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) # load K and V: they stay in SRAM throughout the inner loop. - with tl.async_task([0]): - k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) - v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) num_steps = BLOCK_N1 // MASK_BLOCK_M1 @@ -1805,9 +1929,8 @@ def _attn_bwd_compute( MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR offs_m = start_m + tl.arange(0, BLOCK_M2) - with tl.async_task([0]): - q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) - do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) m = tl.load(M + offs_m) @@ -1868,6 +1991,199 @@ def _attn_bwd_compute( tl.store(dq_ptrs, dq) +@triton.jit +def _attn_bwd_compute_ws( + Q, + K, + V, + sm_scale, # + DO, # + DQ, + DK, + DV, # + M, + D, + # shared by Q/K/V/DO. + stride_z, + stride_h, + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + HEAD_DIM: tl.constexpr, +): + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # load scales + offs_k = tl.arange(0, HEAD_DIM) + + start_n = pid * BLOCK_N1 + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + with tl.async_task([0]): + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + + dk, dv = _attn_bwd_dkdv_ws( + dk, + dv, # + Q, + k, + v, + sm_scale, # + DO, # + M, + D, # + stride_tok, + stride_d, # + H, + N_CTX, # + MASK_BLOCK_M1, + BLOCK_N1, + HEAD_DIM, # + start_n, + start_m, + num_steps, # + MASK=True, # + ) + + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + # Compute dK and dV for non-masked blocks. + dk, dv = _attn_bwd_dkdv_ws( # + dk, + dv, # + Q, + k, + v, + sm_scale, # + DO, # + M, + D, # + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M1, + BLOCK_N1, + HEAD_DIM, # + start_n, + start_m, + num_steps, # + MASK=False, # + ) + + with tl.async_task([1, 2]): + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv) + + # Write back dK. + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + with tl.async_task([0]): + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _attn_bwd_dq_ws( + dq, + q, + K, + V, # + do, + m, + D, # + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M2, + MASK_BLOCK_N2, + HEAD_DIM, # + start_m, + end_n - num_steps * MASK_BLOCK_N2, + num_steps, # + MASK=True, # + ) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _attn_bwd_dq_ws( + dq, + q, + K, + V, # + do, + m, + D, # + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M2, + BLOCK_N2, + HEAD_DIM, # + start_m, + end_n - num_steps * BLOCK_N2, + num_steps, # + MASK=False, # + ) + # Write back dQ. + with tl.async_task([1, 2]): + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + dq *= LN2 + tl.store(dq_ptrs, dq) + + @triton.autotune(list(filter(keep2, configsBwd)), key=["N_CTX"]) @triton.jit def _attn_bwd( @@ -1949,7 +2265,7 @@ def _attn_bwd_ws( BLK_SLICE_FACTOR: tl.constexpr, # HEAD_DIM: tl.constexpr, ): - _attn_bwd_compute( + _attn_bwd_compute_ws( Q, K, V, From 8d2ca2edbb69164d974731e4446b947c11d4ad1a Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Mon, 2 Dec 2024 21:54:24 -0800 Subject: [PATCH 4/9] use [1,2] Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- tritonbench/kernels/triton_fused_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 9466682c..9664fe88 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -1640,7 +1640,7 @@ def _attn_bwd_dkdv_ws( m = tl.load(M + offs_m) #with tl.async_task([0]): # do = tl.load(do_ptrs) - with tl.async_task([1]): + with tl.async_task([1, 2]): qkT = tl.dot(k, qT) #dpT = tl.dot(v, tl.trans(do)).to(tl.float32) pT = tl.math.exp2(qkT - m[None, :]) @@ -1651,7 +1651,7 @@ def _attn_bwd_dkdv_ws( with tl.async_task([0]): do = tl.load(do_ptrs) # Compute dV. - with tl.async_task([1]): + with tl.async_task([1, 2]): ppT = pT ppT = ppT.to(tl.bfloat16) dv += tl.dot(ppT, do) @@ -1708,7 +1708,7 @@ def _attn_bwd_dq_ws( with tl.async_task([0]): kT = tl.load(kT_ptrs) vT = tl.load(vT_ptrs) - with tl.async_task([1]): + with tl.async_task([1, 2]): qk = tl.dot(q, kT) p = tl.math.exp2(qk - m) # Autoregressive masking. From bc63261e8bc0a4d916beb27f1d1487616565e440 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Mon, 2 Dec 2024 21:57:45 -0800 Subject: [PATCH 5/9] minor fix for original variant Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- tritonbench/kernels/triton_fused_attention.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 9664fe88..c0c4f1f5 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -1513,7 +1513,6 @@ def _attn_bwd_dkdv( offs_m = curr_m + tl.arange(0, BLOCK_M1) m = tl.load(M + offs_m) qkT = tl.dot(k, qT) - #dpT = tl.dot(v, tl.trans(do)).to(tl.float32) pT = tl.math.exp2(qkT - m[None, :]) # Autoregressive masking. if MASK: @@ -1930,8 +1929,8 @@ def _attn_bwd_compute( offs_m = start_m + tl.arange(0, BLOCK_M2) q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) - do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) m = tl.load(M + offs_m) m = m[:, None] From 79c804f33d7d6f75306a4520ab32eb06bfa7e2a4 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Tue, 3 Dec 2024 08:39:04 -0800 Subject: [PATCH 6/9] skip tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/test_gpu/skip_tests_h100_pytorch.yaml | 1 + test/test_gpu/skip_tests_h100_triton_main.yaml | 1 + 2 files changed, 2 insertions(+) diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index c2d62fdb..2ef84ca8 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -19,6 +19,7 @@ flash_attention: - triton_tutorial_flash_v2_ws - triton_tutorial_flash_v2_tma_ws - triton_tutorial_flash_v2_tma_ws_persistent + - triton_tutorial_flash_v2_bwd_ws fp8_attention: - colfax_fmha # triton_flash_v2 requires triton-main diff --git a/test/test_gpu/skip_tests_h100_triton_main.yaml b/test/test_gpu/skip_tests_h100_triton_main.yaml index 58ac80d9..c2b1ac79 100644 --- a/test/test_gpu/skip_tests_h100_triton_main.yaml +++ b/test/test_gpu/skip_tests_h100_triton_main.yaml @@ -11,6 +11,7 @@ flash_attention: - triton_tutorial_flash_v2_ws - triton_tutorial_flash_v2_tma_ws - triton_tutorial_flash_v2_tma_ws_persistent + - triton_tutorial_flash_v2_bwd_ws fp8_attention: # fb-only kernel - colfax_fmha From 757db43666a0688db63ef2bb7353ae03bb63f9dd Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Tue, 3 Dec 2024 08:40:34 -0800 Subject: [PATCH 7/9] lint Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- tritonbench/kernels/triton_fused_attention.py | 65 +++++++++---------- .../operators/flash_attention/operator.py | 2 - 2 files changed, 32 insertions(+), 35 deletions(-) diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index c0c4f1f5..6e10fe62 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -458,10 +458,10 @@ def _attn_fwd_inner_ws( num_warps=w, ) ) - for BM in [128] #64, 128] - for BN in [128] #64, 128] - for s in [3] #3, 4, 7] - for w in [8] #4, 8] + for BM in [128] # 64, 128] + for BN in [128] # 64, 128] + for s in [3] # 3, 4, 7] + for w in [8] # 4, 8] ] # TMA, WS, and CompPipe configsTmaWS = [ @@ -1637,11 +1637,11 @@ def _attn_bwd_dkdv_ws( # Load m before computing qk to reduce pipeline stall. offs_m = curr_m + tl.arange(0, BLOCK_M1) m = tl.load(M + offs_m) - #with tl.async_task([0]): + # with tl.async_task([0]): # do = tl.load(do_ptrs) with tl.async_task([1, 2]): qkT = tl.dot(k, qT) - #dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + # dpT = tl.dot(v, tl.trans(do)).to(tl.float32) pT = tl.math.exp2(qkT - m[None, :]) # Autoregressive masking. if MASK: @@ -1746,10 +1746,10 @@ def keep2(conf): "BLOCK_M2": BN, "BLOCK_N2": BM, }, - num_stages=s, #0 or s, + num_stages=s, # 0 or s, num_warps=w, - num_buffers_warp_spec=0, #0 or 2, - num_consumer_groups=0, #0 or 1, + num_buffers_warp_spec=0, # 0 or 2, + num_consumer_groups=0, # 0 or 1, ) if has_warp_spec else triton.Config( @@ -1763,10 +1763,10 @@ def keep2(conf): num_warps=w, ) ) - for BM in [64] #32, 64] # BLOCK_N1 % BLOCK_M1 == 0 - for BN in [128] #64, 128] - for s in [3]#, 4, 7] - for w in [4]#, 8] + for BM in [64] # 32, 64] # BLOCK_N1 % BLOCK_M1 == 0 + for BN in [128] # 64, 128] + for s in [3] # , 4, 7] + for w in [4] # , 8] ] configsBwdWs = [ ( @@ -1794,10 +1794,10 @@ def keep2(conf): num_warps=w, ) ) - for BM in [64] #32, 64] # BLOCK_N1 % BLOCK_M1 == 0 - for BN in [128] #[64] #64, 128] - for s in [3]#, 4, 7] - for w in [4]#, 8] + for BM in [64] # 32, 64] # BLOCK_N1 % BLOCK_M1 == 0 + for BN in [128] # [64] #64, 128] + for s in [3] # , 4, 7] + for w in [4] # , 8] ] @@ -2698,8 +2698,8 @@ def backward(ctx, do): BATCH, N_HEAD, N_CTX = q.shape[:3] PRE_BLOCK = 128 - #NUM_WARPS, NUM_STAGES = 4, 5 - #BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + # NUM_WARPS, NUM_STAGES = 4, 5 + # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 BLK_SLICE_FACTOR = 2 RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) @@ -2720,8 +2720,7 @@ def backward(ctx, do): HEAD_DIM=ctx.HEAD_DIM, # ) grid = lambda args: (N_CTX // args["BLOCK_N1"], 1, BATCH * N_HEAD) - #grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) - print(q.stride(0), q.stride(1), q.stride(2), q.stride(3)) + # grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) if ctx.bwdVariant == "base": _attn_bwd[grid]( q, @@ -2740,14 +2739,14 @@ def backward(ctx, do): q.stride(3), # N_HEAD, N_CTX, # - #BLOCK_M1=BLOCK_M1, - #BLOCK_N1=BLOCK_N1, # - #BLOCK_M2=BLOCK_M2, - #BLOCK_N2=BLOCK_N2, # + # BLOCK_M1=BLOCK_M1, + # BLOCK_N1=BLOCK_N1, # + # BLOCK_M2=BLOCK_M2, + # BLOCK_N2=BLOCK_N2, # BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # HEAD_DIM=ctx.HEAD_DIM, # - #num_warps=NUM_WARPS, # - #num_stages=NUM_STAGES, # + # num_warps=NUM_WARPS, # + # num_stages=NUM_STAGES, # ) elif ctx.bwdVariant == "ws": _attn_bwd_ws[grid]( @@ -2767,14 +2766,14 @@ def backward(ctx, do): q.stride(3), # N_HEAD, N_CTX, # - #BLOCK_M1=BLOCK_M1, - #BLOCK_N1=BLOCK_N1, # - #BLOCK_M2=BLOCK_M2, - #BLOCK_N2=BLOCK_N2, # + # BLOCK_M1=BLOCK_M1, + # BLOCK_N1=BLOCK_N1, # + # BLOCK_M2=BLOCK_M2, + # BLOCK_N2=BLOCK_N2, # BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # HEAD_DIM=ctx.HEAD_DIM, # - #num_warps=NUM_WARPS, # - #num_stages=NUM_STAGES, # + # num_warps=NUM_WARPS, # + # num_stages=NUM_STAGES, # ) return dq, dk, dv, None, None, None, None diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index a9d80833..0d4ad2d9 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -470,8 +470,6 @@ def tflops( BATCH, H, N_CTX, D_HEAD = q.shape flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD tflops = 2 * flops_per_matmul - print("causal, mode: ", self.causal, self.mode) - print("fn_name: ", fn_name, metrics.latency) if self.causal: tflops *= 0.5 if self.mode == BenchmarkMode.BWD: From 2867e2fb00efee63f280f5f43a5087e4201d479c Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Wed, 4 Dec 2024 17:12:46 -0800 Subject: [PATCH 8/9] small fix Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- tritonbench/kernels/triton_fused_attention.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 6e10fe62..20c29f70 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -1634,12 +1634,10 @@ def _attn_bwd_dkdv_ws( for blk_idx in range(num_steps): with tl.async_task([0]): qT = tl.load(qT_ptrs) - # Load m before computing qk to reduce pipeline stall. - offs_m = curr_m + tl.arange(0, BLOCK_M1) - m = tl.load(M + offs_m) - # with tl.async_task([0]): - # do = tl.load(do_ptrs) with tl.async_task([1, 2]): + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) qkT = tl.dot(k, qT) # dpT = tl.dot(v, tl.trans(do)).to(tl.float32) pT = tl.math.exp2(qkT - m[None, :]) @@ -1661,10 +1659,11 @@ def _attn_bwd_dkdv_ws( dsT = pT * (dpT - Di[None, :]) dsT = dsT.to(tl.bfloat16) dk += tl.dot(dsT, tl.trans(qT)) - # Increment pointers. - curr_m += step_m - qT_ptrs += step_m * stride_tok - do_ptrs += step_m * stride_tok + # Increment pointers. + curr_m += step_m + with tl.async_task([0]): + qT_ptrs += step_m * stride_tok + do_ptrs += step_m * stride_tok return dk, dv @@ -1722,10 +1721,11 @@ def _attn_bwd_dq_ws( # Compute dQ. # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. dq += tl.dot(ds, tl.trans(kT)) - # Increment pointers. - curr_n += step_n - kT_ptrs += step_n * stride_tok - vT_ptrs += step_n * stride_tok + # Increment pointers. + curr_n += step_n + with tl.async_task([0]): + kT_ptrs += step_n * stride_tok + vT_ptrs += step_n * stride_tok return dq From 02ef2c2a367ed60351da26fa9c95827f06b5753c Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Wed, 4 Dec 2024 17:17:29 -0800 Subject: [PATCH 9/9] fix fp8 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- tritonbench/kernels/triton_fused_attention.py | 8 ++++---- tritonbench/operators/fp8_attention/operator.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 20c29f70..aa953c9c 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -458,10 +458,10 @@ def _attn_fwd_inner_ws( num_warps=w, ) ) - for BM in [128] # 64, 128] - for BN in [128] # 64, 128] - for s in [3] # 3, 4, 7] - for w in [8] # 4, 8] + for BM in [64, 128] + for BN in [64, 128] + for s in [3, 4, 7] + for w in [4, 8] ] # TMA, WS, and CompPipe configsTmaWS = [ diff --git a/tritonbench/operators/fp8_attention/operator.py b/tritonbench/operators/fp8_attention/operator.py index 0131ca73..71c0ffcc 100644 --- a/tritonbench/operators/fp8_attention/operator.py +++ b/tritonbench/operators/fp8_attention/operator.py @@ -110,7 +110,7 @@ def triton_flash_v2( triton_q, triton_k, triton_v = self.triton_preprocess(q, k, v) # full fp8 will be enabled if type of q,k,v is fp8 return lambda: triton_attention( - triton_q, triton_k, triton_v, False, self.sm_scale, "base" + triton_q, triton_k, triton_v, False, self.sm_scale, "base", "base" ) def get_x_val(self, _example_inputs) -> Tuple[int, int, int, int]: