Skip to content

Commit

Permalink
[inductor] make inductor work with new triton kernel launch API (#584)
Browse files Browse the repository at this point in the history
This commit cherry-picks from the following changes from PyTorch:
- pytorch/pytorch#123076
- pytorch/pytorch#119450
  • Loading branch information
Stonepia authored Apr 7, 2024
1 parent ed69bf5 commit ffc62ab
Showing 1 changed file with 95 additions and 16 deletions.
111 changes: 95 additions & 16 deletions intel_extension_for_pytorch/_inductor/xpu/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,25 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Dict):
]
def_args = [name for name in self.fn.arg_names if name not in cfg.kwargs]

binary_shared = (
binary.shared if hasattr(binary, "shared") else binary.metadata.shared
)

scope = {
"grid_meta": cfg.kwargs,
"bin": binary,
"torch": torch,
"set_device": torch.xpu.set_device,
"current_device": torch.xpu.current_device,
"launch_enter_hook": binary.launch_enter_hook,
"launch_exit_hook": binary.launch_exit_hook,
"metadata": binary.metadata,
"shared": binary_shared,
}

scope["runner"] = get_first_attr(binary, "run", "c_wrapper")
scope["function"] = get_first_attr(binary, "function", "cu_function")
scope["num_warps"] = (
binary.num_warps
if hasattr(binary, "num_warps")
else binary.metadata.num_warps
)

scope["cta_args"] = (
(binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims"))
if hasattr(binary, "num_ctas")
Expand All @@ -181,15 +190,82 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Dict):
else ()
)
)
scope["num_warps"] = (
binary.num_warps
if hasattr(binary, "num_warps")
else binary.metadata.num_warps
)
scope["shared"] = (
binary.shared if hasattr(binary, "shared") else binary.metadata.shared

scope["function"] = get_first_attr(binary, "function", "cu_function")

def get_launch_args_without_kernel_launch_metadata(
grid,
grid_0,
grid_1,
grid_2,
stream,
function,
metadata,
bin,
launch_enter_hook,
launch_exit_hook,
num_warps,
shared,
cta_args,
args,
):
"""
Construct launch args before CompiledKernel.launch_metadata is added.
"""
return (
grid_0,
grid_1,
grid_2,
num_warps,
*cta_args,
shared,
stream,
function,
launch_enter_hook,
launch_exit_hook,
metadata,
)

def get_launch_args_with_kernel_launch_metadata(
grid,
grid_0,
grid_1,
grid_2,
stream,
function,
metadata,
bin,
launch_enter_hook,
launch_exit_hook,
num_warps,
shared,
cta_args,
args,
):
"""
Construct launch args after CompiledKernel.launch_metadata is added
by https://github.com/openai/triton/pull/3492 .
"""
return (
grid_0,
grid_1,
grid_2,
stream,
function,
metadata,
bin.launch_metadata(grid, stream, *args),
launch_enter_hook,
launch_exit_hook,
)

scope["get_launch_args"] = (
get_launch_args_with_kernel_launch_metadata
if hasattr(binary, "launch_metadata")
else get_launch_args_without_kernel_launch_metadata
)

scope["runner"] = get_first_attr(binary, "run", "c_wrapper")

exec(
f"""
def launcher({', '.join(def_args)}, grid, stream):
Expand All @@ -199,10 +275,13 @@ def launcher({', '.join(def_args)}, grid, stream):
grid_0, grid_1, grid_2 = grid
runner(grid_0, grid_1, grid_2, num_warps,
*cta_args, shared,
stream, function, None, None, None,
{', '.join(call_args)})
args = {', '.join(call_args)},
launch_args = get_launch_args(
grid, grid_0, grid_1, grid_2, stream, function,
metadata, bin, launch_enter_hook, launch_exit_hook,
num_warps, shared, cta_args, args
)
runner(*launch_args, *args)
return bin
""".lstrip(),
scope,
Expand Down

0 comments on commit ffc62ab

Please sign in to comment.