Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Apr 25, 2024
1 parent bb29c3a commit b3a5948
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 30 deletions.
12 changes: 4 additions & 8 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,6 @@ def test_flash_attention_backward(self):
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_wrapper(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import paged_attention
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention

Expand Down Expand Up @@ -542,14 +541,12 @@ def test_paged_attention_wrapper(self):
torch.allclose(
output.cpu()[seq_lens > 0],
expected_output.cpu()[seq_lens > 0],
atol=1e-1,
rtol=1e-1))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)
atol=1e-5,
rtol=1e-5))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_wrapper_with_dynamo(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import paged_attention
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention

Expand Down Expand Up @@ -619,9 +616,8 @@ def paged_attention_wrapper(q, k, v, seq_lens, page_indices,
torch.allclose(
output.cpu()[seq_lens > 0],
expected_output.cpu()[seq_lens > 0],
atol=1e-1,
rtol=1e-1))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)
atol=1e-5,
rtol=1e-5))


if __name__ == '__main__':
Expand Down
38 changes: 16 additions & 22 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,20 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices,
return output.reshape(batch_size, num_heads, head_dim)


def non_xla_attetion(q, k, v):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
if k.device != torch.device("meta"):
warnings.warn(
'XLA flash attention should only be applied to tensors on XLA device')

# perform a regular attention if input tensors are not on XLA device.
attn_weight = q @ k.transpose(-2, -1)
attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output


XLA_LIB.define(
"flash_attention(Tensor q, Tensor k, Tensor v, bool casual=False) -> Tensor",
)
Expand All @@ -434,17 +448,7 @@ def flash_attention_non_xla(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
if k.device != torch.device("meta"):
warnings.warn(
'XLA flash attention should only be applied to tensors on XLA device')

# perform a regular attention if input tensors are not on XLA device.
attn_weight = q @ k.transpose(-2, -1)
attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output
return non_xla_attetion(q, k, v)


XLA_LIB.define(
Expand All @@ -466,14 +470,4 @@ def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor,
v_pages: torch.Tensor, lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
if k.device != torch.device("meta"):
warnings.warn(
'XLA paged attention should only be applied to tensors on XLA device')

# perform a regular attention if input tensors are not on XLA device.
attn_weight = q @ k.transpose(-2, -1)
attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output
return non_xla_attetion(q, k, v)

0 comments on commit b3a5948

Please sign in to comment.