From 5270d7c451021edd6260f14dd7934419a9d06a44 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Thu, 5 Dec 2024 15:57:28 -0800 Subject: [PATCH] Skip the computation if the current block lies above the causal mask diagonal. (#8448) --- .../multi_queries_paged_attention_kernel.py | 66 ++++++++++++++----- 1 file changed, 49 insertions(+), 17 deletions(-) diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index 557f8ad5ec3..84d6ad530e5 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -148,11 +148,12 @@ def start_new_sequence(): q_index = q_blk_idx * num_queries_per_compute_block kv_index = kv_blk_idx * kv_seq_len_per_kv_compute_blk - kv_len = lengths_ref[b] + effective_kv_len = lengths_ref[b] effective_q_len = effective_q_lens_ref[b] - row_ids = (kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota( - jnp.int32, - (num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 0) + row_ids = ( + effective_kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota( + jnp.int32, + (num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 0) col_ids = kv_index + jax.lax.broadcasted_iota( jnp.int32, (num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 1) @@ -176,7 +177,7 @@ def start_new_sequence(): alpha = jnp.exp(m_prev - m_next) # Shape [block_q, 128] - l_corr = alpha * l_prev + l_corr = alpha * l_prev # Shape [block_q, 128] l_next = jnp.sum(p, axis=1)[:, None] + l_corr # Shape [block_q, 128] @@ -192,16 +193,28 @@ def start_new_sequence(): l_scratch_ref[q_head_idx_per_kv] = l_next m_scratch_ref[q_head_idx_per_kv] = m_next - l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) + l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, + 1.0 / l_next) # [block_q, 128] + acc_scratch_ref[q_head_idx_per_kv] *= l_broadcast(l_corr * l_next_inv_safe) # Note Matmul operandlhs must have a shape divisible by (16, 1) - o_curr = jax.lax.dot(p.astype(v.dtype), v, preferred_element_type=jnp.float32) + o_curr = jax.lax.dot( + p.astype(v.dtype), v, + preferred_element_type=jnp.float32) # [block_q, 128] + acc_scratch_ref[q_head_idx_per_kv] += o_curr * l_broadcast(l_next_inv_safe) - # The condition comes from the one "@pl.when(kv_blk_idx * compute_blk_size_kv < kv_len)" controlling if we should run the function get_kv_and_run_flash_attention. + # The condition comes from the check controlling if we should run the function get_kv_and_run_flash_attention. # If kv_len=512, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 1. # If kv_len=513, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 2. - @pl.when(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk) - 1) + is_last_kv_blk_idx = kv_blk_idx == pl.cdiv(effective_kv_len, + kv_seq_len_per_kv_compute_blk) - 1 + is_next_kv_blk_masked_out = jnp.logical_not( + _block_below_or_on_diag(q_blk_idx, num_queries_per_compute_block, + kv_blk_idx + 1, kv_seq_len_per_kv_compute_blk, + effective_q_len, effective_kv_len)) + + @pl.when(jnp.logical_or(is_last_kv_blk_idx, is_next_kv_blk_masked_out)) def store_to_output(): o_ref[0, q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype( o_ref.dtype) @@ -211,6 +224,16 @@ def store_to_output(): m_ref.dtype) +# A block is considered below or on diagonal as long as the bottom left +# corner of the block is below or on diagonal. +# If the inputs are 0, 32, 0, 256, 64, 257, the block's bottom left corner is (31, 0). For that column(0), the diagonal element is (-193, 0). We check(>=) the x-coordinate of the corner and the diagonal element (31 and -193) +# If the inputs are 0, 32, 1, 256, 64, 257, the block's bottom left corner is (31, 256). For that column(256), the diagonal element is (63, 256). We check(>=) the x-coordinate of the corner and the diagonal element (31 and 63). +def _block_below_or_on_diag(q_blk_idx, q_blk_size, kv_blk_idx, kv_blk_size, + effective_q_len, effective_kv_len): + return ((q_blk_idx + 1) * q_blk_size - 1) >= (kv_blk_idx * kv_blk_size) - ( + effective_kv_len - effective_q_len) + + def paged_flash_attention_kernel( lengths_ref, # [batch_size] jax.Array the length of each example # 1d vector, results from page_indices.reshape(-1) where originally page_indices.shape=[batch_size, pages_per_sequence] @@ -241,7 +264,6 @@ def paged_flash_attention_kernel( pages_per_sequence: int, # Note [bs, pages_per_sequence] = page_indices.shape batch_size: int, num_kv_pages_per_compute_block: int, - num_queries_per_compute_block: int, mask_value: float, query_len: int, ): @@ -256,11 +278,17 @@ def paged_flash_attention_kernel( b_q_ref, num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim = q_ref.shape num_kv_heads, total_num_pages, page_size, head_dim = k_pages_hbm_ref.shape compute_blk_size_kv = page_size * num_kv_pages_per_compute_block - kv_len = lengths_ref[b] + effective_kv_len = lengths_ref[b] + effective_q_len = effective_q_lens_ref[b] - # TODO: think about skip the work when we know the causal mask would mask all (e.g. when the whole kv_blk is after the whole q_blk) # Get the K and V for the current batch and current kv head. - @pl.when(kv_blk_idx * compute_blk_size_kv < kv_len) + should_run = jnp.logical_and( + kv_blk_idx * compute_blk_size_kv < effective_kv_len, + _block_below_or_on_diag(q_blk_idx, num_queries_per_compute_block, + kv_blk_idx, compute_blk_size_kv, effective_q_len, + effective_kv_len)) + + @pl.when(should_run) def get_kv_and_run_flash_attention(): # Loop over num_q_heads_per_kv_head and use the same K and V def compute_block_indices(b, kv_head_idx, q_blk_idx, kv_blk_idx): @@ -289,7 +317,7 @@ def advance_to_next_non_zero_length(): ) def advance_kv_head_idx(): - # assumption: kv_blk_idx * compute_blk_size_kv >= lengths_ref[b] + # assumption: kv_blk_idx * compute_blk_size_kv >= lengths_ref[b], or the block is above the causal mask diagonal. next_kv_head_idx = kv_head_idx + 1 return lax.cond( q_blk_idx == num_q_blks - 1, @@ -297,8 +325,13 @@ def advance_kv_head_idx(): (b, next_kv_head_idx, 0), advance_b), lambda: (b, kv_head_idx, 0)) - return lax.cond(kv_blk_idx * compute_blk_size_kv < lengths_ref[b], lambda: - (b, kv_head_idx, kv_blk_idx), advance_kv_head_idx) + return lax.cond( + jnp.logical_and( + kv_blk_idx * compute_blk_size_kv < lengths_ref[b], + _block_below_or_on_diag(q_blk_idx, num_queries_per_compute_block, + kv_blk_idx, compute_blk_size_kv, + effective_q_lens_ref[b], lengths_ref[b])), + lambda: (b, kv_head_idx, kv_blk_idx), advance_kv_head_idx) def create_kv_async_copy_descriptors(b, kv_head_idx, kv_blk_idx, buffer_index): @@ -586,7 +619,6 @@ def lm_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_): pages_per_sequence=pages_per_sequence, batch_size=batch_size, num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, - num_queries_per_compute_block=num_queries_per_compute_block, mask_value=mask_value, query_len=query_len), grid_spec=pltpu.PrefetchScalarGridSpec(