From 0106b2f2488a9bf4487aad50aa95b87f613dfee5 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Tue, 3 Dec 2024 22:38:32 +0000 Subject: [PATCH] add a comment --- test/test_tpu_paged_attention_kernel.py | 15 +++++++-------- .../multi_queries_paged_attention_kernel.py | 7 +++++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/test/test_tpu_paged_attention_kernel.py b/test/test_tpu_paged_attention_kernel.py index 44b6bf3a2c9..3fcab456844 100644 --- a/test/test_tpu_paged_attention_kernel.py +++ b/test/test_tpu_paged_attention_kernel.py @@ -223,9 +223,8 @@ def test_paged_attention_with_query_padding( # Set query_len>kv_seq_lens query_len = max_kv_len batch_size = 3 - # kv_seq_lens = jax.random.randint( - # jax.random.key(0), (batch_size,), 0, max_kv_len) - kv_seq_lens = jnp.array([256, 512, 512]) + kv_seq_lens = jax.random.randint( + jax.random.key(0), (batch_size,), 0, max_kv_len) effective_q_lens = jax.random.randint( jax.random.key(0), (batch_size,), 0, kv_seq_lens) for cur_effec_q_len, cur_kv_seq_len in zip(effective_q_lens, kv_seq_lens): @@ -292,12 +291,10 @@ def test_paged_attention_with_query_padding( atol=atol, rtol=rtol)) - def test_paged_attention_store_to_output_correctly( - self, - ): + def test_paged_attention_store_to_output_correctly(self,): # Make sure the internal FA store_to_output correctly. dtype = jnp.float32 - page_size=16 + page_size = 16 num_kv_heads = 8 q_kv_head_ratio = 4 head_dim = 256 @@ -308,7 +305,8 @@ def test_paged_attention_store_to_output_correctly( query_len = max_kv_len batch_size = 3 # Set various edge case testing the internal flash attention can store_to_output correct - kv_seq_lens = jnp.array([block_kv_size-1, block_kv_size+1, 2*block_kv_size]) + kv_seq_lens = jnp.array( + [block_kv_size - 1, block_kv_size + 1, 2 * block_kv_size]) assert len(kv_seq_lens) == batch_size effective_q_lens = jax.random.randint( jax.random.key(0), (batch_size,), 0, kv_seq_lens) @@ -366,5 +364,6 @@ def test_paged_attention_store_to_output_correctly( atol=atol, rtol=rtol)) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index 5e4988a9892..557f8ad5ec3 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -198,8 +198,10 @@ def start_new_sequence(): o_curr = jax.lax.dot(p.astype(v.dtype), v, preferred_element_type=jnp.float32) acc_scratch_ref[q_head_idx_per_kv] += o_curr * l_broadcast(l_next_inv_safe) - # @pl.when(kv_blk_idx == kv_len // kv_seq_len_per_kv_compute_blk) - @pl.when(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk)-1) + # The condition comes from the one "@pl.when(kv_blk_idx * compute_blk_size_kv < kv_len)" controlling if we should run the function get_kv_and_run_flash_attention. + # If kv_len=512, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 1. + # If kv_len=513, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 2. + @pl.when(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk) - 1) def store_to_output(): o_ref[0, q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype( o_ref.dtype) @@ -385,6 +387,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable MIN_BLOCK_SIZE = 128 + @jax.profiler.annotate_function @functools.partial( jax.jit,