Skip to content

Commit

Permalink
[pallas:triton] The lowering now uses PTX instead of Triton IR
Browse files Browse the repository at this point in the history
This change improves the stability and backward compatibility of Pallas Triton
calls, because unlike PTX, the Triton dialect has no stability guarantees
and does change in practice.

A few notes

* Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead,
  compilation is done via a new PjRt extension, which uses its own compilation
  pipeline mirrored after the one in the Triton Python bindings.
* The implementation of the old custom call used by Pallas Triton is
  deprecated and will be removed after 6 months as per
  [compatibility guarantees] [*]

[*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees

PiperOrigin-RevId: 717950300
  • Loading branch information
superbobry authored and Google-ML-Automation committed Jan 24, 2025
1 parent 7043b85 commit 1618bb9
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 26 deletions.
58 changes: 58 additions & 0 deletions jax/_src/lib/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,62 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
from typing import Protocol

from jaxlib.triton import dialect # noqa: F401 # pytype: disable=import-error


class CompilationResult(Protocol):
asm: bytes
smem_bytes: int
cluster_dim_x: int
cluster_dim_y: int
cluster_dim_z: int


class CompilationHandler(Protocol):

def __call__(
self,
module: bytes,
arch_name: str,
num_warps: int,
num_ctas: int,
num_stages: int,
) -> CompilationResult:
...


_compilation_handlers: dict[str, CompilationHandler] = {}
_compilation_handlers_lock = threading.Lock()


def register_compilation_handler(
platform: str, handler: CompilationHandler
) -> None:
with _compilation_handlers_lock:
if existing_handler := _compilation_handlers.get(platform):
raise RuntimeError(
f'Platform {platform} already has a Triton compilation handler:'
f' {existing_handler}'
)
_compilation_handlers[platform] = handler


def compile(
platform: str,
module: bytes,
arch_name: str,
*,
num_warps: int,
num_ctas: int,
num_stages: int,
) -> CompilationResult:
with _compilation_handlers_lock:
handler = _compilation_handlers.get(platform)
if handler is None:
raise RuntimeError(
f'Platform {platform} does not have a Triton compilation handler'
)
return handler(module, arch_name, num_warps, num_ctas, num_stages)
106 changes: 82 additions & 24 deletions jax/_src/pallas/triton/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@
from __future__ import annotations

import io
import re
from typing import Any
import zlib

import jax
import jax._src.core as jax_core
from jax._src.interpreters import mlir
from jax._src.lib import gpu_triton as triton_kernel_call_lib
from jax._src.lib import triton
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.pallas import core as pallas_core
from jax._src.pallas.triton import lowering
Expand Down Expand Up @@ -51,7 +57,7 @@ def pallas_call_lowering(
cost_estimate: pallas_core.CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
):
del interpret, out_avals
del interpret, cost_estimate, out_avals
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError(
"dynamic grid bounds not supported in the Triton backend"
Expand All @@ -77,6 +83,11 @@ def pallas_call_lowering(
print("The grid mapping for pallas_call {name_and_src_info}:")
print(grid_mapping)

# Sanitize the name to conform to NVPTX requirements. We do this here
# to avoid the need to fetch the new name from PTX post compilation.
name_and_src_info = name_and_src_info.replace(
name=re.sub(r"[^a-zA-Z0-9_$]", "_", name_and_src_info.name)
)
lowering_result = lowering.lower_jaxpr_to_triton_module(
jaxpr, grid_mapping, name_and_src_info, lowering_platform
)
Expand All @@ -86,35 +97,82 @@ def pallas_call_lowering(
print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True))

grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid)
out_types = [
buf = io.BytesIO()
module_op.write_bytecode(buf)

if jaxlib_version < (0, 5, 1):
# AOT Triton compilation is only available on jaxlib 0.5.1+.
out_types = [
ir.RankedTensorType.get(bm.array_shape_dtype.shape,
mlir.dtype_to_ir_type(bm.array_shape_dtype.dtype))
for bm in grid_mapping.block_mappings_output
]
buf = io.BytesIO()
module_op.write_bytecode(buf)
backend_config = dict(
name=ir.StringAttr.get(name_and_src_info.name),
ir=ir.StringAttr.get(buf.getvalue()),
num_stages=mlir.i32_attr(num_stages),
num_warps=mlir.i32_attr(num_warps),
grid_x=mlir.i32_attr(grid_x),
grid_y=mlir.i32_attr(grid_y),
grid_z=mlir.i32_attr(grid_z),
debug=ir.BoolAttr.get(debug),
]
backend_config = dict(
name=ir.StringAttr.get(name_and_src_info.name),
ir=ir.StringAttr.get(buf.getvalue()),
num_stages=mlir.i32_attr(num_stages),
num_warps=mlir.i32_attr(num_warps),
grid_x=mlir.i32_attr(grid_x),
grid_y=mlir.i32_attr(grid_y),
grid_z=mlir.i32_attr(grid_z),
debug=ir.BoolAttr.get(debug),
)
if "serialized_metadata" in (triton_params or {}):
# This field is unstable and may be removed in the future.
if triton_params["serialized_metadata"] is not None:
backend_config["serialized_metadata"] = ir.StringAttr.get(
triton_params["serialized_metadata"]
)
return mlir.custom_call(
call_target_name="__gpu$xla.gpu.triton",
result_types=out_types,
operands=in_nodes,
backend_config=backend_config,
api_version=4,
operand_layouts=avals_to_layouts(ctx.avals_in),
result_layouts=avals_to_layouts(ctx.avals_out),
operand_output_aliases=dict(input_output_aliases),
).results

gpu_device, *_ = jax.local_devices(backend="gpu")
compilation_result = triton.compile(
lowering_platform.upper(),
buf.getvalue(),
str(gpu_device.compute_capability), # e.g. 7.0
num_warps=num_warps,
num_ctas=1,
num_stages=num_stages,
)
kernel = triton_kernel_call_lib.TritonKernel(
name_and_src_info.name,
num_warps,
compilation_result.smem_bytes,
compilation_result.asm,
module_op.get_asm(enable_debug_info=True, pretty_debug_info=True),
triton_kernel_call_lib.get_compute_capability(0),
compilation_result.cluster_dim_x,
compilation_result.cluster_dim_y,
compilation_result.cluster_dim_z,
)
kernel_call = triton_kernel_call_lib.TritonKernelCall(
kernel,
grid_x,
grid_y,
grid_z,
[triton_kernel_call_lib.create_array_parameter(0, 16)]
* (len(ctx.avals_in) + len(ctx.avals_out)),
)
if "serialized_metadata" in (triton_params or {}):
# This field is unstable and may be removed in the future.
if triton_params["serialized_metadata"] is not None:
backend_config["serialized_metadata"] = ir.StringAttr.get(
triton_params["serialized_metadata"]
)
# TODO(slebedev): Migrate to ``jax.ffi``.
return mlir.custom_call(
call_target_name="__gpu$xla.gpu.triton",
result_types=out_types,
call_target_name="triton_kernel_call",
result_types=[*map(mlir.aval_to_ir_type, ctx.avals_out)], # type: ignore[list-item]
operands=in_nodes,
backend_config=backend_config,
api_version=4,
backend_config=zlib.compress(
kernel_call.to_proto(
name_and_src_info.name,
triton_params.get("serialized_metadata") or b"",
)
),
operand_layouts=avals_to_layouts(ctx.avals_in),
result_layouts=avals_to_layouts(ctx.avals_out),
operand_output_aliases=dict(input_output_aliases),
Expand Down
7 changes: 7 additions & 0 deletions jax_plugins/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import pathlib

from jax._src.lib import triton
from jax._src.lib import xla_client
import jax._src.xla_bridge as xb

Expand Down Expand Up @@ -99,5 +100,11 @@ def initialize():
cuda_plugin_extension.register_custom_type_id, c_api
),
)
triton.register_compilation_handler(
"CUDA",
functools.partial(
cuda_plugin_extension.compile_triton_to_asm, c_api
),
)
else:
logger.warning('cuda_plugin_extension is not found.')
3 changes: 2 additions & 1 deletion jaxlib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,16 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@nanobind",
"@tsl//tsl/platform:statusor",
"@xla//xla:util",
"@xla//xla/ffi/api:c_api",
"@xla//xla/pjrt:status_casters",
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
"@xla//xla/python:py_client_gpu",
"@xla//xla/tsl/python/lib/core:numpy",
],
Expand Down
63 changes: 62 additions & 1 deletion jaxlib/gpu_plugin_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,73 @@ limitations under the License.
#include "jaxlib/gpu_plugin_extension.h"

