diff --git a/test/test_pallas.py b/test/test_pallas.py index f2db81c7d65..58981e471fd 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -248,6 +248,20 @@ def attention(q, k, v): self.assertFalse(torch.allclose(o.cpu(), expected_o.cpu())) jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + @unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"}) + @unittest.mock.patch.dict(os.environ, {"XLA_USE_BF16": "1"}) + def test_flash_attention_wrapper_bf16(self): + from torch_xla.experimental.custom_kernel import flash_attention + + q = torch.randn(3, 2, 128, 4).to("xla") + k = torch.randn(3, 2, 128, 4).to("xla") + v = torch.randn(3, 2, 128, 4).to("xla") + + # No exception being raised. + o = flash_attention(q, k, v) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index c8086374963..a86c2af1b0e 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -1,4 +1,5 @@ import functools +import os import torch import torch_xla import torch_xla.core.xla_model as xm @@ -7,6 +8,8 @@ from torch.library import impl from torch_xla.core.xla_model import XLA_LIB +_XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0") == "1" + XLA_LIB.define( "tpu_custom_call_(Tensor(a!) output, Tensor[] inputs, str payload) -> ()",) @@ -75,8 +78,12 @@ def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable): def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype: if dtype == torch.float32: + if _XLA_USE_BF16: + return jnp.bfloat16 return jnp.float32 elif dtype == torch.float64: + if _XLA_USE_BF16: + return jnp.bfloat16 return jnp.float64 elif dtype == torch.float16: return jnp.float16