Skip to content

Commit

Permalink
Adding and editing efficient_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Nov 3, 2023
1 parent a55c629 commit 4fe8033
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 45 deletions.
2 changes: 1 addition & 1 deletion fjformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .attention import (dot_product_attention_multiquery, dot_product_attention_multihead,
dot_product_attention_queries_per_head, blockwise_dot_product_attention)
dot_product_attention_queries_per_head, efficient_attention)
from .load import (
load_and_convert_checkpoint_to_torch, float_tensor_to_dtype, read_ckpt, save_ckpt, StreamingCheckpointer
)
Expand Down
2 changes: 1 addition & 1 deletion fjformer/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .blockwise_attention import blockwise_dot_product_attention
from .efficient_attention import efficient_attention
from .flash_attention_0 import dot_product_attention_multihead, dot_product_attention_multiquery, \
dot_product_attention_queries_per_head
Original file line number Diff line number Diff line change
@@ -1,50 +1,92 @@
import functools
from typing import NamedTuple
import jax
import jax.lax as lax
import jax.numpy as jnp
from functools import partial
from typing import NamedTuple
from einops import rearrange
import chex
from jax import lax

'''
Compute attention blockwise without materializing the full attention matrix, initially proposed in https://arxiv.org/abs/2112.05682 Rabe et al. 2021;
https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA efficient implementation;
https://arxiv.org/abs/2305.19370 Liu et al. 2023 proposes blockwise computing both attention and FFN, as well as loss function, enabling 4x longer sequences.
'''

class Carry(NamedTuple):
numerator: chex.Array
denominator: chex.Array
max_so_far: chex.Array


def blockwise_dot_product_attention(query, key, value, bias, deterministic,
dropout_rng, attn_pdrop, causal, query_chunk_size,
key_chunk_size, dtype, policy, precision, float32_logits):
def efficient_attention(
query: chex.Array,
key: chex.Array,
value: chex.Array,
bias: chex.Array = None,
deterministic: bool = True,
dropout_rng: chex.PRNGKey = None,
attention_drop_rate: float = 0.0,
causal: bool = True,
query_chunk_size: int = 1024,
key_chunk_size: int = 1024,
dtype: chex.ArrayDType = jnp.float32,
policy=jax.checkpoint_policies.nothing_saveable(),
precision=None,
float32_logits: bool = True,
prevent_cse: bool = True,
):
"""
:param query: Array Shape [batch,Q Sequence length,num attention heads, head dims]
:param key: Array Shape [batch,KV Sequence length,num KV attention heads, head dims]
:param value: Array Shape [batch,KV Sequence length,num KV attention heads, head dims]
:param bias: Bias To be added
:param deterministic: bool (whenever use dropout or no)
:param dropout_rng: RNG Dropout
:param attention_drop_rate:
:param causal: Is Decoder or Causal
:param query_chunk_size: Chunk size used for query
:param key_chunk_size: Chunk size used for key
:param dtype: DataType
:param policy: Gradient Checkpoint Policy
:param precision: PrecisionLike
:param float32_logits:
:param prevent_cse:
:return:
"""
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
q_len = query.shape[1]
kv_len = key.shape[1]
if float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
query = rearrange(query, 'b (c n) h d -> n b c h d', c=query_chunk_size)
key, value = map(lambda t: rearrange(t, 'b (c n) h d -> n b c h d', c=key_chunk_size), (key, value))
num_q, batch, _, num_heads, dim_per_head = query.shape
num_kv = key.shape[0]

for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
assert bias_dim == 1 or bias_dim == broadcast_dim
if not deterministic and attn_pdrop > 0.0:

batch, q_len, num_heads, dim_per_head = query.shape
batch, kv_len, kv_heads, dim_per_head = key.shape
batch, kv_len, kv_heads, dim_per_head = value.shape

num_q = q_len // query_chunk_size
num_kv = kv_len // key_chunk_size
query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
key = key.reshape((batch, num_kv, key_chunk_size, kv_heads, dim_per_head))
value = value.reshape((batch, num_kv, key_chunk_size, kv_heads, dim_per_head))

