diff --git a/test/test_tpu_paged_attention_kernel.py b/test/test_tpu_paged_attention_kernel.py index 80023652d69..d2d0f4f19a9 100644 --- a/test/test_tpu_paged_attention_kernel.py +++ b/test/test_tpu_paged_attention_kernel.py @@ -162,7 +162,8 @@ def test_paged_attention_without_query_padding( num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, num_queries_per_compute_block=num_queries_per_compute_block, ) - # actual_output = jax.block_until_ready(actual_output) + # Note kernel execution is async. Without blocking, if an error happens in the kernel, the error may point to some irrelevant and confusing places. See https://github.com/pytorch/xla/pull/8356#issuecomment-2486861631 + actual_output = jax.block_until_ready(actual_output) # Run the ref impl. expected_output = _ref_jax_extended_paged_attention( @@ -219,7 +220,7 @@ def test_paged_attention_with_query_padding( block_kv_size, ): - max_kv_len = 2048 + max_kv_len = 512 # Set query_len>kv_seq_lens query_len = max_kv_len batch_size = 3 @@ -259,7 +260,7 @@ def test_paged_attention_with_query_padding( num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, num_queries_per_compute_block=num_queries_per_compute_block, ) - # actual_output = jax.block_until_ready(actual_output) + actual_output = jax.block_until_ready(actual_output) # Run the ref impl. expected_output = _ref_jax_extended_paged_attention(