Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip the computation if the current block lies above the causal mask diagonal. #8448

Merged
merged 6 commits into from
Dec 5, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
skip based on caucal mask
  • Loading branch information
vanbasten23 committed Dec 5, 2024
commit 6bb10fb0e5d6bd8c29fb134d908cc0958e9f3f6d
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)



class MultiPageAsyncCopyDescriptor:
"""Descriptor for async copy of multiple K/V pages from HBM."""

@@ -176,7 +177,7 @@ def start_new_sequence():

alpha = jnp.exp(m_prev - m_next) # Shape [block_q, 128]

l_corr = alpha * l_prev
l_corr = alpha * l_prev # Shape [block_q, 128]

l_next = jnp.sum(p, axis=1)[:, None] + l_corr # Shape [block_q, 128]

@@ -192,16 +193,20 @@ def start_new_sequence():
l_scratch_ref[q_head_idx_per_kv] = l_next
m_scratch_ref[q_head_idx_per_kv] = m_next

l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next)
l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) # [block_q, 128]

acc_scratch_ref[q_head_idx_per_kv] *= l_broadcast(l_corr * l_next_inv_safe)
# Note Matmul operandlhs must have a shape divisible by (16, 1)
o_curr = jax.lax.dot(p.astype(v.dtype), v, preferred_element_type=jnp.float32)
o_curr = jax.lax.dot(p.astype(v.dtype), v, preferred_element_type=jnp.float32) # [block_q, 128]

acc_scratch_ref[q_head_idx_per_kv] += o_curr * l_broadcast(l_next_inv_safe)
# @pl.when(jnp.logical_and(q_blk_idx == 0, kv_blk_idx == 1))
# def _():
# pl.debug_print('xw32 line201')
# blk_debug_blkq_128_ref[:] = acc_scratch_ref[q_head_idx_per_kv]

