Skip to content

Commit

Permalink
[Pallas] Support XLA_USE_BF16 (#6817)
Browse files Browse the repository at this point in the history
Summary:
XLA_USE_BF16=1 will make all the internal xla tensors to use BF16 but torch.tensor wrappers will still return torch.float. To address this, we need to set the jax tracers correctly to produce the correct Mosaic.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_wrapper_bf16
  • Loading branch information
alanwaketan authored Mar 26, 2024
1 parent b8a97df commit 73c31db
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
14 changes: 14 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,20 @@ def attention(q, k, v):
self.assertFalse(torch.allclose(o.cpu(), expected_o.cpu()))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"})
@unittest.mock.patch.dict(os.environ, {"XLA_USE_BF16": "1"})
def test_flash_attention_wrapper_bf16(self):
from torch_xla.experimental.custom_kernel import flash_attention

q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")

# No exception being raised.
o = flash_attention(q, k, v)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import os
import torch
import torch_xla
import torch_xla.core.xla_model as xm
Expand All @@ -7,6 +8,8 @@
from torch.library import impl
from torch_xla.core.xla_model import XLA_LIB

_XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0") == "1"

XLA_LIB.define(
"tpu_custom_call_(Tensor(a!) output, Tensor[] inputs, str payload) -> ()",)

Expand Down Expand Up @@ -75,8 +78,12 @@ def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):

def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
if dtype == torch.float32:
if _XLA_USE_BF16:
return jnp.bfloat16
return jnp.float32
elif dtype == torch.float64:
if _XLA_USE_BF16:
return jnp.bfloat16
return jnp.float64
elif dtype == torch.float16:
return jnp.float16
Expand Down

0 comments on commit 73c31db

Please sign in to comment.