Skip to content

Commit

Permalink
add a comment
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed Dec 3, 2024
1 parent 688de8d commit 0106b2f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
15 changes: 7 additions & 8 deletions test/test_tpu_paged_attention_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,8 @@ def test_paged_attention_with_query_padding(
# Set query_len>kv_seq_lens
query_len = max_kv_len
batch_size = 3
# kv_seq_lens = jax.random.randint(
# jax.random.key(0), (batch_size,), 0, max_kv_len)
kv_seq_lens = jnp.array([256, 512, 512])
kv_seq_lens = jax.random.randint(
jax.random.key(0), (batch_size,), 0, max_kv_len)
effective_q_lens = jax.random.randint(
jax.random.key(0), (batch_size,), 0, kv_seq_lens)
for cur_effec_q_len, cur_kv_seq_len in zip(effective_q_lens, kv_seq_lens):
Expand Down Expand Up @@ -292,12 +291,10 @@ def test_paged_attention_with_query_padding(
atol=atol,
rtol=rtol))

def test_paged_attention_store_to_output_correctly(
self,
):
def test_paged_attention_store_to_output_correctly(self,):
# Make sure the internal FA store_to_output correctly.
dtype = jnp.float32
page_size=16
page_size = 16
num_kv_heads = 8
q_kv_head_ratio = 4
head_dim = 256
Expand All @@ -308,7 +305,8 @@ def test_paged_attention_store_to_output_correctly(
query_len = max_kv_len
batch_size = 3
# Set various edge case testing the internal flash attention can store_to_output correct
kv_seq_lens = jnp.array([block_kv_size-1, block_kv_size+1, 2*block_kv_size])
kv_seq_lens = jnp.array(
[block_kv_size - 1, block_kv_size + 1, 2 * block_kv_size])
assert len(kv_seq_lens) == batch_size
effective_q_lens = jax.random.randint(
jax.random.key(0), (batch_size,), 0, kv_seq_lens)
Expand Down Expand Up @@ -366,5 +364,6 @@ def test_paged_attention_store_to_output_correctly(
atol=atol,
rtol=rtol))


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,10 @@ def start_new_sequence():
o_curr = jax.lax.dot(p.astype(v.dtype), v, preferred_element_type=jnp.float32)
acc_scratch_ref[q_head_idx_per_kv] += o_curr * l_broadcast(l_next_inv_safe)

# @pl.when(kv_blk_idx == kv_len // kv_seq_len_per_kv_compute_blk)
@pl.when(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk)-1)
# The condition comes from the one "@pl.when(kv_blk_idx * compute_blk_size_kv < kv_len)" controlling if we should run the function get_kv_and_run_flash_attention.
# If kv_len=512, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 1.
# If kv_len=513, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 2.
@pl.when(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk) - 1)
def store_to_output():
o_ref[0, q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype(
o_ref.dtype)
Expand Down Expand Up @@ -385,6 +387,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable

MIN_BLOCK_SIZE = 128


@jax.profiler.annotate_function
@functools.partial(
jax.jit,
Expand Down

0 comments on commit 0106b2f

Please sign in to comment.