From 6bb10fb0e5d6bd8c29fb134d908cc0958e9f3f6d Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Tue, 3 Dec 2024 22:51:01 +0000 Subject: [PATCH 1/6] skip based on caucal mask --- .../multi_queries_paged_attention_kernel.py | 54 +++++++++++++------ 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index 557f8ad5ec3..f5cf8226cef 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -15,6 +15,7 @@ DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) + class MultiPageAsyncCopyDescriptor: """Descriptor for async copy of multiple K/V pages from HBM.""" @@ -176,7 +177,7 @@ def start_new_sequence(): alpha = jnp.exp(m_prev - m_next) # Shape [block_q, 128] - l_corr = alpha * l_prev + l_corr = alpha * l_prev # Shape [block_q, 128] l_next = jnp.sum(p, axis=1)[:, None] + l_corr # Shape [block_q, 128] @@ -192,16 +193,20 @@ def start_new_sequence(): l_scratch_ref[q_head_idx_per_kv] = l_next m_scratch_ref[q_head_idx_per_kv] = m_next - l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) + l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) # [block_q, 128] + acc_scratch_ref[q_head_idx_per_kv] *= l_broadcast(l_corr * l_next_inv_safe) # Note Matmul operandlhs must have a shape divisible by (16, 1) - o_curr = jax.lax.dot(p.astype(v.dtype), v, preferred_element_type=jnp.float32) + o_curr = jax.lax.dot(p.astype(v.dtype), v, preferred_element_type=jnp.float32) # [block_q, 128] + acc_scratch_ref[q_head_idx_per_kv] += o_curr * l_broadcast(l_next_inv_safe) + # @pl.when(jnp.logical_and(q_blk_idx == 0, kv_blk_idx == 1)) + # def _(): + # pl.debug_print('xw32 line201') + # blk_debug_blkq_128_ref[:] = acc_scratch_ref[q_head_idx_per_kv] - # The condition comes from the one "@pl.when(kv_blk_idx * compute_blk_size_kv < kv_len)" controlling if we should run the function get_kv_and_run_flash_attention. - # If kv_len=512, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 1. - # If kv_len=513, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 2. - @pl.when(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk) - 1) + #@pl.when(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk)-1) + @pl.when(jnp.logical_or(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk)-1, jnp.logical_not(_below_or_on_diag(q_blk_idx, num_queries_per_compute_block, kv_blk_idx+1, kv_seq_len_per_kv_compute_blk, effective_q_len, kv_len)))) def store_to_output(): o_ref[0, q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype( o_ref.dtype) @@ -210,6 +215,18 @@ def store_to_output(): m_ref[0, q_head_idx_per_kv] = m_scratch_ref[q_head_idx_per_kv].astype( m_ref.dtype) +# partially working +# def _below_or_on_diag(q_blk_idx, q_blk_size, kv_blk_idx, kv_blk_size, effective_q_len, effective_kv_len): +# return ((q_blk_idx+1) * q_blk_size - 1) <= (q_blk_size-1) + (kv_blk_idx * kv_blk_size) + (effective_kv_len - effective_q_len) + +def _below_or_on_diag(q_blk_idx, q_blk_size, kv_blk_idx, kv_blk_size, effective_q_len, effective_kv_len): + return ((q_blk_idx+1) * q_blk_size - 1) >= (kv_blk_idx * kv_blk_size) - (effective_kv_len - effective_q_len) + +# Original version in FA +# def below_or_on_diag(r, r_blk_size, c, c_blk_size): +# # A block is considered below or on diagonal as long as the bottom left +# # corner of the block is below or on diagonal. +# return ((r + 1) * r_blk_size - 1) > (c * c_blk_size) def paged_flash_attention_kernel( lengths_ref, # [batch_size] jax.Array the length of each example @@ -228,7 +245,6 @@ def paged_flash_attention_kernel( o_ref, l_ref, m_ref, - # scratch space k_vmem_buffer, # (2, num_kv_pages_per_compute_block, num_kv_heads, head_dim) k_scales_vmem_buffer, v_vmem_buffer, # (2, num_kv_pages_per_compute_block, num_kv_heads, head_dim) @@ -241,7 +257,6 @@ def paged_flash_attention_kernel( pages_per_sequence: int, # Note [bs, pages_per_sequence] = page_indices.shape batch_size: int, num_kv_pages_per_compute_block: int, - num_queries_per_compute_block: int, mask_value: float, query_len: int, ): @@ -256,11 +271,15 @@ def paged_flash_attention_kernel( b_q_ref, num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim = q_ref.shape num_kv_heads, total_num_pages, page_size, head_dim = k_pages_hbm_ref.shape compute_blk_size_kv = page_size * num_kv_pages_per_compute_block - kv_len = lengths_ref[b] + effective_kv_len = lengths_ref[b] + effective_q_len = effective_q_lens_ref[b] - # TODO: think about skip the work when we know the causal mask would mask all (e.g. when the whole kv_blk is after the whole q_blk) # Get the K and V for the current batch and current kv head. - @pl.when(kv_blk_idx * compute_blk_size_kv < kv_len) + should_run = jnp.logical_and(kv_blk_idx * compute_blk_size_kv < effective_kv_len, _below_or_on_diag(q_blk_idx, num_queries_per_compute_block, kv_blk_idx, compute_blk_size_kv, effective_q_len, effective_kv_len)) + # pl.debug_print('xw32 line272 kv_blk_idx={}, compute_blk_size_kv={}, (kv_blk_idx * compute_blk_size_kv)={}, effective_kv_len={}', kv_blk_idx, compute_blk_size_kv, (kv_blk_idx * compute_blk_size_kv), effective_kv_len) + # pl.debug_print('xw32 line273 q_blk_idx={}, num_queries_per_compute_block={}, kv_blk_idx={}, compute_blk_size_kv={}, effective_q_len={}, effective_kv_len={}', q_blk_idx, num_queries_per_compute_block, kv_blk_idx, compute_blk_size_kv, effective_q_len, effective_kv_len) + #@pl.when(kv_blk_idx * compute_blk_size_kv < effective_kv_len) + @pl.when(should_run) def get_kv_and_run_flash_attention(): # Loop over num_q_heads_per_kv_head and use the same K and V def compute_block_indices(b, kv_head_idx, q_blk_idx, kv_blk_idx): @@ -289,7 +308,7 @@ def advance_to_next_non_zero_length(): ) def advance_kv_head_idx(): - # assumption: kv_blk_idx * compute_blk_size_kv >= lengths_ref[b] + # assumption: kv_blk_idx * compute_blk_size_kv >= lengths_ref[b], or the block is to the right of the mask. next_kv_head_idx = kv_head_idx + 1 return lax.cond( q_blk_idx == num_q_blks - 1, @@ -297,7 +316,9 @@ def advance_kv_head_idx(): (b, next_kv_head_idx, 0), advance_b), lambda: (b, kv_head_idx, 0)) - return lax.cond(kv_blk_idx * compute_blk_size_kv < lengths_ref[b], lambda: + # return lax.cond(kv_blk_idx * compute_blk_size_kv < lengths_ref[b], lambda: + # (b, kv_head_idx, kv_blk_idx), advance_kv_head_idx) + return lax.cond(jnp.logical_and(kv_blk_idx * compute_blk_size_kv < lengths_ref[b], _below_or_on_diag(q_blk_idx,num_queries_per_compute_block,kv_blk_idx,compute_blk_size_kv, effective_q_lens_ref[b], lengths_ref[b])), lambda: (b, kv_head_idx, kv_blk_idx), advance_kv_head_idx) def create_kv_async_copy_descriptors(b, kv_head_idx, kv_blk_idx, @@ -342,6 +363,8 @@ def prefetch_first_block(): # pylint: disable=unused-variable next_b, next_kv_head_idx, next_kv_blk_idx = compute_block_indices( b, kv_head_idx, q_blk_idx, kv_blk_idx + 1) + pl.debug_print('xw32 line354, working on b={}, kv_head_idx={}, q_blk_idx={},kv_blk_idx={}', b, kv_head_idx, q_blk_idx, kv_blk_idx) + pl.debug_print('xw32 line355, working on next_b={}, next_kv_head_idx={}, q_blk_idx={},next_kv_blk_idx={}', next_b, next_kv_head_idx, q_blk_idx, next_kv_blk_idx) @pl.when(next_b < batch_size) def prefetch_next_block(): # pylint: disable=unused-variable @@ -369,7 +392,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable o_ref, # [1, num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] l_ref, m_ref, - l_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] + l_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] m_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] acc_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, @@ -586,7 +609,6 @@ def lm_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_): pages_per_sequence=pages_per_sequence, batch_size=batch_size, num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, - num_queries_per_compute_block=num_queries_per_compute_block, mask_value=mask_value, query_len=query_len), grid_spec=pltpu.PrefetchScalarGridSpec( From 0283a26a12cc23641fd9b3b3c5f2f0cee4613956 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Tue, 3 Dec 2024 23:28:44 +0000 Subject: [PATCH 2/6] run linter --- .../multi_queries_paged_attention_kernel.py | 66 ++++++++++--------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index f5cf8226cef..fda42212f00 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -15,7 +15,6 @@ DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) - class MultiPageAsyncCopyDescriptor: """Descriptor for async copy of multiple K/V pages from HBM.""" @@ -193,20 +192,24 @@ def start_new_sequence(): l_scratch_ref[q_head_idx_per_kv] = l_next m_scratch_ref[q_head_idx_per_kv] = m_next - l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) # [block_q, 128] + l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, + 1.0 / l_next) # [block_q, 128] acc_scratch_ref[q_head_idx_per_kv] *= l_broadcast(l_corr * l_next_inv_safe) # Note Matmul operandlhs must have a shape divisible by (16, 1) - o_curr = jax.lax.dot(p.astype(v.dtype), v, preferred_element_type=jnp.float32) # [block_q, 128] + o_curr = jax.lax.dot( + p.astype(v.dtype), v, + preferred_element_type=jnp.float32) # [block_q, 128] acc_scratch_ref[q_head_idx_per_kv] += o_curr * l_broadcast(l_next_inv_safe) - # @pl.when(jnp.logical_and(q_blk_idx == 0, kv_blk_idx == 1)) - # def _(): - # pl.debug_print('xw32 line201') - # blk_debug_blkq_128_ref[:] = acc_scratch_ref[q_head_idx_per_kv] - #@pl.when(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk)-1) - @pl.when(jnp.logical_or(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk)-1, jnp.logical_not(_below_or_on_diag(q_blk_idx, num_queries_per_compute_block, kv_blk_idx+1, kv_seq_len_per_kv_compute_blk, effective_q_len, kv_len)))) + is_last_kv_blk_idx = kv_blk_idx == pl.cdiv(kv_len, + kv_seq_len_per_kv_compute_blk) - 1 + is_next_kv_blk_masked_out = jnp.logical_not( + _block_below_or_on_diag(q_blk_idx, num_queries_per_compute_block, + kv_blk_idx + 1, kv_seq_len_per_kv_compute_blk, + effective_q_len, kv_len)) + @pl.when(jnp.logical_or(is_last_kv_blk_idx, is_next_kv_blk_masked_out)) def store_to_output(): o_ref[0, q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype( o_ref.dtype) @@ -215,18 +218,12 @@ def store_to_output(): m_ref[0, q_head_idx_per_kv] = m_scratch_ref[q_head_idx_per_kv].astype( m_ref.dtype) -# partially working -# def _below_or_on_diag(q_blk_idx, q_blk_size, kv_blk_idx, kv_blk_size, effective_q_len, effective_kv_len): -# return ((q_blk_idx+1) * q_blk_size - 1) <= (q_blk_size-1) + (kv_blk_idx * kv_blk_size) + (effective_kv_len - effective_q_len) -def _below_or_on_diag(q_blk_idx, q_blk_size, kv_blk_idx, kv_blk_size, effective_q_len, effective_kv_len): - return ((q_blk_idx+1) * q_blk_size - 1) >= (kv_blk_idx * kv_blk_size) - (effective_kv_len - effective_q_len) +def _block_below_or_on_diag(q_blk_idx, q_blk_size, kv_blk_idx, kv_blk_size, + effective_q_len, effective_kv_len): + return ((q_blk_idx + 1) * q_blk_size - 1) >= (kv_blk_idx * kv_blk_size) - ( + effective_kv_len - effective_q_len) -# Original version in FA -# def below_or_on_diag(r, r_blk_size, c, c_blk_size): -# # A block is considered below or on diagonal as long as the bottom left -# # corner of the block is below or on diagonal. -# return ((r + 1) * r_blk_size - 1) > (c * c_blk_size) def paged_flash_attention_kernel( lengths_ref, # [batch_size] jax.Array the length of each example @@ -275,10 +272,12 @@ def paged_flash_attention_kernel( effective_q_len = effective_q_lens_ref[b] # Get the K and V for the current batch and current kv head. - should_run = jnp.logical_and(kv_blk_idx * compute_blk_size_kv < effective_kv_len, _below_or_on_diag(q_blk_idx, num_queries_per_compute_block, kv_blk_idx, compute_blk_size_kv, effective_q_len, effective_kv_len)) - # pl.debug_print('xw32 line272 kv_blk_idx={}, compute_blk_size_kv={}, (kv_blk_idx * compute_blk_size_kv)={}, effective_kv_len={}', kv_blk_idx, compute_blk_size_kv, (kv_blk_idx * compute_blk_size_kv), effective_kv_len) - # pl.debug_print('xw32 line273 q_blk_idx={}, num_queries_per_compute_block={}, kv_blk_idx={}, compute_blk_size_kv={}, effective_q_len={}, effective_kv_len={}', q_blk_idx, num_queries_per_compute_block, kv_blk_idx, compute_blk_size_kv, effective_q_len, effective_kv_len) - #@pl.when(kv_blk_idx * compute_blk_size_kv < effective_kv_len) + should_run = jnp.logical_and( + kv_blk_idx * compute_blk_size_kv < effective_kv_len, + _block_below_or_on_diag(q_blk_idx, num_queries_per_compute_block, + kv_blk_idx, compute_blk_size_kv, effective_q_len, + effective_kv_len)) + @pl.when(should_run) def get_kv_and_run_flash_attention(): # Loop over num_q_heads_per_kv_head and use the same K and V @@ -316,10 +315,13 @@ def advance_kv_head_idx(): (b, next_kv_head_idx, 0), advance_b), lambda: (b, kv_head_idx, 0)) - # return lax.cond(kv_blk_idx * compute_blk_size_kv < lengths_ref[b], lambda: - # (b, kv_head_idx, kv_blk_idx), advance_kv_head_idx) - return lax.cond(jnp.logical_and(kv_blk_idx * compute_blk_size_kv < lengths_ref[b], _below_or_on_diag(q_blk_idx,num_queries_per_compute_block,kv_blk_idx,compute_blk_size_kv, effective_q_lens_ref[b], lengths_ref[b])), lambda: - (b, kv_head_idx, kv_blk_idx), advance_kv_head_idx) + return lax.cond( + jnp.logical_and( + kv_blk_idx * compute_blk_size_kv < lengths_ref[b], + _block_below_or_on_diag(q_blk_idx, num_queries_per_compute_block, + kv_blk_idx, compute_blk_size_kv, + effective_q_lens_ref[b], lengths_ref[b])), + lambda: (b, kv_head_idx, kv_blk_idx), advance_kv_head_idx) def create_kv_async_copy_descriptors(b, kv_head_idx, kv_blk_idx, buffer_index): @@ -363,8 +365,12 @@ def prefetch_first_block(): # pylint: disable=unused-variable next_b, next_kv_head_idx, next_kv_blk_idx = compute_block_indices( b, kv_head_idx, q_blk_idx, kv_blk_idx + 1) - pl.debug_print('xw32 line354, working on b={}, kv_head_idx={}, q_blk_idx={},kv_blk_idx={}', b, kv_head_idx, q_blk_idx, kv_blk_idx) - pl.debug_print('xw32 line355, working on next_b={}, next_kv_head_idx={}, q_blk_idx={},next_kv_blk_idx={}', next_b, next_kv_head_idx, q_blk_idx, next_kv_blk_idx) + pl.debug_print( + 'xw32 line354, working on b={}, kv_head_idx={}, q_blk_idx={},kv_blk_idx={}', + b, kv_head_idx, q_blk_idx, kv_blk_idx) + pl.debug_print( + 'xw32 line355, working on next_b={}, next_kv_head_idx={}, q_blk_idx={},next_kv_blk_idx={}', + next_b, next_kv_head_idx, q_blk_idx, next_kv_blk_idx) @pl.when(next_b < batch_size) def prefetch_next_block(): # pylint: disable=unused-variable @@ -392,7 +398,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable o_ref, # [1, num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] l_ref, m_ref, - l_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] + l_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] m_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE] acc_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, From b4b32a1d2f53da5159ff7566b783d1322befec1e Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 4 Dec 2024 00:13:27 +0000 Subject: [PATCH 3/6] clean up --- .../multi_queries_paged_attention_kernel.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index fda42212f00..e9a7b67db00 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -242,6 +242,7 @@ def paged_flash_attention_kernel( o_ref, l_ref, m_ref, + # scratch space k_vmem_buffer, # (2, num_kv_pages_per_compute_block, num_kv_heads, head_dim) k_scales_vmem_buffer, v_vmem_buffer, # (2, num_kv_pages_per_compute_block, num_kv_heads, head_dim) @@ -307,7 +308,7 @@ def advance_to_next_non_zero_length(): ) def advance_kv_head_idx(): - # assumption: kv_blk_idx * compute_blk_size_kv >= lengths_ref[b], or the block is to the right of the mask. + # assumption: kv_blk_idx * compute_blk_size_kv >= lengths_ref[b], or the block is above the causal mask diagonal. next_kv_head_idx = kv_head_idx + 1 return lax.cond( q_blk_idx == num_q_blks - 1, @@ -365,12 +366,6 @@ def prefetch_first_block(): # pylint: disable=unused-variable next_b, next_kv_head_idx, next_kv_blk_idx = compute_block_indices( b, kv_head_idx, q_blk_idx, kv_blk_idx + 1) - pl.debug_print( - 'xw32 line354, working on b={}, kv_head_idx={}, q_blk_idx={},kv_blk_idx={}', - b, kv_head_idx, q_blk_idx, kv_blk_idx) - pl.debug_print( - 'xw32 line355, working on next_b={}, next_kv_head_idx={}, q_blk_idx={},next_kv_blk_idx={}', - next_b, next_kv_head_idx, q_blk_idx, next_kv_blk_idx) @pl.when(next_b < batch_size) def prefetch_next_block(): # pylint: disable=unused-variable From ab0eae6db59be56e3d12505a0cc4f9ea3a774021 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 4 Dec 2024 03:38:16 +0000 Subject: [PATCH 4/6] add comment and linter --- .../pallas_kernels/multi_queries_paged_attention_kernel.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index e9a7b67db00..8d206022593 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -209,6 +209,7 @@ def start_new_sequence(): _block_below_or_on_diag(q_blk_idx, num_queries_per_compute_block, kv_blk_idx + 1, kv_seq_len_per_kv_compute_blk, effective_q_len, kv_len)) + @pl.when(jnp.logical_or(is_last_kv_blk_idx, is_next_kv_blk_masked_out)) def store_to_output(): o_ref[0, q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype( @@ -219,8 +220,12 @@ def store_to_output(): m_ref.dtype) +# If the inputs are 0, 0, 32, 256, 64, 257, the >= compares the x-coordinate of element (31, 0) and (-193, 0) +# If the inputs are 0, 1, 32, 256, 64, 257, the >= compares the x-coordinate of element (31, 256) and (63, 256) where (63, 256) is the diagonal on row 63. def _block_below_or_on_diag(q_blk_idx, q_blk_size, kv_blk_idx, kv_blk_size, effective_q_len, effective_kv_len): + # A block is considered below or on diagonal as long as the bottom left + # corner of the block is below or on diagonal. return ((q_blk_idx + 1) * q_blk_size - 1) >= (kv_blk_idx * kv_blk_size) - ( effective_kv_len - effective_q_len) From 0869e997b686beb2278f289f8de94a99aa0ba84b Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 4 Dec 2024 19:29:49 +0000 Subject: [PATCH 5/6] revised the comment --- .../multi_queries_paged_attention_kernel.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index 8d206022593..9dc9955f8b1 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -220,12 +220,12 @@ def store_to_output(): m_ref.dtype) -# If the inputs are 0, 0, 32, 256, 64, 257, the >= compares the x-coordinate of element (31, 0) and (-193, 0) -# If the inputs are 0, 1, 32, 256, 64, 257, the >= compares the x-coordinate of element (31, 256) and (63, 256) where (63, 256) is the diagonal on row 63. +# A block is considered below or on diagonal as long as the bottom left +# corner of the block is below or on diagonal. +# If the inputs are 0, 32, 0, 256, 64, 257, the block's bottom left corner is (31, 0). For that column(0), the diagonal element is (-193, 0). We check(>=) the x-coordinate of the corner and the diagonal element (31 and -193) +# If the inputs are 0, 32, 1, 256, 64, 257, the block's bottom left corner is (31, 256). For that column(256), the diagonal element is (63, 256). We check(>=) the x-coordinate of the corner and the diagonal element (31 and 63). def _block_below_or_on_diag(q_blk_idx, q_blk_size, kv_blk_idx, kv_blk_size, effective_q_len, effective_kv_len): - # A block is considered below or on diagonal as long as the bottom left - # corner of the block is below or on diagonal. return ((q_blk_idx + 1) * q_blk_size - 1) >= (kv_blk_idx * kv_blk_size) - ( effective_kv_len - effective_q_len) From 56064e79e0f8b2c1f59ae9ed121c3e8ee868a824 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 5 Dec 2024 03:52:19 +0000 Subject: [PATCH 6/6] rebased --- .../multi_queries_paged_attention_kernel.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index 9dc9955f8b1..84d6ad530e5 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -148,11 +148,12 @@ def start_new_sequence(): q_index = q_blk_idx * num_queries_per_compute_block kv_index = kv_blk_idx * kv_seq_len_per_kv_compute_blk - kv_len = lengths_ref[b] + effective_kv_len = lengths_ref[b] effective_q_len = effective_q_lens_ref[b] - row_ids = (kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota( - jnp.int32, - (num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 0) + row_ids = ( + effective_kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota( + jnp.int32, + (num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 0) col_ids = kv_index + jax.lax.broadcasted_iota( jnp.int32, (num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 1) @@ -203,12 +204,15 @@ def start_new_sequence(): acc_scratch_ref[q_head_idx_per_kv] += o_curr * l_broadcast(l_next_inv_safe) - is_last_kv_blk_idx = kv_blk_idx == pl.cdiv(kv_len, + # The condition comes from the check controlling if we should run the function get_kv_and_run_flash_attention. + # If kv_len=512, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 1. + # If kv_len=513, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 2. + is_last_kv_blk_idx = kv_blk_idx == pl.cdiv(effective_kv_len, kv_seq_len_per_kv_compute_blk) - 1 is_next_kv_blk_masked_out = jnp.logical_not( _block_below_or_on_diag(q_blk_idx, num_queries_per_compute_block, kv_blk_idx + 1, kv_seq_len_per_kv_compute_blk, - effective_q_len, kv_len)) + effective_q_len, effective_kv_len)) @pl.when(jnp.logical_or(is_last_kv_blk_idx, is_next_kv_blk_masked_out)) def store_to_output():