Skip to content

Commit

Permalink
fix the test
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed Nov 20, 2024
1 parent d52c6f2 commit 1954394
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions test/test_tpu_paged_attention_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def test_paged_attention_without_query_padding(
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
num_queries_per_compute_block=num_queries_per_compute_block,
)
# actual_output = jax.block_until_ready(actual_output)
# Note kernel execution is async. Without blocking, if an error happens in the kernel, the error may point to some irrelevant and confusing places. See https://github.com/pytorch/xla/pull/8356#issuecomment-2486861631
actual_output = jax.block_until_ready(actual_output)

# Run the ref impl.
expected_output = _ref_jax_extended_paged_attention(
Expand Down Expand Up @@ -219,7 +220,7 @@ def test_paged_attention_with_query_padding(
block_kv_size,
):

max_kv_len = 2048
max_kv_len = 512
# Set query_len>kv_seq_lens
query_len = max_kv_len
batch_size = 3
Expand Down Expand Up @@ -259,7 +260,7 @@ def test_paged_attention_with_query_padding(
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
num_queries_per_compute_block=num_queries_per_compute_block,
)
# actual_output = jax.block_until_ready(actual_output)
actual_output = jax.block_until_ready(actual_output)

# Run the ref impl.
expected_output = _ref_jax_extended_paged_attention(
Expand Down

0 comments on commit 1954394

Please sign in to comment.