query = jnp.moveaxis(query, 1, 0)
key = jnp.moveaxis(key, 1, 0)
value = jnp.moveaxis(value, 1, 0)

if bias is not None:
for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
assert bias_dim == 1 or bias_dim == broadcast_dim
if not deterministic and attention_drop_rate > 0.0:
attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
attn_dropout = jax.random.bernoulli(attn_dropout_rng, attention_drop_rate, (batch, num_heads, q_len, kv_len))
else:
attn_dropout = None

_chunk_bias_fn = partial(
_chunk_bias_fn = functools.partial(
_chunk_attention_bias,
query_chunk_size, key_chunk_size, bias, deterministic,
attn_dropout, attn_pdrop, causal, dtype)
attn_dropout, attention_drop_rate, causal, dtype)

def _query_chunk_attention(args):
def scan_attention(args):
query_chunk, query_chunk_idx = args

@partial(jax.checkpoint, prevent_cse=False, policy=policy)
def summarize_chunk(carry, args):
@functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
def scan_kv_block(carry, args):
key_chunk, value_chunk, key_chunk_idx = args
(numerator, denominator, prev_max_score) = carry
attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision)
Expand All @@ -57,7 +99,7 @@ def summarize_chunk(carry, args):
max_score = jax.lax.stop_gradient(max_score)
exp_weights = jnp.exp(attn_weights - max_score)
exp_values = jnp.einsum(
'bqhv,bvhf->bqhf', exp_weights, value_chunk, precision=precision
'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision
)
correction = jnp.exp(prev_max_score - max_score)
numerator = numerator * correction + exp_values
Expand All @@ -72,7 +114,7 @@ def skip_upper_half(carry, args):
return jax.lax.cond(
skip_block,
lambda carry, args: (carry, None),
summarize_chunk,
scan_kv_block,
carry,
args,
)
Expand All @@ -89,22 +131,25 @@ def skip_upper_half(carry, args):
return outputs

_, res = lax.scan(
lambda _, x: ((), _query_chunk_attention(x)),
lambda _, x: ((), scan_attention(x)),
(), xs=(query, jnp.arange(0, num_q))
)
res = rearrange(res, 'n b c h d -> b (n c) h d')
return res


class Carry(NamedTuple):
numerator: chex.Array
denominator: chex.Array
max_so_far: chex.Array


def _chunk_attention_bias(query_chunk_size, key_chunk_size,
bias, deterministic, attn_dropout, attn_pdrop, causal,
dtype, query_chunk_idx, key_chunk_idx):
def _chunk_attention_bias(
query_chunk_size: int,
key_chunk_size: int,
bias: chex.Array,
deterministic: bool,
attn_dropout: chex.Array,
attention_drop_rate: float,
causal: bool,
dtype: chex.ArrayDType,
query_chunk_idx: int,
key_chunk_idx: int
):
query_offset = query_chunk_idx * query_chunk_size
key_offset = key_chunk_idx * key_chunk_size
chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
Expand All @@ -123,7 +168,7 @@ def _chunk_attention_bias(query_chunk_size, key_chunk_size,
causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)

if not deterministic and attn_pdrop > 0.0:
if not deterministic and attention_drop_rate > 0.0:
attn_dropout_slice = lax.dynamic_slice(
attn_dropout,
start_indices=(0, 0, query_offset, key_offset),
Expand Down
2 changes: 1 addition & 1 deletion fjformer/func/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._func import average_metrics, global_norm, transpose
from ._func import average_metrics, global_norm, transpose, fused_softmax
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

setuptools.setup(
name="fjformer",
version='0.0.0',
version='0.0.1',
author="Erfan Zare Chavoshi",
author_email="[email protected]",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/erfanzar/",
packages=setuptools.find_packages('fjformer'),
packages=setuptools.find_packages(),
install_requires=[
"numpy",
"jax>=0.4.10",
Expand Down

0 comments on commit 4fe8033

Please sign in to comment.