From c3001d64caa5a0ff27985eb0f93bfef71c0a195b Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Thu, 21 Mar 2024 10:36:35 -0700 Subject: [PATCH] [Backport] Introduce jax_import_guard (#6794) Summary: Importing JAX will lock the TPU devices and prevent any pytorch/xla's TPU computations. To address it, we need to acquire the TPU first. Test Plan: python test/test_pallas.py --- test/test_pallas.py | 17 +++----- torch_xla/experimental/custom_kernel.py | 58 ++++++++++++++----------- 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index c623242d384..10b3fddea74 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -8,6 +8,13 @@ import torch_xla from torch_xla import runtime as xr +if xr.device_type() == 'TPU': + from torch_xla.experimental.custom_kernel import jax_import_guard + jax_import_guard() + import jax + import jax.numpy as jnp + from jax.experimental import pallas as pl + class PallasTest(unittest.TestCase): @@ -111,12 +118,8 @@ def add_one_pallas(output, inputs, payload): @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_tpu_custom_call_pallas_extract_add_payload(self): - import jax - import jax.numpy as jnp import jax._src.pallas.mosaic.pallas_call_registration - from jax.experimental import pallas as pl - def add_vectors_kernel(x_ref, y_ref, o_ref): x, y = x_ref[...], y_ref[...] o_ref[...] = x + y @@ -136,13 +139,7 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: self.assertIn("custom_call_config", payload) @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") - # TODO: This test cannot be ran individually, let's fix it. def test_tpu_custom_call_pallas_wrap_add_payload(self): - import jax - import jax.numpy as jnp - import jax._src.pallas.mosaic.pallas_call_registration - - from jax.experimental import pallas as pl def add_vectors_kernel(x_ref, y_ref, o_ref): x, y = x_ref[...], y_ref[...] diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 95439a9d2ff..b6b7a304b5b 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -1,12 +1,8 @@ import functools -import jax -import jax.numpy as jnp -import jax._src.pallas.mosaic.pallas_call_registration import torch import torch_xla import torch_xla.core.xla_model as xm -from jax.experimental import pallas as pl from typing import List, Callable from torch.library import impl from torch_xla.core.xla_model import XLA_LIB @@ -64,30 +60,42 @@ def _extract_backend_config( return None -def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype: - if dtype == torch.float32: - return jnp.float32 - elif dtype == torch.float64: - return jnp.float64 - elif dtype == torch.float16: - return jnp.float16 - elif dtype == torch.bfloat16: - return jnp.bfloat16 - elif dtype == torch.int32: - return jnp.int32 - elif dtype == torch.int64: - return jnp.int64 - elif dtype == torch.int16: - return jnp.int16 - elif dtype == torch.int8: - return jnp.int8 - elif dtype == torch.uint8: - return jnp.uint8 - else: - raise ValueError(f"Unsupported dtype: {dtype}") +def jax_import_guard(): + # Somehow, we need to grab the TPU before JAX locks it. Otherwise, any pt-xla TPU operations will hang. + torch_xla._XLAC._init_computation_client() def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable): + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + import jax + import jax.numpy as jnp + import jax._src.pallas.mosaic.pallas_call_registration + from jax.experimental import pallas as pl + + def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype: + if dtype == torch.float32: + return jnp.float32 + elif dtype == torch.float64: + return jnp.float64 + elif dtype == torch.float16: + return jnp.float16 + elif dtype == torch.bfloat16: + return jnp.bfloat16 + elif dtype == torch.int32: + return jnp.int32 + elif dtype == torch.int64: + return jnp.int64 + elif dtype == torch.int16: + return jnp.int16 + elif dtype == torch.int8: + return jnp.int8 + elif dtype == torch.uint8: + return jnp.uint8 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + # TODO: Maybe we can cache the payload for the same input. def wrapped_kernel(kernel: Callable, output_shape_dtype_fn: Callable, *args): jax_args = []