# The condition comes from the one "@pl.when(kv_blk_idx * compute_blk_size_kv < kv_len)" controlling if we should run the function get_kv_and_run_flash_attention.
# If kv_len=512, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 1.
# If kv_len=513, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 2.
@pl.when(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk) - 1)
#@pl.when(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk)-1)
@pl.when(jnp.logical_or(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk)-1, jnp.logical_not(_below_or_on_diag(q_blk_idx, num_queries_per_compute_block, kv_blk_idx+1, kv_seq_len_per_kv_compute_blk, effective_q_len, kv_len))))
def store_to_output():
o_ref[0, q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype(
o_ref.dtype)
@@ -210,6 +215,18 @@ def store_to_output():
m_ref[0, q_head_idx_per_kv] = m_scratch_ref[q_head_idx_per_kv].astype(
m_ref.dtype)

# partially working
# def _below_or_on_diag(q_blk_idx, q_blk_size, kv_blk_idx, kv_blk_size, effective_q_len, effective_kv_len):
# return ((q_blk_idx+1) * q_blk_size - 1) <= (q_blk_size-1) + (kv_blk_idx * kv_blk_size) + (effective_kv_len - effective_q_len)

def _below_or_on_diag(q_blk_idx, q_blk_size, kv_blk_idx, kv_blk_size, effective_q_len, effective_kv_len):
return ((q_blk_idx+1) * q_blk_size - 1) >= (kv_blk_idx * kv_blk_size) - (effective_kv_len - effective_q_len)

# Original version in FA
# def below_or_on_diag(r, r_blk_size, c, c_blk_size):
# # A block is considered below or on diagonal as long as the bottom left
# # corner of the block is below or on diagonal.
# return ((r + 1) * r_blk_size - 1) > (c * c_blk_size)

def paged_flash_attention_kernel(
lengths_ref, # [batch_size] jax.Array the length of each example
@@ -228,7 +245,6 @@ def paged_flash_attention_kernel(
o_ref,
l_ref,
m_ref,
# scratch space
k_vmem_buffer, # (2, num_kv_pages_per_compute_block, num_kv_heads, head_dim)
k_scales_vmem_buffer,
v_vmem_buffer, # (2, num_kv_pages_per_compute_block, num_kv_heads, head_dim)
@@ -241,7 +257,6 @@ def paged_flash_attention_kernel(
pages_per_sequence: int, # Note [bs, pages_per_sequence] = page_indices.shape
batch_size: int,
num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int,
mask_value: float,
query_len: int,
):
@@ -256,11 +271,15 @@ def paged_flash_attention_kernel(
b_q_ref, num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim = q_ref.shape
num_kv_heads, total_num_pages, page_size, head_dim = k_pages_hbm_ref.shape
compute_blk_size_kv = page_size * num_kv_pages_per_compute_block
kv_len = lengths_ref[b]
effective_kv_len = lengths_ref[b]
effective_q_len = effective_q_lens_ref[b]

# TODO: think about skip the work when we know the causal mask would mask all (e.g. when the whole kv_blk is after the whole q_blk)
# Get the K and V for the current batch and current kv head.
@pl.when(kv_blk_idx * compute_blk_size_kv < kv_len)
should_run = jnp.logical_and(kv_blk_idx * compute_blk_size_kv < effective_kv_len, _below_or_on_diag(q_blk_idx, num_queries_per_compute_block, kv_blk_idx, compute_blk_size_kv, effective_q_len, effective_kv_len))
# pl.debug_print('xw32 line272 kv_blk_idx={}, compute_blk_size_kv={}, (kv_blk_idx * compute_blk_size_kv)={}, effective_kv_len={}', kv_blk_idx, compute_blk_size_kv, (kv_blk_idx * compute_blk_size_kv), effective_kv_len)
# pl.debug_print('xw32 line273 q_blk_idx={}, num_queries_per_compute_block={}, kv_blk_idx={}, compute_blk_size_kv={}, effective_q_len={}, effective_kv_len={}', q_blk_idx, num_queries_per_compute_block, kv_blk_idx, compute_blk_size_kv, effective_q_len, effective_kv_len)
#@pl.when(kv_blk_idx * compute_blk_size_kv < effective_kv_len)
@pl.when(should_run)
def get_kv_and_run_flash_attention():
# Loop over num_q_heads_per_kv_head and use the same K and V
def compute_block_indices(b, kv_head_idx, q_blk_idx, kv_blk_idx):
@@ -289,15 +308,17 @@ def advance_to_next_non_zero_length():
)

def advance_kv_head_idx():
# assumption: kv_blk_idx * compute_blk_size_kv >= lengths_ref[b]
# assumption: kv_blk_idx * compute_blk_size_kv >= lengths_ref[b], or the block is to the right of the mask.
next_kv_head_idx = kv_head_idx + 1
return lax.cond(
q_blk_idx == num_q_blks - 1,
lambda: lax.cond(next_kv_head_idx < num_kv_heads, lambda:
(b, next_kv_head_idx, 0), advance_b), lambda:
(b, kv_head_idx, 0))

return lax.cond(kv_blk_idx * compute_blk_size_kv < lengths_ref[b], lambda:
# return lax.cond(kv_blk_idx * compute_blk_size_kv < lengths_ref[b], lambda:
# (b, kv_head_idx, kv_blk_idx), advance_kv_head_idx)
return lax.cond(jnp.logical_and(kv_blk_idx * compute_blk_size_kv < lengths_ref[b], _below_or_on_diag(q_blk_idx,num_queries_per_compute_block,kv_blk_idx,compute_blk_size_kv, effective_q_lens_ref[b], lengths_ref[b])), lambda:
(b, kv_head_idx, kv_blk_idx), advance_kv_head_idx)

def create_kv_async_copy_descriptors(b, kv_head_idx, kv_blk_idx,
@@ -342,6 +363,8 @@ def prefetch_first_block(): # pylint: disable=unused-variable

next_b, next_kv_head_idx, next_kv_blk_idx = compute_block_indices(
b, kv_head_idx, q_blk_idx, kv_blk_idx + 1)
pl.debug_print('xw32 line354, working on b={}, kv_head_idx={}, q_blk_idx={},kv_blk_idx={}', b, kv_head_idx, q_blk_idx, kv_blk_idx)
pl.debug_print('xw32 line355, working on next_b={}, next_kv_head_idx={}, q_blk_idx={},next_kv_blk_idx={}', next_b, next_kv_head_idx, q_blk_idx, next_kv_blk_idx)

@pl.when(next_b < batch_size)
def prefetch_next_block(): # pylint: disable=unused-variable
@@ -369,7 +392,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable
o_ref, # [1, num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim]
l_ref,
m_ref,
l_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE]
l_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE]
m_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, MIN_BLOCK_SIZE]
acc_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim]
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
@@ -586,7 +609,6 @@ def lm_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_):
pages_per_sequence=pages_per_sequence,
batch_size=batch_size,
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
num_queries_per_compute_block=num_queries_per_compute_block,
mask_value=mask_value,
query_len=query_len),
grid_spec=pltpu.PrefetchScalarGridSpec(