Skip to content

Commit

Permalink
Convert output back to qdtype
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Apr 25, 2024
1 parent 5778694 commit 961dfff
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices,
], payload, [q.shape, output_shape, output_shape],
[q_output_dtype, torch.float32, torch.float32])

return output.reshape(batch_size, num_heads, head_dim)
return output.reshape(batch_size, num_heads, head_dim).to(q.dtype)


def non_xla_attetion(q, k, v, attention_type):
Expand Down

0 comments on commit 961dfff

Please sign in to comment.