Skip to content

Commit

Permalink
fix fp8
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
manman-ren committed Dec 5, 2024
1 parent 2867e2f commit 02ef2c2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions tritonbench/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,10 @@ def _attn_fwd_inner_ws(
num_warps=w,
)
)
for BM in [128] # 64, 128]
for BN in [128] # 64, 128]
for s in [3] # 3, 4, 7]
for w in [8] # 4, 8]
for BM in [64, 128]
for BN in [64, 128]
for s in [3, 4, 7]
for w in [4, 8]
]
# TMA, WS, and CompPipe
configsTmaWS = [
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/fp8_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def triton_flash_v2(
triton_q, triton_k, triton_v = self.triton_preprocess(q, k, v)
# full fp8 will be enabled if type of q,k,v is fp8
return lambda: triton_attention(
triton_q, triton_k, triton_v, False, self.sm_scale, "base"
triton_q, triton_k, triton_v, False, self.sm_scale, "base", "base"
)

def get_x_val(self, _example_inputs) -> Tuple[int, int, int, int]:
Expand Down

0 comments on commit 02ef2c2

Please sign in to comment.