Skip to content

Commit

Permalink
Integrates ragged attention to JetStream Pytorch (AI-Hypercomputer#93)
Browse files Browse the repository at this point in the history
* Stable version of ragged attention.

* Converts the attention output types the same as q.

* Fixes the typo for the ragged attention.

* Provides the default value for partition_by_axis.

* Provides mesh to the shard_map.

* Fixes typo.

* Fixes typo, should be start instead of start_pos.

* Should use "//" instead of "/" to get int results.

* Use block size // 2 as the starting current position for better initial performance. Fix the typo that should use jax.lax.div instead of jnp.div

* Updates the run_interactive script to use the correct result token processing API from JetStream.

* Fix typo, should use token_utils.process_result_token.

* Fix typo.

* Fixes the sampled tokens list.

* Use text_tokens_to_str to convert the output tokens.

* Reshape the precomputed grid indices to 1D. Removes the
dense_attention_quantized and use option to control
if it's quantization or not. Use the new torch_xla2 API.

* Should check if X is None instead of if X

* Fix the dense_attention not returning data.

* Reshape the kv scaler to 3 dim for ragged attention.

* Cannot stop the input_pos counter from increasing since we are using a ring buffer. Will cause error.

* Adds starting_position and profiling_prefill for better testing and benchmarking.

* Move flags in scripts to a common function (AI-Hypercomputer#92)

* refactor flags

* clean up:

* fix run_server

* move common flags to global

* format

* update

* udpate readme

* update run_interactive

* Stable version of ragged attention.

* Fix the merge conflicts

* Fixes the missing pieces after merging conflicts. Adds couple of new flags for debugging and performance tuning.

* Integrates ragged attention to Gemma too.

* Somehow have some local changes to run_interactive, reverting them to align with main.

* Set the default value for the newly added parameters.

* Adds more descriptions to the ragged attention index precompuation function.

* Merges the quantized ragged attention kernel with the non quantized version.

* Moves the attention calculation to attention.py for better code structure.

* Fix run issues refactoring.

* Fix the quantized version for ragged attention.

* Fix test_attention by adding default value for the newly added arguments. The error message is missing positional arguments.

* Fixes unit tests, changes the Transformer model call argument order(input_pos)  back to original to avoid unnecessary issues.

* Format attention_kernel.py

* Add descrpitions to ragged attention outputs.

* Fix quantization tests by adding default value to quantization kernel class.

* Reformat attention_kernel.py. Format with black doesn't comply with the pylink rules.

* Ignores R0913: Too many arguments link error for ragged attention kernel. Fix other lint errors.

* Ignore R0903: Too few public methods. Fix lint errors.

* Fix the rest of the lint errors.

---------

Co-authored-by: Siyuan Liu <[email protected]>
  • Loading branch information
wang2yn84 and lsy323 authored May 23, 2024
1 parent 65c39d4 commit 517d847
Show file tree
Hide file tree
Showing 9 changed files with 679 additions and 66 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[MESSAGES CONTROL]
disable=C0114,R0801,E1102,W0613,R1711,too-many-locals
disable=C0114,R0801,R0903,R0913,E1102,W0613,R1711,too-many-locals
347 changes: 347 additions & 0 deletions jetstream_pt/attention_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
import functools
import math

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.shard_map import shard_map

import torch
import torch.nn.functional as F

import numpy as np

DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)


def ragged_flash_attention_kernel(
start_ref,
end_ref,
line_end_ref,
pre_b_ref,
pre_i_ref,
q_ref,
k_ref,
v_ref,
k_scaler_ref,
v_scaler_ref,
o_ref, # outputs
m_ref, # row max
l_ref, # propogation coefficient
bk: int,
mask_value: float,
normalize_var: bool,
quantized: bool,
):
"""Pallas kernel for flash attention."""
with jax.named_scope("attention_kernel"):
b, i = pl.program_id(0), pl.program_id(1)

@pl.when(i == 0)
def init():
with jax.named_scope("init"):
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
l_ref[...] = jnp.zeros_like(l_ref)
o_ref[...] = jnp.zeros_like(o_ref)

length = line_end_ref[b]
start = start_ref[b]
end = end_ref[b]

@pl.when(jnp.logical_and(i * bk < length, start != end))
def run():
with jax.named_scope("run_qk"):
q = q_ref[...].astype(jnp.float32)
k = k_ref[...].astype(jnp.float32)
v = v_ref[...].astype(jnp.float32)
m_prev, l_prev = m_ref[...], l_ref[...]

qk = jax.lax.dot_general(
q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32
)
if normalize_var:
qk = qk / jnp.sqrt(k.shape[-1])
if quantized:
qk = qk * k_scaler_ref[...]
with jax.named_scope("run_mask"):
start = start_ref[b]
end = end_ref[b]
iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1)
mask_start_lt_end = jnp.logical_and(
i * bk + iota >= start, i * bk + iota < end
).astype(jnp.int32)
mask_start_gt_end = jnp.logical_or(
i * bk + iota >= start, i * bk + iota < end
).astype(jnp.int32)
# mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end)
mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end)

qk = qk + jnp.where(mask, 0.0, mask_value)

with jax.named_scope("run_softmax"):
m_curr = qk.max(axis=-1)

s_curr = jnp.exp(qk - m_curr[..., None])

l_curr = jax.lax.broadcast_in_dim(
s_curr.sum(axis=-1), l_prev.shape, (0,)
)
if quantized:
s_curr = s_curr * v_scaler_ref[...]
o_curr_times_l_curr = jnp.dot(s_curr, v)
m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,))
m_next = jnp.maximum(m_prev, m_curr)
alpha = jnp.exp(m_prev - m_next)
beta = jnp.exp(m_curr - m_next)
l_next = alpha * l_prev + beta * l_curr
l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)

m_ref[...], l_ref[...] = m_next, l_next_safe
o_ref[...] = (
(l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr)
/ l_next_safe
).astype(o_ref.dtype)


@functools.partial(
jax.jit, static_argnames=["bk", "mask_value", "normalize_var"]
)
def ragged_mqa(
q: jax.Array,
k: jax.Array,
v: jax.Array,
start: jax.Array,
end: jax.Array,
k_scaler: jax.Array | None = None,
v_scaler: jax.Array | None = None,
ragged_batch_index=None,
ragged_block_index=None,
bk: int = 512,
mask_value: float = DEFAULT_MASK_VALUE,
normalize_var: bool = True,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
"""Ragged multi query attention."""
with jax.named_scope("ragged_mqa"):
batch_size, num_heads, head_dim = q.shape
seq_len = k.shape[1]

def kv_index_map(
b,
i,
start_ref,
end_ref,
line_end_ref,
ragged_batch_index_ref,
ragged_block_index_ref,
):
index = b * (seq_len // bk) + i
return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0

def q_index_map(
b,
i,
start_ref,
end_ref,
line_end_ref,
ragged_batch_index_ref,
ragged_block_index_ref,
):
index = b * (seq_len // bk) + i
return ragged_batch_index_ref[index], 0, 0

def scaler_index_map(b, i, *_):
return b, 0, i

line_end = jnp.where(start < end, end, seq_len - 1)

in_specs = [
pl.BlockSpec(q_index_map, (None, num_heads, head_dim)),
pl.BlockSpec(kv_index_map, (None, bk, head_dim)),
pl.BlockSpec(kv_index_map, (None, bk, head_dim)),
]
inputs = (
start,
end,
line_end,
ragged_batch_index,
ragged_block_index,
q,
k,
v,
)
quantized = False
if k_scaler is not None:
in_specs = in_specs + [
pl.BlockSpec(scaler_index_map, (None, 1, bk)),
pl.BlockSpec(scaler_index_map, (None, 1, bk)),
]
inputs = inputs + (k_scaler, v_scaler)
quantized = True

out, m, l = pl.pallas_call(
functools.partial(
ragged_flash_attention_kernel,
bk=bk,
mask_value=mask_value,
normalize_var=normalize_var,
quantized=quantized,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=5,
in_specs=in_specs,
out_specs=[
pl.BlockSpec(q_index_map, (None, num_heads, head_dim)),
pl.BlockSpec(q_index_map, (None, num_heads, head_dim)),
pl.BlockSpec(q_index_map, (None, num_heads, head_dim)),
],
grid=(batch_size, seq_len // bk),
),
compiler_params={"dimension_semantics": ("parallel", "arbitrary")},
out_shape=[
q,
jax.ShapeDtypeStruct(
(batch_size, num_heads, head_dim), jnp.float32
),
jax.ShapeDtypeStruct(
(batch_size, num_heads, head_dim), jnp.float32
),
],
)(*inputs)
return out, (m[..., 0], l[..., 0])


@functools.partial(
jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "shard_axis"]
)
def ragged_mha(
q: jax.Array,
k: jax.Array,
v: jax.Array,
start: jax.Array,
end: jax.Array,
ragged_batch_index: jax.Array,
ragged_block_index: jax.Array,
k_scaler: jax.Array | None = None,
v_scaler: jax.Array | None = None,
bk: int = 512,
mask_value: float = DEFAULT_MASK_VALUE,
normalize_var: bool = True,
shard_axis: int = 1,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
"""Ragged multi head attention.
Args:
q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array.
k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or
PartitionQuantizedTensor.
v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or
PartitionQuantizedTensor.
start: A i32[batch_size] jax.Array
end: A i32[batch_size] jax.Array
bk: An integer that is the sequence block size.
logit_cap: An optional float that caps logits via tanh. By default there is
no logit capping.
mask_value: The value used for padding in attention. By default it is a very
negative floating point number.
out_dtype: An optional dtype for the output. If not provided, the output
dtype will be q's dtype.
Returns:
The output of attention([batch_size, num_heads, compute_dim, head_dim]),
along with the max logit ([batch_size, num_heads, compute_dim, 1]) and
softmax denominator ([batch_size, num_heads, compute_dim, 1]).
"""
mask_value = DEFAULT_MASK_VALUE
if k_scaler is None:
replicated_in_axes = 4
replicated_inputs = (ragged_batch_index, ragged_block_index)
else:
replicated_in_axes = 6
replicated_inputs = (
jnp.squeeze(k_scaler, -1),
jnp.squeeze(v_scaler, -1),
ragged_batch_index,
ragged_block_index,
)

with jax.named_scope("ragged_mha_vmap"):
out, (m, l) = jax.vmap(
functools.partial(
ragged_mqa,
bk=bk,
mask_value=mask_value,
normalize_var=normalize_var,
# out_dtype=out_dtype,
),
in_axes=(
shard_axis,
shard_axis,
shard_axis,
*([None] * replicated_in_axes),
),
out_axes=shard_axis,
)(q, k, v, start, end, *replicated_inputs)
return out, (m, l)


def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None):
"""The vanilla attention kernel implementation."""

bsz, _, _, head_dim = xq.shape
with jax.named_scope("attn_mat1"):
## Attention start
# scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim)
scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim)
if k_scaler is not None:
scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2]))
if mask is not None:
# if mask.shape != (1,1,16,16):
# breakpoint()
scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen)
with jax.named_scope("attn_soft"):
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
if v_scaler is not None:
scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2]))

with jax.named_scope("attn_mat2"):
# output = torch.einsum(
# "ikjm,ikml->ikjl", scores, values
# ) # (bs, n_local_heads, seqlen, head_dim)
output = torch.einsum("ikjm,ikml->ikjl", scores, values)
return output


class RaggedAttentionKernel:
"""Ragged attention kernel."""

def __init__(self, env, input_specs, output_specs, sharding_axis):
self.binded_ragged_mha = functools.partial(
ragged_mha, bk=env.block_size, shard_axis=sharding_axis
)
self.binded_ragged_mha = shard_map(
ragged_mha, env.mesh, input_specs, output_specs, check_rep=False
)
self.binded_ragged_mha = jax.jit(self.binded_ragged_mha)

def __call__(
self,
xq,
keys,
values,
start,
end,
ragged_batch_index,
ragged_block_index,
k_scaler=None,
v_scaler=None,
):
return self.binded_ragged_mha(
xq,
keys,
values,
start,
end,
ragged_batch_index,
ragged_block_index,
k_scaler,
v_scaler,
)
22 changes: 22 additions & 0 deletions jetstream_pt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,26 @@
lambda value: value in _VALID_QUANTIZATION_TYPE,
f"quantize_type is invalid, supported quantization types are {_VALID_QUANTIZATION_TYPE}",
)
flags.DEFINE_bool(
"profiling_prefill",
False,
"Whether to profile the prefill, "
"if set to false, profile generate function only",
required=False,
)
flags.DEFINE_bool(
"ragged_mha",
False,
"Whether to enable Ragged multi head attention",
required=False,
)
flags.DEFINE_integer(
"starting_position",
512,
"The starting position of decoding, "
"for performance tuning and debugging only",
required=False,
)


def create_quantization_config_from_flags():
Expand Down Expand Up @@ -112,6 +132,8 @@ def create_engine_from_config_flags():
max_cache_length=FLAGS.max_cache_length,
sharding_config=sharding_file_name,
shard_on_batch=FLAGS.shard_on_batch,
ragged_mha=FLAGS.ragged_mha,
starting_position=FLAGS.starting_position,
)

print("Initialize engine", time.perf_counter() - start)
Expand Down
Loading

0 comments on commit 517d847

Please sign in to comment.