From 457e2c63197e68d21c67cf8dd64d0eea05918941 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Thu, 25 Apr 2024 00:45:13 +0000 Subject: [PATCH] Add q_output_dtype --- torch_xla/experimental/custom_kernel.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 55c6b633877a..e2bdf74ae34f 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -400,6 +400,9 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices, buffer_index = torch.zeros((1,), dtype=torch.int32).to("xla") step = torch.zeros((1,), dtype=torch.int32).to("xla") output_shape = torch.Size(list(q.shape[:-1]) + [1]) + q_output_dtype = torch.float32 + if (num_heads // num_kv_heads) % 8 != 0: + q_output_dtype = q.dtype output, _, _ = torch_xla._XLAC._xla_tpu_custom_call( [ @@ -411,7 +414,7 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices, k_pages, v_pages, ], payload, [q.shape, output_shape, output_shape], - [q.dtype, torch.float32, torch.float32]) + [q_output_dtype, torch.float32, torch.float32]) return output.reshape(batch_size, num_heads, head_dim)