diff --git a/jax/_src/lib/triton.py b/jax/_src/lib/triton.py index c0a5202e9dbc..fdd858f52089 100644 --- a/jax/_src/lib/triton.py +++ b/jax/_src/lib/triton.py @@ -12,4 +12,62 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading +from typing import Protocol + from jaxlib.triton import dialect # noqa: F401 # pytype: disable=import-error + + +class CompilationResult(Protocol): + asm: str + smem_bytes: int + cluster_dim_x: int + cluster_dim_y: int + cluster_dim_z: int + + +class CompilationHandler(Protocol): + + def __call__( + self, + module: bytes, + arch_name: str, + num_warps: int, + num_ctas: int, + num_stages: int, + ) -> CompilationResult: + ... + + +_compilation_handlers: dict[str, CompilationHandler] = {} +_compilation_handlers_lock = threading.Lock() + + +def register_compilation_handler( + platform: str, handler: CompilationHandler +) -> None: + with _compilation_handlers_lock: + if existing_handler := _compilation_handlers.get(platform): + raise RuntimeError( + f'Platform {platform} already has a Triton compilation handler:' + f' {existing_handler}' + ) + _compilation_handlers[platform] = handler + + +def compile( + platform: str, + module: bytes, + arch_name: str, + *, + num_warps: int, + num_ctas: int, + num_stages: int, +) -> CompilationResult: + with _compilation_handlers_lock: + handler = _compilation_handlers.get(platform) + if handler is None: + raise RuntimeError( + f'Platform {platform} does not have a Triton compilation handler' + ) + return handler(module, arch_name, num_warps, num_ctas, num_stages) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 59b1b86f33fc..1a212c7bc542 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -17,10 +17,16 @@ from __future__ import annotations import io +import re from typing import Any +import zlib +import jax import jax._src.core as jax_core from jax._src.interpreters import mlir +from jax._src.lib import gpu_triton as triton_kernel_call_lib +from jax._src.lib import triton +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.pallas import core as pallas_core from jax._src.pallas.triton import lowering @@ -51,7 +57,7 @@ def pallas_call_lowering( cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], ): - del interpret, out_avals + del interpret, cost_estimate, out_avals if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError( "dynamic grid bounds not supported in the Triton backend" @@ -77,6 +83,11 @@ def pallas_call_lowering( print("The grid mapping for pallas_call {name_and_src_info}:") print(grid_mapping) + # Sanitize the name to conform to NVPTX requirements. We do this here + # to avoid the need to fetch the new name from PTX post compilation. + name_and_src_info = name_and_src_info.replace( + name=re.sub(r"[^a-zA-Z0-9_$]", "_", name_and_src_info.name) + ) lowering_result = lowering.lower_jaxpr_to_triton_module( jaxpr, grid_mapping, name_and_src_info, lowering_platform ) @@ -86,35 +97,88 @@ def pallas_call_lowering( print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True)) grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid) - out_types = [ + buf = io.BytesIO() + module_op.write_bytecode(buf) + + if jaxlib_version < (0, 5, 1): + # AOT Triton compilation is only available on jaxlib 0.5.1+. + out_types = [ ir.RankedTensorType.get(bm.array_shape_dtype.shape, mlir.dtype_to_ir_type(bm.array_shape_dtype.dtype)) for bm in grid_mapping.block_mappings_output - ] - buf = io.BytesIO() - module_op.write_bytecode(buf) - backend_config = dict( - name=ir.StringAttr.get(name_and_src_info.name), - ir=ir.StringAttr.get(buf.getvalue()), - num_stages=mlir.i32_attr(num_stages), - num_warps=mlir.i32_attr(num_warps), - grid_x=mlir.i32_attr(grid_x), - grid_y=mlir.i32_attr(grid_y), - grid_z=mlir.i32_attr(grid_z), - debug=ir.BoolAttr.get(debug), + ] + backend_config = dict( + name=ir.StringAttr.get(name_and_src_info.name), + ir=ir.StringAttr.get(buf.getvalue()), + num_stages=mlir.i32_attr(num_stages), + num_warps=mlir.i32_attr(num_warps), + grid_x=mlir.i32_attr(grid_x), + grid_y=mlir.i32_attr(grid_y), + grid_z=mlir.i32_attr(grid_z), + debug=ir.BoolAttr.get(debug), + ) + if "serialized_metadata" in (triton_params or {}): + # This field is unstable and may be removed in the future. + if triton_params["serialized_metadata"] is not None: + backend_config["serialized_metadata"] = ir.StringAttr.get( + triton_params["serialized_metadata"] + ) + return mlir.custom_call( + call_target_name="__gpu$xla.gpu.triton", + result_types=out_types, + operands=in_nodes, + backend_config=backend_config, + api_version=4, + operand_layouts=avals_to_layouts(ctx.avals_in), + result_layouts=avals_to_layouts(ctx.avals_out), + operand_output_aliases=dict(input_output_aliases), + ).results + + try: + gpu_device, *_ = jax.local_devices(backend="gpu") + except RuntimeError: + raise NotImplementedError( + "Pallas Triton lowering requires a GPU device." + ) from None + + compilation_result = triton.compile( + lowering_platform.upper(), + buf.getvalue(), + str(gpu_device.compute_capability), + num_warps=num_warps, + num_ctas=1, + num_stages=num_stages, + ) + kernel = triton_kernel_call_lib.TritonKernel( + name_and_src_info.name, + num_warps, + compilation_result.smem_bytes, + compilation_result.asm, + module_op.get_asm(enable_debug_info=True, pretty_debug_info=True), + triton_kernel_call_lib.get_compute_capability(0), + compilation_result.cluster_dim_x, + compilation_result.cluster_dim_y, + compilation_result.cluster_dim_z, + ) + kernel_call = triton_kernel_call_lib.TritonKernelCall( + kernel, + grid_x, + grid_y, + grid_z, + [triton_kernel_call_lib.create_array_parameter(0, 16)] + * (len(ctx.avals_in) + len(ctx.avals_out)), ) - if "serialized_metadata" in (triton_params or {}): - # This field is unstable and may be removed in the future. - if triton_params["serialized_metadata"] is not None: - backend_config["serialized_metadata"] = ir.StringAttr.get( - triton_params["serialized_metadata"] - ) + # TODO(b/392558289): Migrate to ``jax.ffi``. return mlir.custom_call( - call_target_name="__gpu$xla.gpu.triton", - result_types=out_types, + call_target_name="triton_kernel_call", + result_types=[*map(mlir.aval_to_ir_type, ctx.avals_out)], # type: ignore[list-item] operands=in_nodes, - backend_config=backend_config, - api_version=4, + backend_config=zlib.compress( + kernel_call.to_proto( + name_and_src_info.name, + triton_params.get("serialized_metadata") or b"", + ) + ), operand_layouts=avals_to_layouts(ctx.avals_in), result_layouts=avals_to_layouts(ctx.avals_out), operand_output_aliases=dict(input_output_aliases), diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index a09e21c6dd77..68281f4f32b3 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -18,6 +18,7 @@ import os import pathlib +from jax._src.lib import triton from jax._src.lib import xla_client import jax._src.xla_bridge as xb @@ -99,5 +100,11 @@ def initialize(): cuda_plugin_extension.register_custom_type_id, c_api ), ) + triton.register_compilation_handler( + "CUDA", + functools.partial( + cuda_plugin_extension.compile_triton_to_asm, c_api + ), + ) else: logger.warning('cuda_plugin_extension is not found.') diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 843ccb112871..a5680920d808 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -234,8 +234,8 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@nanobind", - "@tsl//tsl/platform:statusor", "@xla//xla:util", "@xla//xla/ffi/api:c_api", "@xla//xla/pjrt:status_casters", @@ -243,6 +243,7 @@ cc_library( "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_helpers", + "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", "@xla//xla/python:py_client_gpu", "@xla//xla/tsl/python/lib/core:numpy", ], diff --git a/jaxlib/gpu_plugin_extension.cc b/jaxlib/gpu_plugin_extension.cc index 46263bdcd40c..8863ebf19b39 100644 --- a/jaxlib/gpu_plugin_extension.cc +++ b/jaxlib/gpu_plugin_extension.cc @@ -16,23 +16,28 @@ limitations under the License. #include "jaxlib/gpu_plugin_extension.h" #include +#include +#include #include #include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_ffi_extension.h" #include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/c/pjrt_c_api_triton_extension.h" #include "xla/pjrt/status_casters.h" #include "xla/python/py_client_gpu.h" #include "xla/tsl/python/lib/core/numpy.h" #include "xla/util.h" -#include "tsl/platform/statusor.h" namespace nb = nanobind; @@ -40,6 +45,44 @@ namespace xla { namespace { +struct TritonCompilationResult { + std::string asm_text; + int64_t smem_bytes; + int cluster_dim_x; + int cluster_dim_y; + int cluster_dim_z; +}; + +absl::StatusOr CompileTritonToASM( + const PJRT_Api* c_api, absl::string_view module, + absl::string_view arch_name, int num_warps, int num_ctas, int num_stages) { + const PJRT_Triton_Extension* triton_ext = + pjrt::FindExtension( + c_api, PJRT_Extension_Type::PJRT_Extension_Type_Triton); + if (triton_ext == nullptr) { + return Unimplemented("The plugin does not have a Triton extension."); + } + PJRT_Triton_Compile_Args args; + args.struct_size = PJRT_Triton_Compile_Args_STRUCT_SIZE; + args.module = module.data(); + args.module_size = module.size(); + args.arch_name = arch_name.data(); + args.arch_name_size = arch_name.size(); + args.num_warps = num_warps; + args.num_ctas = num_ctas; + args.num_stages = num_stages; + RETURN_STATUS_IF_PJRT_ERROR(triton_ext->compile(&args), c_api); + auto asm_text = std::string(args.out_asm, args.out_asm_size); + delete[] args.out_asm; + return TritonCompilationResult{ + .asm_text = std::string(args.out_asm, args.out_asm_size), + .smem_bytes = args.out_smem_bytes, + .cluster_dim_x = args.out_cluster_dim_x, + .cluster_dim_y = args.out_cluster_dim_y, + .cluster_dim_z = args.out_cluster_dim_z, + }; +} + absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, const char* fn_name_c_str, size_t fn_name_size, nb::object fn, @@ -170,6 +213,24 @@ nb::dict Registrations() { void BuildGpuPluginExtension(nanobind::module_& m) { tsl::ImportNumpy(); + + nb::class_(m, "TritonCompilationResult") + .def_ro("asm", &TritonCompilationResult::asm_text) + .def_ro("smem_bytes", &TritonCompilationResult::smem_bytes) + .def_ro("cluster_dim_x", &TritonCompilationResult::cluster_dim_x) + .def_ro("cluster_dim_y", &TritonCompilationResult::cluster_dim_y) + .def_ro("cluster_dim_z", &TritonCompilationResult::cluster_dim_z); + + m.def("compile_triton_to_asm", + [](nb::capsule c_api, nb::bytes module, absl::string_view arch_name, + int num_warps, int num_ctas, int num_stages) { + return xla::ValueOrThrow(CompileTritonToASM( + static_cast(c_api.data()), + absl::string_view(static_cast(module.data()), + module.size()), + arch_name, num_warps, num_ctas, num_stages)); + }); + m.def( "register_custom_call_target", [](nb::capsule c_api, nb::object fn_name_py, nb::object fn, diff --git a/tests/pallas/export_pallas_test.py b/tests/pallas/export_pallas_test.py index c1092357c6d2..ca22c221461e 100644 --- a/tests/pallas/export_pallas_test.py +++ b/tests/pallas/export_pallas_test.py @@ -49,13 +49,18 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) )(x, y) + platforms = ["tpu"] + if jtu.device_under_test() == "gpu": + # Pallas Triton requires a GPU device to be available during lowering. + platforms.append("cuda") + a = np.arange(8 * 16, dtype=np.int32).reshape((8, 16)) exp = export.export( add_vectors, - platforms=["tpu", "cuda"], - # The Pallas GPU custom call is not enabled for export by default. + platforms=platforms, + # The Pallas Triton custom call is not enabled for export by default. disabled_checks=[ - export.DisabledSafetyCheck.custom_call("__gpu$xla.gpu.triton") + export.DisabledSafetyCheck.custom_call("triton_kernel_call") ] )(a, a)