Skip to content

Commit

Permalink
[Pallas] Add a bfloat16 flash attention test case (#6810)
Browse files Browse the repository at this point in the history
  • Loading branch information
alanwaketan authored Mar 25, 2024
1 parent 8240d05 commit 3522be1
Showing 1 changed file with 31 additions and 3 deletions.
34 changes: 31 additions & 3 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch_xla
from torch_xla import runtime as xr
from torch_xla._internal import tpu

if xr.device_type() == 'TPU':
from torch_xla.experimental.custom_kernel import jax_import_guard
Expand Down Expand Up @@ -155,8 +156,9 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y:
(x.shape, x.dtype))

dtypes = [torch.float32, torch.float
] # TODO: torch.float64, torch.bfloat16, torch.float16 don't work.
dtypes = [
torch.float32, torch.float
] # Add doesn't support torch.float64, torch.bfloat16, torch.float16.
for i in range(len(dtypes)):
x = torch.randn((i + 1, i + 1), dtype=dtypes[i]).to("xla")
y = torch.randn((i + 1, i + 1), dtype=dtypes[i]).to("xla")
Expand All @@ -166,14 +168,40 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:

dtypes = [
torch.int32, torch.int
] # TODO: torch.int64, torch.int16, torch.int8, torch.uint8 don't work.
] # Add doesn't support torch.int64, torch.int16, torch.int8, torch.uint8.
for i in range(len(dtypes)):
x = torch.arange(i + 1, dtype=dtypes[i]).to("xla")
y = torch.arange(i + 1, dtype=dtypes[i]).to("xla")
expected_output = x + y
output = pt_kernel(x, y)
self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"})
def test_tpu_custom_call_pallas_wrap_flash_attention(self):
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
flash_attention_kernel = make_kernel_from_pallas(
flash_attention, lambda q, k, v: (q.shape, q.dtype))

def attention(q, k, v):
attn_weight = q @ k.transpose(-2, -1)
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output

q_mini = torch.arange(128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13
k_mini = torch.arange(
1000, 1000 + 128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13
q = q_mini.broadcast_to(3, 2, 128, 4).to("xla")
k = k_mini.broadcast_to(3, 2, 128, 4).to("xla")
v = torch.ones(3, 2, 128, 4, dtype=torch.bfloat16).to("xla")

o = flash_attention_kernel(q, k, v)
expected_o = attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down

0 comments on commit 3522be1

Please sign in to comment.