From 4fe8033460d6d64b58c8ad8f3f428e6c12fdb253 Mon Sep 17 00:00:00 2001 From: erfanzar Date: Fri, 3 Nov 2023 16:33:22 +0330 Subject: [PATCH] Adding and editing `efficient_attention` --- fjformer/__init__.py | 2 +- fjformer/attention/__init__.py | 2 +- ...se_attention.py => efficient_attention.py} | 125 ++++++++++++------ fjformer/func/__init__.py | 2 +- setup.py | 4 +- 5 files changed, 90 insertions(+), 45 deletions(-) rename fjformer/attention/{blockwise_attention.py => efficient_attention.py} (57%) diff --git a/fjformer/__init__.py b/fjformer/__init__.py index 6b45269..7ba2253 100644 --- a/fjformer/__init__.py +++ b/fjformer/__init__.py @@ -1,5 +1,5 @@ from .attention import (dot_product_attention_multiquery, dot_product_attention_multihead, - dot_product_attention_queries_per_head, blockwise_dot_product_attention) + dot_product_attention_queries_per_head, efficient_attention) from .load import ( load_and_convert_checkpoint_to_torch, float_tensor_to_dtype, read_ckpt, save_ckpt, StreamingCheckpointer ) diff --git a/fjformer/attention/__init__.py b/fjformer/attention/__init__.py index e5c905c..1d5f7dc 100644 --- a/fjformer/attention/__init__.py +++ b/fjformer/attention/__init__.py @@ -1,3 +1,3 @@ -from .blockwise_attention import blockwise_dot_product_attention +from .efficient_attention import efficient_attention from .flash_attention_0 import dot_product_attention_multihead, dot_product_attention_multiquery, \ dot_product_attention_queries_per_head diff --git a/fjformer/attention/blockwise_attention.py b/fjformer/attention/efficient_attention.py similarity index 57% rename from fjformer/attention/blockwise_attention.py rename to fjformer/attention/efficient_attention.py index 6309d31..a27eca2 100644 --- a/fjformer/attention/blockwise_attention.py +++ b/fjformer/attention/efficient_attention.py @@ -1,50 +1,92 @@ +import functools +from typing import NamedTuple import jax +import jax.lax as lax import jax.numpy as jnp -from functools import partial -from typing import NamedTuple from einops import rearrange import chex -from jax import lax -''' -Compute attention blockwise without materializing the full attention matrix, initially proposed in https://arxiv.org/abs/2112.05682 Rabe et al. 2021; -https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA efficient implementation; -https://arxiv.org/abs/2305.19370 Liu et al. 2023 proposes blockwise computing both attention and FFN, as well as loss function, enabling 4x longer sequences. -''' + +class Carry(NamedTuple): + numerator: chex.Array + denominator: chex.Array + max_so_far: chex.Array -def blockwise_dot_product_attention(query, key, value, bias, deterministic, - dropout_rng, attn_pdrop, causal, query_chunk_size, - key_chunk_size, dtype, policy, precision, float32_logits): +def efficient_attention( + query: chex.Array, + key: chex.Array, + value: chex.Array, + bias: chex.Array = None, + deterministic: bool = True, + dropout_rng: chex.PRNGKey = None, + attention_drop_rate: float = 0.0, + causal: bool = True, + query_chunk_size: int = 1024, + key_chunk_size: int = 1024, + dtype: chex.ArrayDType = jnp.float32, + policy=jax.checkpoint_policies.nothing_saveable(), + precision=None, + float32_logits: bool = True, + prevent_cse: bool = True, +): + """ + + :param query: Array Shape [batch,Q Sequence length,num attention heads, head dims] + :param key: Array Shape [batch,KV Sequence length,num KV attention heads, head dims] + :param value: Array Shape [batch,KV Sequence length,num KV attention heads, head dims] + :param bias: Bias To be added + :param deterministic: bool (whenever use dropout or no) + :param dropout_rng: RNG Dropout + :param attention_drop_rate: + :param causal: Is Decoder or Causal + :param query_chunk_size: Chunk size used for query + :param key_chunk_size: Chunk size used for key + :param dtype: DataType + :param policy: Gradient Checkpoint Policy + :param precision: PrecisionLike + :param float32_logits: + :param prevent_cse: + :return: + """ query = query / jnp.sqrt(query.shape[-1]).astype(dtype) - q_len = query.shape[1] - kv_len = key.shape[1] if float32_logits: query = query.astype(jnp.float32) key = key.astype(jnp.float32) - query = rearrange(query, 'b (c n) h d -> n b c h d', c=query_chunk_size) - key, value = map(lambda t: rearrange(t, 'b (c n) h d -> n b c h d', c=key_chunk_size), (key, value)) - num_q, batch, _, num_heads, dim_per_head = query.shape - num_kv = key.shape[0] - - for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)): - assert bias_dim == 1 or bias_dim == broadcast_dim - if not deterministic and attn_pdrop > 0.0: + + batch, q_len, num_heads, dim_per_head = query.shape + batch, kv_len, kv_heads, dim_per_head = key.shape + batch, kv_len, kv_heads, dim_per_head = value.shape + + num_q = q_len // query_chunk_size + num_kv = kv_len // key_chunk_size + query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head)) + key = key.reshape((batch, num_kv, key_chunk_size, kv_heads, dim_per_head)) + value = value.reshape((batch, num_kv, key_chunk_size, kv_heads, dim_per_head)) + + query = jnp.moveaxis(query, 1, 0) + key = jnp.moveaxis(key, 1, 0) + value = jnp.moveaxis(value, 1, 0) + + if bias is not None: + for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)): + assert bias_dim == 1 or bias_dim == broadcast_dim + if not deterministic and attention_drop_rate > 0.0: attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng) - attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len)) + attn_dropout = jax.random.bernoulli(attn_dropout_rng, attention_drop_rate, (batch, num_heads, q_len, kv_len)) else: attn_dropout = None - _chunk_bias_fn = partial( + _chunk_bias_fn = functools.partial( _chunk_attention_bias, query_chunk_size, key_chunk_size, bias, deterministic, - attn_dropout, attn_pdrop, causal, dtype) + attn_dropout, attention_drop_rate, causal, dtype) - def _query_chunk_attention(args): + def scan_attention(args): query_chunk, query_chunk_idx = args - @partial(jax.checkpoint, prevent_cse=False, policy=policy) - def summarize_chunk(carry, args): + @functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy) + def scan_kv_block(carry, args): key_chunk, value_chunk, key_chunk_idx = args (numerator, denominator, prev_max_score) = carry attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision) @@ -57,7 +99,7 @@ def summarize_chunk(carry, args): max_score = jax.lax.stop_gradient(max_score) exp_weights = jnp.exp(attn_weights - max_score) exp_values = jnp.einsum( - 'bqhv,bvhf->bqhf', exp_weights, value_chunk, precision=precision + 'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision ) correction = jnp.exp(prev_max_score - max_score) numerator = numerator * correction + exp_values @@ -72,7 +114,7 @@ def skip_upper_half(carry, args): return jax.lax.cond( skip_block, lambda carry, args: (carry, None), - summarize_chunk, + scan_kv_block, carry, args, ) @@ -89,22 +131,25 @@ def skip_upper_half(carry, args): return outputs _, res = lax.scan( - lambda _, x: ((), _query_chunk_attention(x)), + lambda _, x: ((), scan_attention(x)), (), xs=(query, jnp.arange(0, num_q)) ) res = rearrange(res, 'n b c h d -> b (n c) h d') return res -class Carry(NamedTuple): - numerator: chex.Array - denominator: chex.Array - max_so_far: chex.Array - - -def _chunk_attention_bias(query_chunk_size, key_chunk_size, - bias, deterministic, attn_dropout, attn_pdrop, causal, - dtype, query_chunk_idx, key_chunk_idx): +def _chunk_attention_bias( + query_chunk_size: int, + key_chunk_size: int, + bias: chex.Array, + deterministic: bool, + attn_dropout: chex.Array, + attention_drop_rate: float, + causal: bool, + dtype: chex.ArrayDType, + query_chunk_idx: int, + key_chunk_idx: int +): query_offset = query_chunk_idx * query_chunk_size key_offset = key_chunk_idx * key_chunk_size chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype) @@ -123,7 +168,7 @@ def _chunk_attention_bias(query_chunk_size, key_chunk_size, causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape) - if not deterministic and attn_pdrop > 0.0: + if not deterministic and attention_drop_rate > 0.0: attn_dropout_slice = lax.dynamic_slice( attn_dropout, start_indices=(0, 0, query_offset, key_offset), diff --git a/fjformer/func/__init__.py b/fjformer/func/__init__.py index f3956f8..3dec5e3 100644 --- a/fjformer/func/__init__.py +++ b/fjformer/func/__init__.py @@ -1 +1 @@ -from ._func import average_metrics, global_norm, transpose \ No newline at end of file +from ._func import average_metrics, global_norm, transpose, fused_softmax diff --git a/setup.py b/setup.py index 664428d..a5caa24 100644 --- a/setup.py +++ b/setup.py @@ -7,13 +7,13 @@ setuptools.setup( name="fjformer", - version='0.0.0', + version='0.0.1', author="Erfan Zare Chavoshi", author_email="erfanzare82@yahoo.com", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/erfanzar/", - packages=setuptools.find_packages('fjformer'), + packages=setuptools.find_packages(), install_requires=[ "numpy", "jax>=0.4.10",