From 961dfff7f509f394f83d702a6a683532d77c6ead Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Thu, 25 Apr 2024 01:02:25 +0000 Subject: [PATCH] Convert output back to qdtype --- torch_xla/experimental/custom_kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 53f0942a69df..42b11d3ea9b9 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -414,7 +414,7 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices, ], payload, [q.shape, output_shape, output_shape], [q_output_dtype, torch.float32, torch.float32]) - return output.reshape(batch_size, num_heads, head_dim) + return output.reshape(batch_size, num_heads, head_dim).to(q.dtype) def non_xla_attetion(q, k, v, attention_type):