Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Support XLA_USE_BF16 #6817

Merged
merged 2 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,20 @@ def attention(q, k, v):
self.assertFalse(torch.allclose(o.cpu(), expected_o.cpu()))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"})
@unittest.mock.patch.dict(os.environ, {"XLA_USE_BF16": "1"})
def test_flash_attention_wrapper_bf16(self):
from torch_xla.experimental.custom_kernel import flash_attention

q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")

# No exception being raised.
o = flash_attention(q, k, v)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import os
import torch
import torch_xla
import torch_xla.core.xla_model as xm
Expand All @@ -7,6 +8,8 @@
from torch.library import impl
from torch_xla.core.xla_model import XLA_LIB

_XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0") == "1"

XLA_LIB.define(
"tpu_custom_call_(Tensor(a!) output, Tensor[] inputs, str payload) -> ()",)

Expand Down Expand Up @@ -75,8 +78,12 @@ def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):

def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
if dtype == torch.float32:
if _XLA_USE_BF16:
return jnp.bfloat16
return jnp.float32
elif dtype == torch.float64:
if _XLA_USE_BF16:
return jnp.bfloat16
return jnp.float64
elif dtype == torch.float16:
return jnp.float16
Expand Down
Loading