From b3a5948927ab6e0b81c875efa4e13d2862a9544b Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Thu, 25 Apr 2024 00:37:06 +0000 Subject: [PATCH] Address comments --- test/test_pallas.py | 12 +++----- torch_xla/experimental/custom_kernel.py | 38 +++++++++++-------------- 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 3e8496afabfb..089394b71d39 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -485,7 +485,6 @@ def test_flash_attention_backward(self): @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") def test_paged_attention_wrapper(self): - jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) from torch_xla.experimental.custom_kernel import paged_attention from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention @@ -542,14 +541,12 @@ def test_paged_attention_wrapper(self): torch.allclose( output.cpu()[seq_lens > 0], expected_output.cpu()[seq_lens > 0], - atol=1e-1, - rtol=1e-1)) - jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + atol=1e-5, + rtol=1e-5)) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") def test_paged_attention_wrapper_with_dynamo(self): - jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) from torch_xla.experimental.custom_kernel import paged_attention from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention @@ -619,9 +616,8 @@ def paged_attention_wrapper(q, k, v, seq_lens, page_indices, torch.allclose( output.cpu()[seq_lens > 0], expected_output.cpu()[seq_lens > 0], - atol=1e-1, - rtol=1e-1)) - jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + atol=1e-5, + rtol=1e-5)) if __name__ == '__main__': diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 34ab988a151d..55c6b633877a 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -416,6 +416,20 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices, return output.reshape(batch_size, num_heads, head_dim) +def non_xla_attetion(q, k, v): + # This will be called when dynamo use fake tensor to construct the fake output. + # We need to make sure output tensor's shape is correct. + if k.device != torch.device("meta"): + warnings.warn( + 'XLA flash attention should only be applied to tensors on XLA device') + + # perform a regular attention if input tensors are not on XLA device. + attn_weight = q @ k.transpose(-2, -1) + attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) + attn_output = attn_weight @ v + return attn_output + + XLA_LIB.define( "flash_attention(Tensor q, Tensor k, Tensor v, bool casual=False) -> Tensor", ) @@ -434,17 +448,7 @@ def flash_attention_non_xla(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False): - # This will be called when dynamo use fake tensor to construct the fake output. - # We need to make sure output tensor's shape is correct. - if k.device != torch.device("meta"): - warnings.warn( - 'XLA flash attention should only be applied to tensors on XLA device') - - # perform a regular attention if input tensors are not on XLA device. - attn_weight = q @ k.transpose(-2, -1) - attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) - attn_output = attn_weight @ v - return attn_output + return non_xla_attetion(q, k, v) XLA_LIB.define( @@ -466,14 +470,4 @@ def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, lengths: torch.Tensor, page_indices: torch.Tensor, pages_per_compute_block: int): - # This will be called when dynamo use fake tensor to construct the fake output. - # We need to make sure output tensor's shape is correct. - if k.device != torch.device("meta"): - warnings.warn( - 'XLA paged attention should only be applied to tensors on XLA device') - - # perform a regular attention if input tensors are not on XLA device. - attn_weight = q @ k.transpose(-2, -1) - attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) - attn_output = attn_weight @ v - return attn_output + return non_xla_attetion(q, k, v)