diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 53f0942a69df..42b11d3ea9b9 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -414,7 +414,7 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices, ], payload, [q.shape, output_shape, output_shape], [q_output_dtype, torch.float32, torch.float32]) - return output.reshape(batch_size, num_heads, head_dim) + return output.reshape(batch_size, num_heads, head_dim).to(q.dtype) def non_xla_attetion(q, k, v, attention_type):