Skip to content

Commit

Permalink
Fix typo in warning message
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Apr 25, 2024
1 parent 457e2c6 commit ca36c35
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,12 +419,12 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices,
return output.reshape(batch_size, num_heads, head_dim)


def non_xla_attetion(q, k, v):
def non_xla_attetion(q, k, v, attention_type):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
if k.device != torch.device("meta"):
warnings.warn(
'XLA flash attention should only be applied to tensors on XLA device')
f'XLA {attention_type} attention should only be applied to tensors on XLA device')

# perform a regular attention if input tensors are not on XLA device.
attn_weight = q @ k.transpose(-2, -1)
Expand All @@ -451,7 +451,7 @@ def flash_attention_non_xla(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False):
return non_xla_attetion(q, k, v)
return non_xla_attetion(q, k, v, "flash")


XLA_LIB.define(
Expand All @@ -473,4 +473,4 @@ def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor,
v_pages: torch.Tensor, lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int):
return non_xla_attetion(q, k, v)
return non_xla_attetion(q, k, v, "paged")

0 comments on commit ca36c35

Please sign in to comment.