#include <cstddef>
#include <cstdint>
#include <string>
#include <utility>

#include "nanobind/nanobind.h"
#include "nanobind/stl/string.h" // IWYU pragma: keep
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "jaxlib/kernel_nanobind_helpers.h"
#include "xla/ffi/api/c_api.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h"
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
#include "xla/pjrt/c/pjrt_c_api_triton_extension.h"
#include "xla/pjrt/status_casters.h"
#include "xla/python/py_client_gpu.h"
#include "xla/tsl/python/lib/core/numpy.h"
#include "xla/util.h"
#include "tsl/platform/statusor.h"

namespace nb = nanobind;

namespace xla {

namespace {

struct TritonCompilationResult {
std::string asm_text;
int64_t smem_bytes;
int cluster_dim_x;
int cluster_dim_y;
int cluster_dim_z;
};

absl::StatusOr<TritonCompilationResult> CompileTritonToASM(
const PJRT_Api* c_api, absl::string_view module,
absl::string_view arch_name, int num_warps, int num_ctas, int num_stages) {
const PJRT_Triton_Extension* triton_ext =
pjrt::FindExtension<PJRT_Triton_Extension>(
c_api, PJRT_Extension_Type::PJRT_Extension_Type_Triton);
if (triton_ext == nullptr) {
return Unimplemented("The plugin does not have a Triton extension.");
}
PJRT_Triton_Compile_Args args;
args.struct_size = PJRT_Triton_Compile_Args_STRUCT_SIZE;
args.module = module.data();
args.module_size = module.size();
args.arch_name = arch_name.data();
args.arch_name_size = arch_name.size();
args.num_warps = num_warps;
args.num_ctas = num_ctas;
args.num_stages = num_stages;
RETURN_STATUS_IF_PJRT_ERROR(triton_ext->compile(&args), c_api);
auto asm_text = std::string(args.out_asm, args.out_asm_size);
delete[] args.out_asm;
return TritonCompilationResult{
.asm_text = std::string(args.out_asm, args.out_asm_size),
.smem_bytes = args.out_smem_bytes,
.cluster_dim_x = args.out_cluster_dim_x,
.cluster_dim_y = args.out_cluster_dim_y,
.cluster_dim_z = args.out_cluster_dim_z,
};
}

absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api,
const char* fn_name_c_str,
size_t fn_name_size, nb::object fn,
Expand Down Expand Up @@ -170,6 +213,24 @@ nb::dict Registrations() {

void BuildGpuPluginExtension(nanobind::module_& m) {
tsl::ImportNumpy();

nb::class_<TritonCompilationResult>(m, "TritonCompilationResult")
.def_ro("asm", &TritonCompilationResult::asm_text)
.def_ro("smem_bytes", &TritonCompilationResult::smem_bytes)
.def_ro("cluster_dim_x", &TritonCompilationResult::cluster_dim_x)
.def_ro("cluster_dim_y", &TritonCompilationResult::cluster_dim_y)
.def_ro("cluster_dim_z", &TritonCompilationResult::cluster_dim_z);

m.def("compile_triton_to_asm",
[](nb::capsule c_api, nb::bytes module, absl::string_view arch_name,
int num_warps, int num_ctas, int num_stages) {
return xla::ValueOrThrow(CompileTritonToASM(
static_cast<const PJRT_Api*>(c_api.data()),
absl::string_view(static_cast<const char*>(module.data()),
module.size()),
arch_name, num_warps, num_ctas, num_stages));
});

m.def(
"register_custom_call_target",
[](nb::capsule c_api, nb::object fn_name_py, nb::object fn,
Expand Down

0 comments on commit 1618bb9

Please sign in to comment.