diff --git a/test/test_tpu_paged_attention_kernel.py b/test/test_tpu_paged_attention_kernel.py index d2d0f4f19a9..44b6bf3a2c9 100644 --- a/test/test_tpu_paged_attention_kernel.py +++ b/test/test_tpu_paged_attention_kernel.py @@ -11,7 +11,7 @@ # Set up paged_attention inputs. def _generate_qkv( - kv_seq_lens, + batch_size, page_size, max_kv_len, query_len, @@ -23,7 +23,6 @@ def _generate_qkv( ): assert max_kv_len % page_size == 0 pages_per_sequence = max_kv_len // page_size - batch_size = len(kv_seq_lens) total_pages = batch_size * pages_per_sequence k1, k2, k3, k4 = jax.random.split(prng_key, 4) k_pages = jax.random.normal( @@ -113,7 +112,7 @@ def setUp(self): num_queries_per_compute_block=(16, 32), block_kv_size=(128, 256), ) - def test_paged_attention_without_query_padding( + def _test_paged_attention_without_query_padding( self, dtype, page_size, @@ -138,7 +137,7 @@ def test_paged_attention_without_query_padding( assert max_kv_len <= total_num_pages * page_size q, k_pages, v_pages, page_indices = _generate_qkv( - kv_seq_lens, + batch_size, page_size, max_kv_len, query_len, @@ -224,8 +223,9 @@ def test_paged_attention_with_query_padding( # Set query_len>kv_seq_lens query_len = max_kv_len batch_size = 3 - kv_seq_lens = jax.random.randint( - jax.random.key(0), (batch_size,), 0, max_kv_len) + # kv_seq_lens = jax.random.randint( + # jax.random.key(0), (batch_size,), 0, max_kv_len) + kv_seq_lens = jnp.array([256, 512, 512]) effective_q_lens = jax.random.randint( jax.random.key(0), (batch_size,), 0, kv_seq_lens) for cur_effec_q_len, cur_kv_seq_len in zip(effective_q_lens, kv_seq_lens): @@ -235,7 +235,7 @@ def test_paged_attention_with_query_padding( total_num_pages = batch_size * pages_per_sequence assert max_kv_len <= total_num_pages * page_size q, k_pages, v_pages, page_indices = _generate_qkv( - kv_seq_lens, + batch_size, page_size, max_kv_len, query_len, @@ -292,6 +292,79 @@ def test_paged_attention_with_query_padding( atol=atol, rtol=rtol)) + def test_paged_attention_store_to_output_correctly( + self, + ): + # Make sure the internal FA store_to_output correctly. + dtype = jnp.float32 + page_size=16 + num_kv_heads = 8 + q_kv_head_ratio = 4 + head_dim = 256 + num_queries_per_compute_block = 32 + block_kv_size = 256 + + max_kv_len = 512 + query_len = max_kv_len + batch_size = 3 + # Set various edge case testing the internal flash attention can store_to_output correct + kv_seq_lens = jnp.array([block_kv_size-1, block_kv_size+1, 2*block_kv_size]) + assert len(kv_seq_lens) == batch_size + effective_q_lens = jax.random.randint( + jax.random.key(0), (batch_size,), 0, kv_seq_lens) + for cur_effec_q_len, cur_kv_seq_len in zip(effective_q_lens, kv_seq_lens): + assert cur_effec_q_len <= cur_kv_seq_len, f'The effective query len {cur_effec_q_len} should be less than or equal to the kv_len {cur_kv_seq_len} in the current sequence.' + + pages_per_sequence = max_kv_len // page_size + total_num_pages = batch_size * pages_per_sequence + assert max_kv_len <= total_num_pages * page_size + q, k_pages, v_pages, page_indices = _generate_qkv( + batch_size, + page_size, + max_kv_len, + query_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + jax.random.key(0), + dtype, + ) + + num_kv_pages_per_compute_block = block_kv_size // page_size + actual_output = paged_attention( + q, + k_pages, + v_pages, + kv_seq_lens, + page_indices, + effective_q_lens, + num_kv_pages_per_compute_block=block_kv_size // page_size, + num_queries_per_compute_block=num_queries_per_compute_block, + ) + actual_output = jax.block_until_ready(actual_output) + + # Run the ref impl. + expected_output = _ref_jax_extended_paged_attention( + q, + k_pages, + v_pages, + kv_seq_lens, + page_indices, + effective_q_lens, + ) + + self.assertEqual(actual_output.shape, expected_output.shape) + + atol = 2e-2 + rtol = 1e-2 + for b in range(batch_size): + effective_q_len = effective_q_lens[b] + self.assertTrue( + jnp.allclose( + expected_output[b, :effective_q_len], + actual_output[b, :effective_q_len], + atol=atol, + rtol=rtol)) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index 0bb572c49e1..5e4988a9892 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -198,7 +198,8 @@ def start_new_sequence(): o_curr = jax.lax.dot(p.astype(v.dtype), v, preferred_element_type=jnp.float32) acc_scratch_ref[q_head_idx_per_kv] += o_curr * l_broadcast(l_next_inv_safe) - @pl.when(kv_blk_idx == kv_len // kv_seq_len_per_kv_compute_blk) + # @pl.when(kv_blk_idx == kv_len // kv_seq_len_per_kv_compute_blk) + @pl.when(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk)-1) def store_to_output(): o_ref[0, q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype( o_ref.dtype) @@ -384,7 +385,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable MIN_BLOCK_SIZE = 128 - +@jax.profiler.annotate_function @functools.partial( jax.jit, static_argnames=[