diff --git a/intel_extension_for_pytorch/_inductor/xpu/triton_heuristics.py b/intel_extension_for_pytorch/_inductor/xpu/triton_heuristics.py index 7847fc0b7..8c7a7ee13 100644 --- a/intel_extension_for_pytorch/_inductor/xpu/triton_heuristics.py +++ b/intel_extension_for_pytorch/_inductor/xpu/triton_heuristics.py @@ -162,16 +162,25 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Dict): ] def_args = [name for name in self.fn.arg_names if name not in cfg.kwargs] + binary_shared = ( + binary.shared if hasattr(binary, "shared") else binary.metadata.shared + ) + scope = { "grid_meta": cfg.kwargs, "bin": binary, - "torch": torch, - "set_device": torch.xpu.set_device, - "current_device": torch.xpu.current_device, + "launch_enter_hook": binary.launch_enter_hook, + "launch_exit_hook": binary.launch_exit_hook, + "metadata": binary.metadata, + "shared": binary_shared, } - scope["runner"] = get_first_attr(binary, "run", "c_wrapper") - scope["function"] = get_first_attr(binary, "function", "cu_function") + scope["num_warps"] = ( + binary.num_warps + if hasattr(binary, "num_warps") + else binary.metadata.num_warps + ) + scope["cta_args"] = ( (binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims")) if hasattr(binary, "num_ctas") @@ -181,15 +190,82 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Dict): else () ) ) - scope["num_warps"] = ( - binary.num_warps - if hasattr(binary, "num_warps") - else binary.metadata.num_warps - ) - scope["shared"] = ( - binary.shared if hasattr(binary, "shared") else binary.metadata.shared + + scope["function"] = get_first_attr(binary, "function", "cu_function") + + def get_launch_args_without_kernel_launch_metadata( + grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args before CompiledKernel.launch_metadata is added. + """ + return ( + grid_0, + grid_1, + grid_2, + num_warps, + *cta_args, + shared, + stream, + function, + launch_enter_hook, + launch_exit_hook, + metadata, + ) + + def get_launch_args_with_kernel_launch_metadata( + grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args after CompiledKernel.launch_metadata is added + by https://github.com/openai/triton/pull/3492 . + """ + return ( + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + bin.launch_metadata(grid, stream, *args), + launch_enter_hook, + launch_exit_hook, + ) + + scope["get_launch_args"] = ( + get_launch_args_with_kernel_launch_metadata + if hasattr(binary, "launch_metadata") + else get_launch_args_without_kernel_launch_metadata ) + scope["runner"] = get_first_attr(binary, "run", "c_wrapper") + exec( f""" def launcher({', '.join(def_args)}, grid, stream): @@ -199,10 +275,13 @@ def launcher({', '.join(def_args)}, grid, stream): grid_0, grid_1, grid_2 = grid - runner(grid_0, grid_1, grid_2, num_warps, - *cta_args, shared, - stream, function, None, None, None, - {', '.join(call_args)}) + args = {', '.join(call_args)}, + launch_args = get_launch_args( + grid, grid_0, grid_1, grid_2, stream, function, + metadata, bin, launch_enter_hook, launch_exit_hook, + num_warps, shared, cta_args, args + ) + runner(*launch_args, *args) return bin """.lstrip(), scope,