Skip to content

Commit

Permalink
Run warp-specialized FP8 rowsise with --warp_specialization (#122)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #122

Reviewed By: xuzhao9, sijiac

Differential Revision: D67675915
  • Loading branch information
htyu authored and facebook-github-bot committed Dec 31, 2024
1 parent 9363aca commit 3e797f1
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tritonbench/operators/fp8_gemm_rowwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def parse_args(args: List[str]) -> argparse.Namespace:
dest="no_use_persistent",
action="store_true",
)
parser.add_argument("--warp_specialization", action="store_true")
parsed_args = parser.parse_args(args)
return parsed_args

Expand Down Expand Up @@ -131,6 +132,7 @@ def __init__(
self.fp8_fast_accum = addmm_args.fp8_fast_accum
self.use_tma = addmm_args.use_tma
self.no_use_persistent = addmm_args.no_use_persistent
self.warp_specialization = addmm_args.warp_specialization

@register_benchmark(enabled=HAS_TRITON, baseline=True)
def _triton(self, xq, wq, x_scale, w_scale) -> Callable:
Expand All @@ -142,6 +144,7 @@ def _triton(self, xq, wq, x_scale, w_scale) -> Callable:
fp8_fast_accum=self.fp8_fast_accum,
tma_persistent=self.use_tma,
no_use_persistent=self.no_use_persistent,
use_warp_specialization=self.warp_specialization,
)

@register_benchmark(
Expand Down

0 comments on commit 3e797f1

Please sign in to comment.