From 7ac2cf17e55ef6dd26c89e6613ac37b47f0fd728 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 5 Dec 2024 21:54:18 +0000 Subject: [PATCH] Remove the constraint of pages_per_sequence%num_kv_pages_per_compute_block==0 --- test/test_tpu_paged_attention_kernel.py | 4 ++-- .../multi_queries_paged_attention_kernel.py | 8 ++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/test/test_tpu_paged_attention_kernel.py b/test/test_tpu_paged_attention_kernel.py index 746439ba4d0..74a287611cc 100644 --- a/test/test_tpu_paged_attention_kernel.py +++ b/test/test_tpu_paged_attention_kernel.py @@ -110,7 +110,7 @@ def setUp(self): q_kv_head_ratio=(1, 4, 8), head_dim=(128, 256), num_queries_per_compute_block=(16, 32), - block_kv_size=(128, 256), + block_kv_size=(128, 192, 256), ) def test_paged_attention_without_query_padding( self, @@ -206,7 +206,7 @@ def test_paged_attention_without_query_padding( q_kv_head_ratio=(1, 4, 8), head_dim=(128, 256), num_queries_per_compute_block=(16, 32), - block_kv_size=(128, 256), + block_kv_size=(128, 192, 256), ) def test_paged_attention_with_query_padding( self, diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index 557f8ad5ec3..14b432f92e6 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -475,10 +475,6 @@ def paged_attention( raise ValueError( f"{num_kv_pages_per_compute_block=} should be smaller or equal to {pages_per_sequence=}" ) - if pages_per_sequence % num_kv_pages_per_compute_block != 0: - raise ValueError( - "num_kv_pages_per_compute_block must be divisible by pages per sequence. Got" - f" {pages_per_sequence=} and {num_kv_pages_per_compute_block=}.") if num_q_heads % num_kv_heads != 0: raise ValueError( "Number of Q heads must be divisible by number of KV heads. Got" @@ -491,8 +487,8 @@ def paged_attention( num_kv_heads, pl.cdiv(query_len, num_queries_per_compute_block ), # how many compute blocks we need to loop the query_len - pages_per_sequence // - num_kv_pages_per_compute_block, # how many compute blocks we need to loop the kv_len + pl.cdiv(pages_per_sequence, num_kv_pages_per_compute_block + ), # how many compute blocks we need to loop the kv_len ) # out_shape