Skip to content

Commit

Permalink
[Backport] Introduce jax_import_guard (#6794)
Browse files Browse the repository at this point in the history
Summary:
Importing JAX will lock the TPU devices and prevent any pytorch/xla's TPU computations. To address it, we need to acquire the TPU first.

Test Plan:
python test/test_pallas.py
  • Loading branch information
alanwaketan authored Mar 21, 2024
1 parent d987775 commit c3001d6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 35 deletions.
17 changes: 7 additions & 10 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
import torch_xla
from torch_xla import runtime as xr

if xr.device_type() == 'TPU':
from torch_xla.experimental.custom_kernel import jax_import_guard
jax_import_guard()
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl


class PallasTest(unittest.TestCase):

Expand Down Expand Up @@ -111,12 +118,8 @@ def add_one_pallas(output, inputs, payload):

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_extract_add_payload(self):
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration

from jax.experimental import pallas as pl

def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
Expand All @@ -136,13 +139,7 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
self.assertIn("custom_call_config", payload)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
# TODO: This test cannot be ran individually, let's fix it.
def test_tpu_custom_call_pallas_wrap_add_payload(self):
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration

from jax.experimental import pallas as pl

def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
Expand Down
58 changes: 33 additions & 25 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import functools
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration
import torch
import torch_xla
import torch_xla.core.xla_model as xm

from jax.experimental import pallas as pl
from typing import List, Callable
from torch.library import impl
from torch_xla.core.xla_model import XLA_LIB
Expand Down Expand Up @@ -64,30 +60,42 @@ def _extract_backend_config(
return None


def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
if dtype == torch.float32:
return jnp.float32
elif dtype == torch.float64:
return jnp.float64
elif dtype == torch.float16:
return jnp.float16
elif dtype == torch.bfloat16:
return jnp.bfloat16
elif dtype == torch.int32:
return jnp.int32
elif dtype == torch.int64:
return jnp.int64
elif dtype == torch.int16:
return jnp.int16
elif dtype == torch.int8:
return jnp.int8
elif dtype == torch.uint8:
return jnp.uint8
else:
raise ValueError(f"Unsupported dtype: {dtype}")
def jax_import_guard():
# Somehow, we need to grab the TPU before JAX locks it. Otherwise, any pt-xla TPU operations will hang.
torch_xla._XLAC._init_computation_client()


def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration
from jax.experimental import pallas as pl

def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
if dtype == torch.float32:
return jnp.float32
elif dtype == torch.float64:
return jnp.float64
elif dtype == torch.float16:
return jnp.float16
elif dtype == torch.bfloat16:
return jnp.bfloat16
elif dtype == torch.int32:
return jnp.int32
elif dtype == torch.int64:
return jnp.int64
elif dtype == torch.int16:
return jnp.int16
elif dtype == torch.int8:
return jnp.int8
elif dtype == torch.uint8:
return jnp.uint8
else:
raise ValueError(f"Unsupported dtype: {dtype}")

# TODO: Maybe we can cache the payload for the same input.
def wrapped_kernel(kernel: Callable, output_shape_dtype_fn: Callable, *args):
jax_args = []
Expand Down

0 comments on commit c3001d6

Please sign in to comment.