From eb06157b195ff2da727da233d1b4e84ff4507427 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 21 Jan 2025 09:26:01 -0800 Subject: [PATCH] [pallas:triton] The lowering now uses PTX instead of Triton IR This change improves the stability and backward compatibility of Pallas Triton calls, because unlike PTX, the Triton dialect has no stability guarantees and does change in practice. A few notes * Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead, compilation is done via a new PjRt extension, which uses its own compilation pipeline mirrored after the one in the Triton Python bindings. * The implementation of the old custom call used by Pallas Triton is deprecated and will be removed after 6 months as per [compatibility guarantees] [*] [*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees PiperOrigin-RevId: 717950300 --- jax/_src/lib/triton.py | 58 +++++++++ .../pallas/triton/pallas_call_registration.py | 112 ++++++++++++++---- jax_plugins/cuda/__init__.py | 7 ++ jaxlib/BUILD | 3 +- jaxlib/gpu_plugin_extension.cc | 63 +++++++++- tests/pallas/export_pallas_test.py | 11 +- 6 files changed, 225 insertions(+), 29 deletions(-) diff --git a/jax/_src/lib/triton.py b/jax/_src/lib/triton.py index c0a5202e9dbc..fdd858f52089 100644 --- a/jax/_src/lib/triton.py +++ b/jax/_src/lib/triton.py @@ -12,4 +12,62 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading +from typing import Protocol + from jaxlib.triton import dialect # noqa: F401 # pytype: disable=import-error + + +class CompilationResult(Protocol): + asm: str + smem_bytes: int + cluster_dim_x: int + cluster_dim_y: int + cluster_dim_z: int + + +class CompilationHandler(Protocol): + + def __call__( + self, + module: bytes, + arch_name: str, + num_warps: int, + num_ctas: int, + num_stages: int, + ) -> CompilationResult: + ... + + +_compilation_handlers: dict[str, CompilationHandler] = {} +_compilation_handlers_lock = threading.Lock() + + +def register_compilation_handler( + platform: str, handler: CompilationHandler +) -> None: + with _compilation_handlers_lock: + if existing_handler := _compilation_handlers.get(platform): + raise RuntimeError( + f'Platform {platform} already has a Triton compilation handler:' + f' {existing_handler}' + ) + _compilation_handlers[platform] = handler + + +def compile( + platform: str, + module: bytes, + arch_name: str, + *, + num_warps: int, + num_ctas: int, + num_stages: int, +) -> CompilationResult: + with _compilation_handlers_lock: + handler = _compilation_handlers.get(platform) + if handler is None: + raise RuntimeError( + f'Platform {platform} does not have a Triton compilation handler' + ) + return handler(module, arch_name, num_warps, num_ctas, num_stages) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 59b1b86f33fc..1a212c7bc542 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -17,10 +17,16 @@ from __future__ import annotations import io +import re from typing import Any +import zlib +import jax import jax._src.core as jax_core from jax._src.interpreters import mlir +from jax._src.lib import gpu_triton as triton_kernel_call_lib +from jax._src.lib import triton +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.pallas import core as pallas_core from jax._src.pallas.triton import lowering @@ -51,7 +57,7 @@ def pallas_call_lowering( cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], ): - del interpret, out_avals + del interpret, cost_estimate, out_avals if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError( "dynamic grid bounds not supported in the Triton backend" @@ -77,6 +83,11 @@ def pallas_call_lowering( print("The grid mapping for pallas_call {name_and_src_info}:") print(grid_mapping) + # Sanitize the name to conform to NVPTX requirements. We do this here + # to avoid the need to fetch the new name from PTX post compilation. + name_and_src_info = name_and_src_info.replace( + name=re.sub(r"[^a-zA-Z0-9_$]", "_", name_and_src_info.name) + ) lowering_result = lowering.lower_jaxpr_to_triton_module( jaxpr, grid_mapping, name_and_src_info, lowering_platform ) @@ -86,35 +97,88 @@ def pallas_call_lowering( print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True)) grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid) - out_types = [ + buf = io.BytesIO() + module_op.write_bytecode(buf) + + if jaxlib_version < (0, 5, 1): + # AOT Triton compilation is only available on jaxlib 0.5.1+. + out_types = [ ir.RankedTensorType.get(bm.array_shape_dtype.shape, mlir.dtype_to_ir_type(bm.array_shape_dtype.dtype)) for bm in grid_mapping.block_mappings_output - ] - buf = io.BytesIO() - module_op.write_bytecode(buf) - backend_config = dict( - name=ir.StringAttr.get(name_and_src_info.name), - ir=ir.StringAttr.get(buf.getvalue()), - num_stages=mlir.i32_attr(num_stages), - num_warps=mlir.i32_attr(num_warps), - grid_x=mlir.i32_attr(grid_x), - grid_y=mlir.i32_attr(grid_y), - grid_z=mlir.i32_attr(grid_z), - debug=ir.BoolAttr.get(debug), + ] + backend_config = dict( + name=ir.StringAttr.get(name_and_src_info.name), + ir=ir.StringAttr.get(buf.getvalue()), + num_stages=mlir.i32_attr(num_stages), + num_warps=mlir.i32_attr(num_warps), + grid_x=mlir.i32_attr(grid_x), + grid_y=mlir.i32_attr(grid_y), + grid_z=mlir.i32_attr(grid_z), + debug=ir.BoolAttr.get(debug), + ) + if "serialized_metadata" in (triton_params or {}): + # This field is unstable and may be removed in the future. + if triton_params["serialized_metadata"] is not None: + backend_config["serialized_metadata"] = ir.StringAttr.get( + triton_params["serialized_metadata"] + ) + return mlir.custom_call( + call_target_name="__gpu$xla.gpu.triton", + result_types=out_types, + operands=in_nodes, + backend_config=backend_config, + api_version=4, + operand_layouts=avals_to_layouts(ctx.avals_in), + result_layouts=avals_to_layouts(ctx.avals_out), + operand_output_aliases=dict(input_output_aliases), + ).results + + try: + gpu_device, *_ = jax.local_devices(backend="gpu") + except RuntimeError: + raise NotImplementedError( + "Pallas Triton lowering requires a GPU device." + ) from None + + compilation_result = triton.compile( + lowering_platform.upper(), + buf.getvalue(), + str(gpu_device.compute_capability), + num_warps=num_warps, + num_ctas=1, + num_stages=num_stages, + ) + kernel = triton_kernel_call_lib.TritonKernel( + name_and_src_info.name, + num_warps, + compilation_result.smem_bytes, + compilation_result.asm, + module_op.get_asm(enable_debug_info=True, pretty_debug_info=True), + triton_kernel_call_lib.get_compute_capability(0), + compilation_result.cluster_dim_x, + compilation_result.cluster_dim_y, + compilation_result.cluster_dim_z, + ) + kernel_call = triton_kernel_call_lib.TritonKernelCall( + kernel, + grid_x, + grid_y, + grid_z, + [triton_kernel_call_lib.create_array_parameter(0, 16)] + * (len(ctx.avals_in) + len(ctx.avals_out)), ) - if "serialized_metadata" in (triton_params or {}): - # This field is unstable and may be removed in the future. - if triton_params["serialized_metadata"] is not None: - backend_config["serialized_metadata"] = ir.StringAttr.get( - triton_params["serialized_metadata"] - ) + # TODO(b/392558289): Migrate to ``jax.ffi``. return mlir.custom_call( - call_target_name="__gpu$xla.gpu.triton", - result_types=out_types, + call_target_name="triton_kernel_call", + result_types=[*map(mlir.aval_to_ir_type, ctx.avals_out)], # type: ignore[list-item] operands=in_nodes, - backend_config=backend_config, - api_version=4, + backend_config=zlib.compress( + kernel_call.to_proto( + name_and_src_info.name, + triton_params.get("serialized_metadata") or b"", + ) + ), operand_layouts=avals_to_layouts(ctx.avals_in), result_layouts=avals_to_layouts(ctx.avals_out), operand_output_aliases=dict(input_output_aliases), diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index a09e21c6dd77..68281f4f32b3 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -18,6 +18,7 @@ import os import pathlib +from jax._src.lib import triton from jax._src.lib import xla_client import jax._src.xla_bridge as xb @@ -99,5 +100,11 @@ def initialize(): cuda_plugin_extension.register_custom_type_id, c_api ), ) + triton.register_compilation_handler( + "CUDA", + functools.partial( + cuda_plugin_extension.compile_triton_to_asm, c_api + ), + ) else: logger.warning('cuda_plugin_extension is not found.') diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 843ccb112871..a5680920d808 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -234,8 +234,8 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@nanobind", - "@tsl//tsl/platform:statusor", "@xla//xla:util", "@xla//xla/ffi/api:c_api", "@xla//xla/pjrt:status_casters", @@ -243,6 +243,7 @@ cc_library( "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_helpers", + "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", "@xla//xla/python:py_client_gpu", "@xla//xla/tsl/python/lib/core:numpy", ], diff --git a/jaxlib/gpu_plugin_extension.cc b/jaxlib/gpu_plugin_extension.cc index 46263bdcd40c..8863ebf19b39 100644 --- a/jaxlib/gpu_plugin_extension.cc +++ b/jaxlib/gpu_plugin_extension.cc @@ -16,23 +16,28 @@ limitations under the License. #include "jaxlib/gpu_plugin_extension.h" #include +#include +#include #include #include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_ffi_extension.h" #include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/c/pjrt_c_api_triton_extension.h" #include "xla/pjrt/status_casters.h" #include "xla/python/py_client_gpu.h" #include "xla/tsl/python/lib/core/numpy.h" #include "xla/util.h" -#include "tsl/platform/statusor.h" namespace nb = nanobind; @@ -40,6 +45,44 @@ namespace xla { namespace { +struct TritonCompilationResult { + std::string asm_text; + int64_t smem_bytes; + int cluster_dim_x; + int cluster_dim_y; + int cluster_dim_z; +}; + +absl::StatusOr CompileTritonToASM( + const PJRT_Api* c_api, absl::string_view module, + absl::string_view arch_name, int num_warps, int num_ctas, int num_stages) { + const PJRT_Triton_Extension* triton_ext = + pjrt::FindExtension( + c_api, PJRT_Extension_Type::PJRT_Extension_Type_Triton); + if (triton_ext == nullptr) { + return Unimplemented("The plugin does not have a Triton extension."); + } + PJRT_Triton_Compile_Args args; + args.struct_size = PJRT_Triton_Compile_Args_STRUCT_SIZE; + args.module = module.data(); + args.module_size = module.size(); + args.arch_name = arch_name.data(); + args.arch_name_size = arch_name.size(); + args.num_warps = num_warps; + args.num_ctas = num_ctas; + args.num_stages = num_stages; + RETURN_STATUS_IF_PJRT_ERROR(triton_ext->compile(&args), c_api); + auto asm_text = std::string(args.out_asm, args.out_asm_size); + delete[] args.out_asm; + return TritonCompilationResult{ + .asm_text = std::string(args.out_asm, args.out_asm_size), + .smem_bytes = args.out_smem_bytes, + .cluster_dim_x = args.out_cluster_dim_x, + .cluster_dim_y = args.out_cluster_dim_y, + .cluster_dim_z = args.out_cluster_dim_z, + }; +} + absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, const char* fn_name_c_str, size_t fn_name_size, nb::object fn, @@ -170,6 +213,24 @@ nb::dict Registrations() { void BuildGpuPluginExtension(nanobind::module_& m) { tsl::ImportNumpy(); + + nb::class_(m, "TritonCompilationResult") + .def_ro("asm", &TritonCompilationResult::asm_text) + .def_ro("smem_bytes", &TritonCompilationResult::smem_bytes) + .def_ro("cluster_dim_x", &TritonCompilationResult::cluster_dim_x) + .def_ro("cluster_dim_y", &TritonCompilationResult::cluster_dim_y) + .def_ro("cluster_dim_z", &TritonCompilationResult::cluster_dim_z); + + m.def("compile_triton_to_asm", + [](nb::capsule c_api, nb::bytes module, absl::string_view arch_name, + int num_warps, int num_ctas, int num_stages) { + return xla::ValueOrThrow(CompileTritonToASM( + static_cast(c_api.data()), + absl::string_view(static_cast(module.data()), + module.size()), + arch_name, num_warps, num_ctas, num_stages)); + }); + m.def( "register_custom_call_target", [](nb::capsule c_api, nb::object fn_name_py, nb::object fn, diff --git a/tests/pallas/export_pallas_test.py b/tests/pallas/export_pallas_test.py index c1092357c6d2..ca22c221461e 100644 --- a/tests/pallas/export_pallas_test.py +++ b/tests/pallas/export_pallas_test.py @@ -49,13 +49,18 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) )(x, y) + platforms = ["tpu"] + if jtu.device_under_test() == "gpu": + # Pallas Triton requires a GPU device to be available during lowering. + platforms.append("cuda") + a = np.arange(8 * 16, dtype=np.int32).reshape((8, 16)) exp = export.export( add_vectors, - platforms=["tpu", "cuda"], - # The Pallas GPU custom call is not enabled for export by default. + platforms=platforms, + # The Pallas Triton custom call is not enabled for export by default. disabled_checks=[ - export.DisabledSafetyCheck.custom_call("__gpu$xla.gpu.triton") + export.DisabledSafetyCheck.custom_call("triton_kernel_call") ] )(a, a)