Skip to content

Commit

Permalink
Remove the constraint of pages_per_sequence%num_kv_pages_per_compute_…
Browse files Browse the repository at this point in the history
…block==0
  • Loading branch information
vanbasten23 committed Dec 5, 2024
1 parent 4c99d21 commit 7ac2cf1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
4 changes: 2 additions & 2 deletions test/test_tpu_paged_attention_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def setUp(self):
q_kv_head_ratio=(1, 4, 8),
head_dim=(128, 256),
num_queries_per_compute_block=(16, 32),
block_kv_size=(128, 256),
block_kv_size=(128, 192, 256),
)
def test_paged_attention_without_query_padding(
self,
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_paged_attention_without_query_padding(
q_kv_head_ratio=(1, 4, 8),
head_dim=(128, 256),
num_queries_per_compute_block=(16, 32),
block_kv_size=(128, 256),
block_kv_size=(128, 192, 256),
)
def test_paged_attention_with_query_padding(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,6 @@ def paged_attention(
raise ValueError(
f"{num_kv_pages_per_compute_block=} should be smaller or equal to {pages_per_sequence=}"
)
if pages_per_sequence % num_kv_pages_per_compute_block != 0:
raise ValueError(
"num_kv_pages_per_compute_block must be divisible by pages per sequence. Got"
f" {pages_per_sequence=} and {num_kv_pages_per_compute_block=}.")
if num_q_heads % num_kv_heads != 0:
raise ValueError(
"Number of Q heads must be divisible by number of KV heads. Got"
Expand All @@ -491,8 +487,8 @@ def paged_attention(
num_kv_heads,
pl.cdiv(query_len, num_queries_per_compute_block
), # how many compute blocks we need to loop the query_len
pages_per_sequence //
num_kv_pages_per_compute_block, # how many compute blocks we need to loop the kv_len
pl.cdiv(pages_per_sequence, num_kv_pages_per_compute_block
), # how many compute blocks we need to loop the kv_len
)

# out_shape
Expand Down

0 comments on commit 7ac2cf1

Please sign in to comment.