From 02ef2c2a367ed60351da26fa9c95827f06b5753c Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Wed, 4 Dec 2024 17:17:29 -0800 Subject: [PATCH] fix fp8 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- tritonbench/kernels/triton_fused_attention.py | 8 ++++---- tritonbench/operators/fp8_attention/operator.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 20c29f70..aa953c9c 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -458,10 +458,10 @@ def _attn_fwd_inner_ws( num_warps=w, ) ) - for BM in [128] # 64, 128] - for BN in [128] # 64, 128] - for s in [3] # 3, 4, 7] - for w in [8] # 4, 8] + for BM in [64, 128] + for BN in [64, 128] + for s in [3, 4, 7] + for w in [4, 8] ] # TMA, WS, and CompPipe configsTmaWS = [ diff --git a/tritonbench/operators/fp8_attention/operator.py b/tritonbench/operators/fp8_attention/operator.py index 0131ca73..71c0ffcc 100644 --- a/tritonbench/operators/fp8_attention/operator.py +++ b/tritonbench/operators/fp8_attention/operator.py @@ -110,7 +110,7 @@ def triton_flash_v2( triton_q, triton_k, triton_v = self.triton_preprocess(q, k, v) # full fp8 will be enabled if type of q,k,v is fp8 return lambda: triton_attention( - triton_q, triton_k, triton_v, False, self.sm_scale, "base" + triton_q, triton_k, triton_v, False, self.sm_scale, "base", "base" ) def get_x_val(self, _example_inputs) -> Tuple[int, int, int, int]: