Skip to content

Commit

Permalink
Fix an edge case error.
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed Dec 3, 2024
1 parent d503ca5 commit 688de8d
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 9 deletions.
87 changes: 80 additions & 7 deletions test/test_tpu_paged_attention_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# Set up paged_attention inputs.
def _generate_qkv(
kv_seq_lens,
batch_size,
page_size,
max_kv_len,
query_len,
Expand All @@ -23,7 +23,6 @@ def _generate_qkv(
):
assert max_kv_len % page_size == 0
pages_per_sequence = max_kv_len // page_size
batch_size = len(kv_seq_lens)
total_pages = batch_size * pages_per_sequence
k1, k2, k3, k4 = jax.random.split(prng_key, 4)
k_pages = jax.random.normal(
Expand Down Expand Up @@ -113,7 +112,7 @@ def setUp(self):
num_queries_per_compute_block=(16, 32),
block_kv_size=(128, 256),
)
def test_paged_attention_without_query_padding(
def _test_paged_attention_without_query_padding(
self,
dtype,
page_size,
Expand All @@ -138,7 +137,7 @@ def test_paged_attention_without_query_padding(
assert max_kv_len <= total_num_pages * page_size

q, k_pages, v_pages, page_indices = _generate_qkv(
kv_seq_lens,
batch_size,
page_size,
max_kv_len,
query_len,
Expand Down Expand Up @@ -224,8 +223,9 @@ def test_paged_attention_with_query_padding(
# Set query_len>kv_seq_lens
query_len = max_kv_len
batch_size = 3
kv_seq_lens = jax.random.randint(
jax.random.key(0), (batch_size,), 0, max_kv_len)
# kv_seq_lens = jax.random.randint(
# jax.random.key(0), (batch_size,), 0, max_kv_len)
kv_seq_lens = jnp.array([256, 512, 512])
effective_q_lens = jax.random.randint(
jax.random.key(0), (batch_size,), 0, kv_seq_lens)
for cur_effec_q_len, cur_kv_seq_len in zip(effective_q_lens, kv_seq_lens):
Expand All @@ -235,7 +235,7 @@ def test_paged_attention_with_query_padding(
total_num_pages = batch_size * pages_per_sequence
assert max_kv_len <= total_num_pages * page_size
q, k_pages, v_pages, page_indices = _generate_qkv(
kv_seq_lens,
batch_size,
page_size,
max_kv_len,
query_len,
Expand Down Expand Up @@ -292,6 +292,79 @@ def test_paged_attention_with_query_padding(
atol=atol,
rtol=rtol))

def test_paged_attention_store_to_output_correctly(
self,
):
# Make sure the internal FA store_to_output correctly.
dtype = jnp.float32
page_size=16
num_kv_heads = 8
q_kv_head_ratio = 4
head_dim = 256
num_queries_per_compute_block = 32
block_kv_size = 256

max_kv_len = 512
query_len = max_kv_len
batch_size = 3
# Set various edge case testing the internal flash attention can store_to_output correct
kv_seq_lens = jnp.array([block_kv_size-1, block_kv_size+1, 2*block_kv_size])
assert len(kv_seq_lens) == batch_size
effective_q_lens = jax.random.randint(
jax.random.key(0), (batch_size,), 0, kv_seq_lens)
for cur_effec_q_len, cur_kv_seq_len in zip(effective_q_lens, kv_seq_lens):
assert cur_effec_q_len <= cur_kv_seq_len, f'The effective query len {cur_effec_q_len} should be less than or equal to the kv_len {cur_kv_seq_len} in the current sequence.'

pages_per_sequence = max_kv_len // page_size
total_num_pages = batch_size * pages_per_sequence
assert max_kv_len <= total_num_pages * page_size
q, k_pages, v_pages, page_indices = _generate_qkv(
batch_size,
page_size,
max_kv_len,
query_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
jax.random.key(0),
dtype,
)

num_kv_pages_per_compute_block = block_kv_size // page_size
actual_output = paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
effective_q_lens,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
)
actual_output = jax.block_until_ready(actual_output)

# Run the ref impl.
expected_output = _ref_jax_extended_paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
effective_q_lens,
)

self.assertEqual(actual_output.shape, expected_output.shape)

atol = 2e-2
rtol = 1e-2
for b in range(batch_size):
effective_q_len = effective_q_lens[b]
self.assertTrue(
jnp.allclose(
expected_output[b, :effective_q_len],
actual_output[b, :effective_q_len],
atol=atol,
rtol=rtol))

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def start_new_sequence():
o_curr = jax.lax.dot(p.astype(v.dtype), v, preferred_element_type=jnp.float32)
acc_scratch_ref[q_head_idx_per_kv] += o_curr * l_broadcast(l_next_inv_safe)

@pl.when(kv_blk_idx == kv_len // kv_seq_len_per_kv_compute_blk)
# @pl.when(kv_blk_idx == kv_len // kv_seq_len_per_kv_compute_blk)
@pl.when(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk)-1)
def store_to_output():
o_ref[0, q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype(
o_ref.dtype)
Expand Down Expand Up @@ -384,7 +385,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable

MIN_BLOCK_SIZE = 128


@jax.profiler.annotate_function
@functools.partial(
jax.jit,
static_argnames=[
Expand Down

0 comments on commit 688de8d

Please sign in to comment.