Skip to content

Commit

Permalink
Version 0.0.14 Adding GPU Special Funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Nov 27, 2023
1 parent 45ed828 commit 6c9b246
Show file tree
Hide file tree
Showing 9 changed files with 2,902 additions and 9 deletions.
2 changes: 1 addition & 1 deletion fjformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@
count_num_params
)

__version__ = '0.0.13'
__version__ = '0.0.14'
11 changes: 4 additions & 7 deletions fjformer/attention/flash_attention_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_kernel_single_batch_single_step

DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
NUM_LANES = 128
Expand Down Expand Up @@ -271,7 +272,6 @@ def _flash_attention_kernel(q_idx_chunk_start, k_idx_chunk_start, q_tile_ref, *a
block_b = q_tile_ref.shape[0]
# If we're not going to tile the softmax, then we can avoid a bunch of VPU ops.
if kwargs["block_k"] == kwargs["kv_seq_len"]:
assert False
kernel = _flash_attention_kernel_single_batch_single_step
else:
kernel = _flash_attention_kernel_single_batch
Expand Down Expand Up @@ -551,11 +551,9 @@ def lm_index_map(batch_index, head_index, q_seq_index, _, q_idx_ref, k_idx_ref):
q_segment_ids_spec = kv_segment_ids_spec = None
q_segment_ids = kv_segment_ids = None
if segment_ids is not None:
assert False

def q_segment_ids_index_map(batch_index, head_index, q_seq_index, _):
del head_index
return (batch_index, q_seq_index, 0)
return batch_index, q_seq_index, 0

def kv_segment_ids_index_map(
batch_index, head_index, q_seq_index, kv_seq_index
Expand Down Expand Up @@ -1241,11 +1239,10 @@ def ab_index_map(batch_index, head_index, q_seq_index, kv_seq_index, q_idx_ref,
q_segment_ids_spec = kv_segment_ids_spec = None
q_segment_ids = kv_segment_ids = None
if segment_ids is not None:
assert False

def q_segment_ids_index_map(batch_index, head_index, q_seq_index, _):
del head_index
return (batch_index, q_seq_index, 0)
return batch_index, q_seq_index, 0

def kv_segment_ids_index_map(
batch_index, head_index, q_seq_index, kv_seq_index
Expand All @@ -1263,7 +1260,7 @@ def kv_segment_ids_index_map(
)
else:
next_kv_index = kv_seq_index
return (batch_index, 0, next_kv_index)
return batch_index, 0, next_kv_index

q_segment_ids_spec = pl.BlockSpec(
q_segment_ids_index_map, (1, block_q_major, NUM_LANES)
Expand Down
Loading

0 comments on commit 6c9b246

Please sign in to comment.