From 304a97ee6f4f9d08c6ec2e0e96d0de7ccd9d79f7 Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Tue, 19 Nov 2024 13:21:57 -0800 Subject: [PATCH] [user triton] Ignore backend-specific args in the TTIR analysis (#141062) Fixes #140800. On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141062 Approved by: https://github.com/oulgen (cherry picked from commit b740a1b96cac9d1ca65ffaf59e66185306e8afd0) --- torch/_higher_order_ops/triton_kernel_wrap.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 5a1ad4405c5ec3..f6c5946a11febb 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -136,8 +136,23 @@ def generate_ttir(kernel, kwargs): assert isinstance(kernel, JITFunction) + context = triton._C.libtriton.ir.context() + target = triton.runtime.driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) + options = backend.parse_options({}) + + # ignore backend-specific kwargs same way as in the native Triton code + # https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596 + # why this is important for user-defined Triton kernels on AMD: https://github.com/pytorch/pytorch/issues/140800 + for name in list(kwargs): + if name not in kernel.arg_names and name in options.__dict__: + kwargs.pop(name) + if len(kwargs) != len(kernel.arg_names): - raise ValueError("Incorrect number of arguments passed to kernel") + raise ValueError( + "Incorrect number of arguments passed to kernel: " + f"passed {list(kwargs.keys())}, expected {kernel.arg_names}." + ) # Replace all SymExprs with a regular value for TTIR generation # Replace all FakeTensor/TensorBox with real tensors @@ -168,10 +183,6 @@ def generate_ttir(kernel, kwargs): if i not in kernel.constexprs } - context = triton._C.libtriton.ir.context() - target = triton.runtime.driver.active.get_current_target() - backend = triton.compiler.compiler.make_backend(target) - options = backend.parse_options({}) triton._C.libtriton.ir.load_dialects(context) backend.load_dialects(context)