Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Introduce make_kernel_from_pallas (#6713) #6742

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,48 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
# the most important fields are present.
self.assertIn("custom_call_config", payload)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
# TODO: This test cannot be ran individually, let's fix it.
def test_tpu_custom_call_pallas_wrap_add_payload(self):
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration

from jax.experimental import pallas as pl

def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y

@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(
add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x.shape,
x.dtype))(x, y)

from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y:
(x.shape, x.dtype))

dtypes = [torch.float32, torch.float
] # TODO: torch.float64, torch.bfloat16, torch.float16 don't work.
for i in range(len(dtypes)):
x = torch.randn((i + 1, i + 1), dtype=dtypes[i]).to("xla")
y = torch.randn((i + 1, i + 1), dtype=dtypes[i]).to("xla")
expected_output = x + y
output = pt_kernel(x, y)
self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu()))

dtypes = [
torch.int32, torch.int
] # TODO: torch.int64, torch.int16, torch.int8, torch.uint8 don't work.
for i in range(len(dtypes)):
x = torch.arange(i + 1, dtype=dtypes[i]).to("xla")
y = torch.arange(i + 1, dtype=dtypes[i]).to("xla")
expected_output = x + y
output = pt_kernel(x, y)
self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu()))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
55 changes: 54 additions & 1 deletion torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import functools
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration
import torch
import torch_xla
import torch_xla.core.xla_model as xm

from typing import List
from jax.experimental import pallas as pl
from typing import List, Callable
from torch.library import impl
from torch_xla.core.xla_model import XLA_LIB

Expand Down Expand Up @@ -56,3 +62,50 @@ def _extract_backend_config(
if op.name == "stablehlo.custom_call":
return op.backend_config.value
return None


def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
if dtype == torch.float32:
return jnp.float32
elif dtype == torch.float64:
return jnp.float64
elif dtype == torch.float16:
return jnp.float16
elif dtype == torch.bfloat16:
return jnp.bfloat16
elif dtype == torch.int32:
return jnp.int32
elif dtype == torch.int64:
return jnp.int64
elif dtype == torch.int16:
return jnp.int16
elif dtype == torch.int8:
return jnp.int8
elif dtype == torch.uint8:
return jnp.uint8
else:
raise ValueError(f"Unsupported dtype: {dtype}")


def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
# TODO: Maybe we can cache the payload for the same input.
def wrapped_kernel(kernel: Callable, output_shape_dtype_fn: Callable, *args):
jax_args = []
for i, arg in enumerate(args):
if torch.is_tensor(arg):
# ShapedArray doesn't have any storage and thus is very suitable for generating the payload.
jax_meta_tensor = jax.core.ShapedArray(
arg.shape, convert_torch_dtype_to_jax(arg.dtype))
jax_args.append(jax_meta_tensor)
else:
# TODO: We can support more types here.
assert False, f"Unsupported argument type: {type(arg)}"

ir = jax.jit(kernel).lower(*jax_args).compiler_ir()
payload = _extract_backend_config(ir)
output_shape, output_dtype = output_shape_dtype_fn(*args)
output = torch.empty(output_shape, dtype=output_dtype).to(xm.xla_device())
torch_xla._XLAC._xla_tpu_custom_call_(output, args, payload)
return output

return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn)
Loading