Skip to content

Commit

Permalink
rebased
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed Dec 5, 2024
1 parent 0869e99 commit 56064e7
Showing 1 changed file with 10 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,12 @@ def start_new_sequence():

q_index = q_blk_idx * num_queries_per_compute_block
kv_index = kv_blk_idx * kv_seq_len_per_kv_compute_blk
kv_len = lengths_ref[b]
effective_kv_len = lengths_ref[b]
effective_q_len = effective_q_lens_ref[b]
row_ids = (kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota(
jnp.int32,
(num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 0)
row_ids = (
effective_kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota(
jnp.int32,
(num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 0)
col_ids = kv_index + jax.lax.broadcasted_iota(
jnp.int32,
(num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 1)
Expand Down Expand Up @@ -203,12 +204,15 @@ def start_new_sequence():

acc_scratch_ref[q_head_idx_per_kv] += o_curr * l_broadcast(l_next_inv_safe)

is_last_kv_blk_idx = kv_blk_idx == pl.cdiv(kv_len,
# The condition comes from the check controlling if we should run the function get_kv_and_run_flash_attention.
# If kv_len=512, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 1.
# If kv_len=513, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 2.
is_last_kv_blk_idx = kv_blk_idx == pl.cdiv(effective_kv_len,
kv_seq_len_per_kv_compute_blk) - 1
is_next_kv_blk_masked_out = jnp.logical_not(
_block_below_or_on_diag(q_blk_idx, num_queries_per_compute_block,
kv_blk_idx + 1, kv_seq_len_per_kv_compute_blk,
effective_q_len, kv_len))
effective_q_len, effective_kv_len))

@pl.when(jnp.logical_or(is_last_kv_blk_idx, is_next_kv_blk_masked_out))
def store_to_output():
Expand Down

0 comments on commit 56064e7

Please sign in to comment.