Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip the computation if the current block lies above the causal mask diagonal. #8448

Merged
merged 6 commits into from
Dec 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,12 @@ def start_new_sequence():

q_index = q_blk_idx * num_queries_per_compute_block
kv_index = kv_blk_idx * kv_seq_len_per_kv_compute_blk
kv_len = lengths_ref[b]
effective_kv_len = lengths_ref[b]
effective_q_len = effective_q_lens_ref[b]
row_ids = (kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota(
jnp.int32,
(num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 0)
row_ids = (
effective_kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota(
jnp.int32,
(num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 0)
col_ids = kv_index + jax.lax.broadcasted_iota(
jnp.int32,
(num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 1)
Expand All @@ -176,7 +177,7 @@ def start_new_sequence():

alpha = jnp.exp(m_prev - m_next) # Shape [block_q, 128]

l_corr = alpha * l_prev
l_corr = alpha * l_prev # Shape [block_q, 128]

l_next = jnp.sum(p, axis=1)[:, None] + l_corr # Shape [block_q, 128]

Expand All @@ -192,16 +193,28 @@ def start_new_sequence():
l_scratch_ref[q_head_idx_per_kv] = l_next
m_scratch_ref[q_head_idx_per_kv] = m_next

l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next)
l_next_inv_safe = jnp.where(l_next == 0.0, 1.0,
1.0 / l_next) # [block_q, 128]

acc_scratch_ref[q_head_idx_per_kv] *= l_broadcast(l_corr * l_next_inv_safe)
# Note Matmul operandlhs must have a shape divisible by (16, 1)
o_curr = jax.lax.dot(p.astype(v.dtype), v, preferred_element_type=jnp.float32)
o_curr = jax.lax.dot(
p.astype(v.dtype), v,
preferred_element_type=jnp.float32) # [block_q, 128]

acc_scratch_ref[q_head_idx_per_kv] += o_curr * l_broadcast(l_next_inv_safe)

# The condition comes from the one "@pl.when(kv_blk_idx * compute_blk_size_kv < kv_len)" controlling if we should run the function get_kv_and_run_flash_attention.
# The condition comes from the check controlling if we should run the function get_kv_and_run_flash_attention.
# If kv_len=512, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 1.
# If kv_len=513, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 2.
@pl.when(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk) - 1)
is_last_kv_blk_idx = kv_blk_idx == pl.cdiv(effective_kv_len,
kv_seq_len_per_kv_compute_blk) - 1
is_next_kv_blk_masked_out = jnp.logical_not(
_block_below_or_on_diag(q_blk_idx, num_queries_per_compute_block,
kv_blk_idx + 1, kv_seq_len_per_kv_compute_blk,
effective_q_len, effective_kv_len))

@pl.when(jnp.logical_or(is_last_kv_blk_idx, is_next_kv_blk_masked_out))
def store_to_output():
o_ref[0, q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype(
o_ref.dtype)
Expand All @@ -211,6 +224,16 @@ def store_to_output():
m_ref.dtype)


# A block is considered below or on diagonal as long as the bottom left
# corner of the block is below or on diagonal.
# If the inputs are 0, 32, 0, 256, 64, 257, the block's bottom left corner is (31, 0). For that column(0), the diagonal element is (-193, 0). We check(>=) the x-coordinate of the corner and the diagonal element (31 and -193)
# If the inputs are 0, 32, 1, 256, 64, 257, the block's bottom left corner is (31, 256). For that column(256), the diagonal element is (63, 256). We check(>=) the x-coordinate of the corner and the diagonal element (31 and 63).
def _block_below_or_on_diag(q_blk_idx, q_blk_size, kv_blk_idx, kv_blk_size,
effective_q_len, effective_kv_len):
return ((q_blk_idx + 1) * q_blk_size - 1) >= (kv_blk_idx * kv_blk_size) - (
effective_kv_len - effective_q_len)


def paged_flash_attention_kernel(
lengths_ref, # [batch_size] jax.Array the length of each example
# 1d vector, results from page_indices.reshape(-1) where originally page_indices.shape=[batch_size, pages_per_sequence]
Expand Down Expand Up @@ -241,7 +264,6 @@ def paged_flash_attention_kernel(
pages_per_sequence: int, # Note [bs, pages_per_sequence] = page_indices.shape
batch_size: int,
num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int,
mask_value: float,
query_len: int,
):
Expand All @@ -256,11 +278,17 @@ def paged_flash_attention_kernel(
b_q_ref, num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim = q_ref.shape
num_kv_heads, total_num_pages, page_size, head_dim = k_pages_hbm_ref.shape
compute_blk_size_kv = page_size * num_kv_pages_per_compute_block
kv_len = lengths_ref[b]
effective_kv_len = lengths_ref[b]
effective_q_len = effective_q_lens_ref[b]

# TODO: think about skip the work when we know the causal mask would mask all (e.g. when the whole kv_blk is after the whole q_blk)
# Get the K and V for the current batch and current kv head.
@pl.when(kv_blk_idx * compute_blk_size_kv < kv_len)
should_run = jnp.logical_and(
kv_blk_idx * compute_blk_size_kv < effective_kv_len,
_block_below_or_on_diag(q_blk_idx, num_queries_per_compute_block,
kv_blk_idx, compute_blk_size_kv, effective_q_len,
effective_kv_len))

@pl.when(should_run)
def get_kv_and_run_flash_attention():
# Loop over num_q_heads_per_kv_head and use the same K and V
def compute_block_indices(b, kv_head_idx, q_blk_idx, kv_blk_idx):
Expand Down Expand Up @@ -289,16 +317,21 @@ def advance_to_next_non_zero_length():
)

def advance_kv_head_idx():
# assumption: kv_blk_idx * compute_blk_size_kv >= lengths_ref[b]
# assumption: kv_blk_idx * compute_blk_size_kv >= lengths_ref[b], or the block is above the causal mask diagonal.
next_kv_head_idx = kv_head_idx + 1
return lax.cond(
q_blk_idx == num_q_blks - 1,
lambda: lax.cond(next_kv_head_idx < num_kv_heads, lambda:
(b, next_kv_head_idx, 0), advance_b), lambda:
(b, kv_head_idx, 0))

return lax.cond(kv_blk_idx * compute_blk_size_kv < lengths_ref[b], lambda:
(b, kv_head_idx, kv_blk_idx), advance_kv_head_idx)
return lax.cond(
jnp.logical_and(
kv_blk_idx * compute_blk_size_kv < lengths_ref[b],
_block_below_or_on_diag(q_blk_idx, num_queries_per_compute_block,
kv_blk_idx, compute_blk_size_kv,
effective_q_lens_ref[b], lengths_ref[b])),
lambda: (b, kv_head_idx, kv_blk_idx), advance_kv_head_idx)

def create_kv_async_copy_descriptors(b, kv_head_idx, kv_blk_idx,
buffer_index):
Expand Down Expand Up @@ -586,7 +619,6 @@ def lm_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_):
pages_per_sequence=pages_per_sequence,
batch_size=batch_size,
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
num_queries_per_compute_block=num_queries_per_compute_block,
mask_value=mask_value,
query_len=query_len),
grid_spec=pltpu.PrefetchScalarGridSpec(
Expand Down
Loading