diff --git a/test/test_tpu_paged_attention_kernel.py b/test/test_tpu_paged_attention_kernel.py index 80023652d69..d2d0f4f19a9 100644 --- a/test/test_tpu_paged_attention_kernel.py +++ b/test/test_tpu_paged_attention_kernel.py @@ -162,7 +162,8 @@ def test_paged_attention_without_query_padding( num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, num_queries_per_compute_block=num_queries_per_compute_block, ) - # actual_output = jax.block_until_ready(actual_output) + # Note kernel execution is async. Without blocking, if an error happens in the kernel, the error may point to some irrelevant and confusing places. See https://github.com/pytorch/xla/pull/8356#issuecomment-2486861631 + actual_output = jax.block_until_ready(actual_output) # Run the ref impl. expected_output = _ref_jax_extended_paged_attention( @@ -219,7 +220,7 @@ def test_paged_attention_with_query_padding( block_kv_size, ): - max_kv_len = 2048 + max_kv_len = 512 # Set query_len>kv_seq_lens query_len = max_kv_len batch_size = 3 @@ -259,7 +260,7 @@ def test_paged_attention_with_query_padding( num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, num_queries_per_compute_block=num_queries_per_compute_block, ) - # actual_output = jax.block_until_ready(actual_output) + actual_output = jax.block_until_ready(actual_output) # Run the ref impl. expected_output = _ref_jax_extended_paged_attention( diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index 92724b05bfc..0bb572c49e1 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -198,13 +198,14 @@ def start_new_sequence(): o_curr = jax.lax.dot(p.astype(v.dtype), v, preferred_element_type=jnp.float32) acc_scratch_ref[q_head_idx_per_kv] += o_curr * l_broadcast(l_next_inv_safe) - # TODO: To potentially improve the perf, consider not to update o_ref, l_ref, and m_ref at every kv_blk_idx. Instead, use a proper @pl.when(kv_blk_idx == ...) at the last kv_block. - o_ref[0, q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype( - o_ref.dtype) - l_ref[0, q_head_idx_per_kv] = l_scratch_ref[q_head_idx_per_kv].astype( - l_ref.dtype) - m_ref[0, q_head_idx_per_kv] = m_scratch_ref[q_head_idx_per_kv].astype( - m_ref.dtype) + @pl.when(kv_blk_idx == kv_len // kv_seq_len_per_kv_compute_blk) + def store_to_output(): + o_ref[0, q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype( + o_ref.dtype) + l_ref[0, q_head_idx_per_kv] = l_scratch_ref[q_head_idx_per_kv].astype( + l_ref.dtype) + m_ref[0, q_head_idx_per_kv] = m_scratch_ref[q_head_idx_per_kv].astype( + m_ref.dtype) def paged_flash_attention_kernel(