Skip to content

Commit

Permalink
Skip the computation if the current block lies above the causal mask …
Browse files Browse the repository at this point in the history
…diagonal. (#8448)
  • Loading branch information
vanbasten23 authored Dec 5, 2024
1 parent a48ff6a commit 5270d7c
Showing 1 changed file with 49 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,12 @@ def start_new_sequence():

q_index = q_blk_idx * num_queries_per_compute_block
kv_index = kv_blk_idx * kv_seq_len_per_kv_compute_blk
kv_len = lengths_ref[b]
effective_kv_len = lengths_ref[b]
effective_q_len = effective_q_lens_ref[b]
row_ids = (kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota(
jnp.int32,
(num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 0)
row_ids = (
effective_kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota(
jnp.int32,
(num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 0)
col_ids = kv_index + jax.lax.broadcasted_iota(
jnp.int32,
(num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 1)
Expand All @@ -176,7 +177,7 @@ def start_new_sequence():

alpha = jnp.exp(m_prev - m_next) # Shape [block_q, 128]

l_corr = alpha * l_prev
l_corr = alpha * l_prev # Shape [block_q, 128]

l_next = jnp.sum(p, axis=1)[:, None] + l_corr # Shape [block_q, 128]

Expand All @@ -192,16 +193,28 @@ def start_new_sequence():
l_scratch_ref[q_head_idx_per_kv] = l_next
m_scratch_ref[q_head_idx_per_kv] = m_next

l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next)
l_next_inv_safe = jnp.where(l_next == 0.0, 1.0,
1.0 / l_next) # [block_q, 128]

acc_scratch_ref[q_head_idx_per_kv] *= l_broadcast(l_corr * l_next_inv_safe)
# Note Matmul operandlhs must have a shape divisible by (16, 1)
o_curr = jax.lax.dot(p.astype(v.dtype), v, preferred_element_type=jnp.float32)
o_curr = jax.lax.dot(
p.astype(v.dtype), v,
preferred_element_type=jnp.float32) # [block_q, 128]

acc_scratch_ref[q_head_idx_per_kv] += o_curr * l_broadcast(l_next_inv_safe)

# The condition comes from the one "@pl.when(kv_blk_idx * compute_blk_size_kv < kv_len)" controlling if we should run the function get_kv_and_run_flash_attention.
# The condition comes from the check controlling if we should run the function get_kv_and_run_flash_attention.
# If kv_len=512, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 1.
# If kv_len=513, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 2.
@pl.when(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk) - 1)
is_last_kv_blk_idx = kv_blk_idx == pl.cdiv(effective_kv_len,
kv_seq_len_per_kv_compute_blk) - 1
is_next_kv_blk_masked_out = jnp.logical_not(
_block_below_or_on_diag(q_blk_idx, num_queries_per_compute_block,
kv_blk_idx + 1, kv_seq_len_per_kv_compute_blk,
effective_q_len, effective_kv_len))

@pl.when(jnp.logical_or(is_last_kv_blk_idx, is_next_kv_blk_masked_out))
def store_to_output():
o_ref[0, q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype(
o_ref.dtype)
Expand All @@ -211,6 +224,16 @@ def store_to_output():
m_ref.dtype)


# A block is considered below or on diagonal as long as the bottom left
# corner of the block is below or on diagonal.
# If the inputs are 0, 32, 0, 256, 64, 257, the block's bottom left corner is (31, 0). For that column(0), the diagonal element is (-193, 0). We check(>=) the x-coordinate of the corner and the diagonal element (31 and -193)
# If the inputs are 0, 32, 1, 256, 64, 257, the block's bottom left corner is (31, 256). For that column(256), the diagonal element is (63, 256). We check(>=) the x-coordinate of the corner and the diagonal element (31 and 63).
def _block_below_or_on_diag(q_blk_idx, q_blk_size, kv_blk_idx, kv_blk_size,
effective_q_len, effective_kv_len):
return ((q_blk_idx + 1) * q_blk_size - 1) >= (kv_blk_idx * kv_blk_size) - (
effective_kv_len - effective_q_len)


def paged_flash_attention_kernel(
lengths_ref, # [batch_size] jax.Array the length of each example
# 1d vector, results from page_indices.reshape(-1) where originally page_indices.shape=[batch_size, pages_per_sequence]
Expand Down Expand Up @@ -241,7 +264,6 @@ def paged_flash_attention_kernel(
pages_per_sequence: int, # Note [bs, pages_per_sequence] = page_indices.shape
batch_size: int,
num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int,
mask_value: float,
query_len: int,
):
Expand All @@ -256,11 +278,17 @@ def paged_flash_attention_kernel(
b_q_ref, num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim = q_ref.shape
num_kv_heads, total_num_pages, page_size, head_dim = k_pages_hbm_ref.shape
compute_blk_size_kv = page_size * num_kv_pages_per_compute_block
kv_len = lengths_ref[b]
effective_kv_len = lengths_ref[b]
effective_q_len = effective_q_lens_ref[b]

# TODO: think about skip the work when we know the causal mask would mask all (e.g. when the whole kv_blk is after the whole q_blk)
# Get the K and V for the current batch and current kv head.
@pl.when(kv_blk_idx * compute_blk_size_kv < kv_len)
should_run = jnp.logical_and(
kv_blk_idx * compute_blk_size_kv < effective_kv_len,
_block_below_or_on_diag(q_blk_idx, num_queries_per_compute_block,
kv_blk_idx, compute_blk_size_kv, effective_q_len,
effective_kv_len))

@pl.when(should_run)
def get_kv_and_run_flash_attention():
# Loop over num_q_heads_per_kv_head and use the same K and V
def compute_block_indices(b, kv_head_idx, q_blk_idx, kv_blk_idx):
Expand Down Expand Up @@ -289,16 +317,21 @@ def advance_to_next_non_zero_length():
)

def advance_kv_head_idx():
# assumption: kv_blk_idx * compute_blk_size_kv >= lengths_ref[b]
# assumption: kv_blk_idx * compute_blk_size_kv >= lengths_ref[b], or the block is above the causal mask diagonal.
next_kv_head_idx = kv_head_idx + 1
return lax.cond(
q_blk_idx == num_q_blks - 1,
lambda: lax.cond(next_kv_head_idx < num_kv_heads, lambda:
(b, next_kv_head_idx, 0), advance_b), lambda:
(b, kv_head_idx, 0))

return lax.cond(kv_blk_idx * compute_blk_size_kv < lengths_ref[b], lambda:
(b, kv_head_idx, kv_blk_idx), advance_kv_head_idx)
return lax.cond(
jnp.logical_and(
kv_blk_idx * compute_blk_size_kv < lengths_ref[b],
_block_below_or_on_diag(q_blk_idx, num_queries_per_compute_block,
kv_blk_idx, compute_blk_size_kv,
effective_q_lens_ref[b], lengths_ref[b])),
lambda: (b, kv_head_idx, kv_blk_idx), advance_kv_head_idx)

def create_kv_async_copy_descriptors(b, kv_head_idx, kv_blk_idx,
buffer_index):
Expand Down Expand Up @@ -586,7 +619,6 @@ def lm_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_):
pages_per_sequence=pages_per_sequence,
batch_size=batch_size,
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
num_queries_per_compute_block=num_queries_per_compute_block,
mask_value=mask_value,
query_len=query_len),
grid_spec=pltpu.PrefetchScalarGridSpec(
Expand Down

0 comments on commit 5270d7c

Please sign in to comment.