diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index e2bdf74ae34f..fa2274cbeac8 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -419,12 +419,12 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices, return output.reshape(batch_size, num_heads, head_dim) -def non_xla_attetion(q, k, v): +def non_xla_attetion(q, k, v, attention_type): # This will be called when dynamo use fake tensor to construct the fake output. # We need to make sure output tensor's shape is correct. if k.device != torch.device("meta"): warnings.warn( - 'XLA flash attention should only be applied to tensors on XLA device') + f'XLA {attention_type} attention should only be applied to tensors on XLA device') # perform a regular attention if input tensors are not on XLA device. attn_weight = q @ k.transpose(-2, -1) @@ -451,7 +451,7 @@ def flash_attention_non_xla(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False): - return non_xla_attetion(q, k, v) + return non_xla_attetion(q, k, v, "flash") XLA_LIB.define( @@ -473,4 +473,4 @@ def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, lengths: torch.Tensor, page_indices: torch.Tensor, pages_per_compute_block: int): - return non_xla_attetion(q, k, v) + return non_xla_attetion(q, k, v, "paged")