diff --git a/fjformer/__init__.py b/fjformer/__init__.py index 82b2aea..3fb3940 100644 --- a/fjformer/__init__.py +++ b/fjformer/__init__.py @@ -50,4 +50,4 @@ count_num_params ) -__version__ = '0.0.13' +__version__ = '0.0.14' diff --git a/fjformer/attention/flash_attention_tpu.py b/fjformer/attention/flash_attention_tpu.py index a47e440..017b046 100644 --- a/fjformer/attention/flash_attention_tpu.py +++ b/fjformer/attention/flash_attention_tpu.py @@ -24,6 +24,7 @@ from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp +from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_kernel_single_batch_single_step DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) NUM_LANES = 128 @@ -271,7 +272,6 @@ def _flash_attention_kernel(q_idx_chunk_start, k_idx_chunk_start, q_tile_ref, *a block_b = q_tile_ref.shape[0] # If we're not going to tile the softmax, then we can avoid a bunch of VPU ops. if kwargs["block_k"] == kwargs["kv_seq_len"]: - assert False kernel = _flash_attention_kernel_single_batch_single_step else: kernel = _flash_attention_kernel_single_batch @@ -551,11 +551,9 @@ def lm_index_map(batch_index, head_index, q_seq_index, _, q_idx_ref, k_idx_ref): q_segment_ids_spec = kv_segment_ids_spec = None q_segment_ids = kv_segment_ids = None if segment_ids is not None: - assert False - def q_segment_ids_index_map(batch_index, head_index, q_seq_index, _): del head_index - return (batch_index, q_seq_index, 0) + return batch_index, q_seq_index, 0 def kv_segment_ids_index_map( batch_index, head_index, q_seq_index, kv_seq_index @@ -1241,11 +1239,10 @@ def ab_index_map(batch_index, head_index, q_seq_index, kv_seq_index, q_idx_ref, q_segment_ids_spec = kv_segment_ids_spec = None q_segment_ids = kv_segment_ids = None if segment_ids is not None: - assert False def q_segment_ids_index_map(batch_index, head_index, q_seq_index, _): del head_index - return (batch_index, q_seq_index, 0) + return batch_index, q_seq_index, 0 def kv_segment_ids_index_map( batch_index, head_index, q_seq_index, kv_seq_index @@ -1263,7 +1260,7 @@ def kv_segment_ids_index_map( ) else: next_kv_index = kv_seq_index - return (batch_index, 0, next_kv_index) + return batch_index, 0, next_kv_index q_segment_ids_spec = pl.BlockSpec( q_segment_ids_index_map, (1, block_q_major, NUM_LANES) diff --git a/fjformer/attention/jax_flash_attn_tpu.py b/fjformer/attention/jax_flash_attn_tpu.py new file mode 100644 index 0000000..5b720b8 --- /dev/null +++ b/fjformer/attention/jax_flash_attn_tpu.py @@ -0,0 +1,1678 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Flash Attention TPU kernel.""" +from __future__ import annotations + +import dataclasses +import functools +from typing import Any, NamedTuple + +import jax +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp + +DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) +NUM_LANES = 128 +NUM_SUBLANES = 8 + + +class SegmentIds(NamedTuple): + """SegmentIds for Q and KV sequences. + + SegmentIds are used to generate segment mask, which prevents attention between + different segments in the input sequence. Each array is a list of ids + (integers). + Only the token with the same id can attend to each other. + + Attributes: + q: segment ids along the Q sequence. + kv: segment ids along the KV sequence. + """ + + q: jax.Array # [q_seq_len] + kv: jax.Array # [kv_seq_len] + + +@dataclasses.dataclass(frozen=True) +class BlockSizes: + """Tile sizes parameterizing FlashAttention kernels. + + Those parameters have negligible effect on numerics, but affect performance + greatly. + """ + block_q: int + block_k_major: int + block_k: int + block_b: int + + block_q_major_dkv: int | None = None + block_k_major_dkv: int | None = None + block_k_dkv: int | None = None + block_q_dkv: int | None = None + + block_k_major_dq: int | None = None + block_k_dq: int | None = None + block_q_dq: int | None = None + + def __post_init__(self): + def verify_major_minor(prefix, suffix, major, minor): + if minor > major: + raise ValueError( + f"{prefix}{suffix}={minor} should be smaller than" + f" {prefix}_major{suffix}={major}" + ) + if major % minor != 0: + raise ValueError( + f"{prefix}{suffix}={minor} should divide" + f" {prefix}_major{suffix}={major}" + ) + + verify_major_minor("block_k", "", self.block_k_major, self.block_k) + if self.block_q_major_dkv is not None and self.block_q_dkv is not None: + verify_major_minor( + "block_q", "_dkv", self.block_q_major_dkv, self.block_q_dkv + ) + if self.block_k_major_dkv is not None and self.block_k_dkv is not None: + verify_major_minor( + "block_k", "_dkv", self.block_k_major_dkv, self.block_k_dkv + ) + if self.block_k_major_dq is not None and self.block_k_dq is not None: + verify_major_minor( + "block_k", "_dq", self.block_k_major_dq, self.block_k_dq + ) + + @property + def has_backward_blocks(self) -> bool: + backward_blocks = ( + self.block_q_major_dkv, + self.block_k_major_dkv, + self.block_q_dkv, + self.block_k_dkv, + self.block_k_major_dq, + self.block_k_dq, + self.block_q_dq, + ) + return all(b is not None for b in backward_blocks) + + @classmethod + def get_default(cls, batch_size, num_heads, q_seq_len, kv_len, d_model): + # TODO(apaszke,sharadmv): Select better parameters based on a heuristic. + del batch_size, num_heads, q_seq_len, kv_len, d_model # Unused. + return BlockSizes( + block_q=128, + block_k_major=128, + block_k=128, + block_b=1, + block_q_major_dkv=128, + block_k_major_dkv=128, + block_k_dkv=128, + block_q_dkv=128, + block_k_major_dq=128, + block_k_dq=128, + block_q_dq=128, + ) + + +@functools.partial( + jax.jit, + static_argnames=[ + "causal", + "sm_scale", + "block_sizes", + "debug", + ], +) +def flash_attention( + q, # [batch_size, num_heads, q_seq_len, d_model] + k, # [batch_size, num_heads, kv_seq_len, d_model] + v, # [batch_size, num_heads, kv_seq_len, d_model] + ab=None, # [batch_size, num_heads, q_seq_len, kv_seq_len] + segment_ids=None, # q of [batch_size, q_seq_len] and kv of [batch_size, kv_seq_len] + *, + causal: bool = False, + sm_scale: float = 1.0, + block_sizes: BlockSizes | None = None, + debug: bool = False, +): + batch_size, num_heads, q_seq_len, d_model = q.shape + batch_size_k, num_heads_k, kv_seq_len, d_model_k = k.shape + batch_size_v, num_heads_v, kv_seq_len_v, d_model_v = v.shape + if batch_size != batch_size_k or batch_size != batch_size_v: + raise ValueError( + f"Batch size mismatch: got {batch_size}, {batch_size_k} and" + f" {batch_size_v} (for q, k, v respectively)" + ) + if num_heads != num_heads_k or num_heads != num_heads_v: + raise ValueError( + f"Head count mismatch: got {num_heads}, {num_heads_k}," + f" {num_heads_v} (for q, k, v respectively)" + ) + if d_model != d_model_k: + raise ValueError( + f"Model dimension mismatch: got {d_model} and {d_model_k} (for q and k" + " respectively)" + ) + if d_model != d_model_v: + raise NotImplementedError( + "V model dimension unequal to KV model dimension unsupported" + ) + if kv_seq_len != kv_seq_len_v: + raise ValueError( + f"KV sequence length mismatch: got {kv_seq_len} and {kv_seq_len_v}" + ) + if block_sizes is None: + block_sizes = BlockSizes.get_default( + batch_size, num_heads, q_seq_len, kv_seq_len, d_model + ) + return _flash_attention( + q, k, v, ab, segment_ids, False, causal, sm_scale, block_sizes, debug + ) + + +@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 10)) +def _flash_attention( + q, + k, + v, + ab, + segment_ids, + save_residuals, + causal, + sm_scale, + block_sizes, + debug, +): + return _flash_attention_impl( + q, + k, + v, + ab, + segment_ids, + save_residuals, + causal, + sm_scale, + block_sizes.block_b, + block_sizes.block_q, + block_sizes.block_k_major, + block_sizes.block_k, + debug, + ) + + +def _flash_attention_fwd( + q, + k, + v, + ab, + segment_ids, + save_residuals, + causal, + sm_scale, + block_sizes, + debug, +): + if save_residuals: + raise NotImplementedError("Higher-order AD not supported") + o, l, m = _flash_attention( + q, k, v, ab, segment_ids, True, causal, sm_scale, block_sizes, debug + ) + return o, (q, k, v, ab, segment_ids, o, l, m) + + +def _flash_attention_bwd( + save_residuals: bool, + causal: bool, + sm_scale: float, + block_sizes: BlockSizes, + debug: bool, + residuals, + do, +): + """VJP rule for FlashAttention.""" + if save_residuals: + raise NotImplementedError("Higher-order AD not supported") + (q, k, v, ab, segment_ids, o, l, m) = residuals + if not block_sizes.has_backward_blocks: + raise ValueError( + "Program is being differentiated, but not all backward blocks are" + " specified" + ) + + di = jnp.sum( + o.astype(jnp.float32) * do.astype(jnp.float32), axis=-1 + ) # [batch_size, num_heads, q_seq_len] + + dk, dv = _flash_attention_bwd_dkv( + q, + k, + v, + ab, + segment_ids, + l, + m, + do, + di, + block_q_major=block_sizes.block_q_major_dkv, + block_k_major=block_sizes.block_k_major_dkv, + block_k=block_sizes.block_k_dkv, + block_q=block_sizes.block_q_dkv, + sm_scale=sm_scale, + causal=causal, + mask_value=DEFAULT_MASK_VALUE, + debug=debug, + ) + + dq, ds = _flash_attention_bwd_dq( + q, + k, + v, + ab, + segment_ids, + l, + m, + do, + di, + block_q_major=block_sizes.block_q_dq, + block_k_major=block_sizes.block_k_major_dq, + block_k=block_sizes.block_k_dq, + sm_scale=sm_scale, + causal=causal, + mask_value=DEFAULT_MASK_VALUE, + debug=debug, + ) + return dq, dk, dv, ds, None + + +_flash_attention.defvjp(fwd=_flash_attention_fwd, bwd=_flash_attention_bwd) + +MIN_BLOCK_SIZE = 128 +TRANS_B_DIM_NUMBERS = (((1,), (1,)), ((), ())) + + +def below_or_on_diag(r, r_blk_size, c, c_blk_size): + # A block is considered below or on diagonal as long as the bottom left + # corner of the block is below or on diagonal. + return ((r + 1) * r_blk_size - 1) > (c * c_blk_size) + + +def _flash_attention_kernel(q_tile_ref, *args, **kwargs): + block_b = q_tile_ref.shape[0] + # If we're not going to tile the softmax, then we can avoid a bunch of VPU ops. + if kwargs["block_k"] == kwargs["kv_seq_len"]: + kernel = _flash_attention_kernel_single_batch_single_step + else: + kernel = _flash_attention_kernel_single_batch + for batch_idx in range(block_b): + kernel((batch_idx, 0), q_tile_ref, *args, **kwargs) + + +def _flash_attention_kernel_single_batch( + batch_idx: tuple[int, ...], + q_tile_ref, + k_tile_ref, + v_tile_ref, + ab_tile_ref, + q_segment_ids_tile_ref, + kv_segment_ids_tile_ref, # Input arrays + o_tile_ref, # Output arrays + m_scratch_ref, + l_scratch_ref, + acc_scratch_ref, + l_ref: Any | None = None, + m_ref: Any | None = None, + *, + causal, + sm_scale, + block_k, + kv_seq_len, + mask_value, +): + block_k_major = k_tile_ref.shape[2] + block_q = q_tile_ref.shape[2] + head_dim = q_tile_ref.shape[-1] + + kv_seq_idx = pl.program_id(3) + + @pl.when(kv_seq_idx == 0) + def start_new_sequence(): + m_scratch_ref[batch_idx] = jnp.full( + m_scratch_ref.shape[2:], -jnp.inf, jnp.float32 + ) + l_scratch_ref[batch_idx] = jnp.zeros(l_scratch_ref.shape[2:], jnp.float32) + acc_scratch_ref[batch_idx] = jnp.zeros( + acc_scratch_ref.shape[2:], jnp.float32 + ) + + q_seq_idx = pl.program_id(2) + if causal: + should_run = below_or_on_diag(q_seq_idx, block_q, kv_seq_idx, block_k_major) + else: + should_run = True + + @pl.when(should_run) + def run(): + @functools.partial( + lax.fori_loop, 0, block_k_major // block_k, init_val=None + ) + def body(i, _): + m_prev = m_scratch_ref[batch_idx] + l_prev = l_scratch_ref[batch_idx] + q = q_tile_ref[batch_idx] # [block_q, head_dim] + start_k = i * block_k + k = pl.load( + k_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None)) + ) # [block_k, head_dim] + + s = jax.lax.dot_general( + q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 + ) # [block_q, block_k] + + # Add attention bias if needed. + # TODO(tanburn) Should the attention bias be added before or after + # multiplication by sm_scale? + if ab_tile_ref is not None: + ab = pl.load( + ab_tile_ref, + (*batch_idx, pl.dslice(None), pl.dslice(start_k, block_k)) + ).astype(jnp.float32) + s += ab + + if sm_scale != 1.0: + s *= sm_scale + + mask = None + if q_segment_ids_tile_ref is not None: + repeats, rem = divmod(block_k, NUM_LANES) + if rem: + raise NotImplementedError( + f"kv block size must be a multiple of {NUM_LANES}" + ) + q_segment_ids = pltpu.repeat( + q_segment_ids_tile_ref[batch_idx[0]], repeats, axis=1 + ) # [block_q, block_k]. + kv_segment_ids = pl.load( + kv_segment_ids_tile_ref, + (batch_idx[0], pl.dslice(1), pl.dslice(start_k, block_k)), + ) # [1, block_k]. + mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) + + if causal: + mask_shape = (block_q, block_k) + row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) + row_ids += q_seq_idx * block_q + col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) + col_ids += kv_seq_idx * block_k_major + start_k + causal_mask = col_ids <= row_ids + mask = ( + causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + ) + + s = s if mask is None else s + jnp.where(mask, 0.0, mask_value) + + m_curr = jnp.max(s, axis=1)[:, None] # Row max, shape [block_q, 1]. + m_next = jnp.maximum(m_prev, m_curr) # Shape [block_q, 128]. + + block_k_repeats, rem = divmod(block_k, MIN_BLOCK_SIZE) + if rem: + raise NotImplementedError( + f"{block_k=} should be a multiple of {MIN_BLOCK_SIZE}" + ) + p = jnp.exp(s - pltpu.repeat(m_next, block_k_repeats, 1)) + + alpha = jnp.exp(m_prev - m_next) # Shape [block_q, 128]. + + l_corr = alpha * l_prev + + l_next = jnp.sum(p, axis=1)[:, None] + l_corr # Shape [block_q, 128] + + head_dim_repeats, rem = divmod(head_dim, MIN_BLOCK_SIZE) + l_broadcast = lambda l: pltpu.repeat(l, head_dim_repeats, 1) + if rem: + if head_dim_repeats == 0: + l_broadcast = lambda l: l[:, :head_dim] + else: + raise NotImplementedError( + f"{head_dim=} should be a multiple of {MIN_BLOCK_SIZE} if larger" + ) + l_scratch_ref[batch_idx] = l_next + m_scratch_ref[batch_idx] = m_next + + l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) + acc_scratch_ref[batch_idx] *= l_broadcast(l_corr * l_next_inv_safe) + v = pl.load( + v_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None)) + ) + o_curr = jax.lax.dot( + p.astype(v.dtype), v, preferred_element_type=jnp.float32 + ) + acc_scratch_ref[batch_idx] += o_curr * l_broadcast(l_next_inv_safe) + + @pl.when(kv_seq_idx == (kv_seq_len // block_k_major) - 1) + def store_output(): + o_tile_ref[batch_idx] = acc_scratch_ref[batch_idx].astype(o_tile_ref.dtype) + if l_ref is not None: + l_ref[batch_idx] = l_scratch_ref[batch_idx].astype(l_ref.dtype) + if m_ref is not None: + m_ref[batch_idx] = m_scratch_ref[batch_idx].astype(m_ref.dtype) + + +def _flash_attention_kernel_single_batch_single_step( + batch_idx: tuple[int, ...], + q_tile_ref, + k_tile_ref, + v_tile_ref, + ab_tile_ref, + q_segment_ids_tile_ref, + kv_segment_ids_tile_ref, # Input arrays + o_tile_ref, # Output arrays + m_scratch_ref, + l_scratch_ref, + acc_scratch_ref, + l_ref: Any | None = None, + m_ref: Any | None = None, + *, + causal, + sm_scale, + block_k, + kv_seq_len, + mask_value, +): + block_k_major = k_tile_ref.shape[2] + block_q = q_tile_ref.shape[2] + + scratch_refs = (m_scratch_ref, l_scratch_ref, acc_scratch_ref) + assert all(ref is None for ref in scratch_refs) + assert kv_seq_len == block_k_major == block_k + + q = q_tile_ref[batch_idx] # [block_q, head_dim] + k = k_tile_ref[batch_idx] # [block_k, head_dim] + s = jax.lax.dot_general( + q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 + ) # [block_q, block_k] + + if ab_tile_ref is not None: + s += ab_tile_ref[batch_idx].astype(jnp.float32) + if sm_scale != 1.0: + s *= sm_scale + + mask = None + if q_segment_ids_tile_ref is not None: + repeats, rem = divmod(block_k, NUM_LANES) + if rem: + raise NotImplementedError( + f"kv block size must be a multiple of {NUM_LANES}" + ) + q_segment_ids = pl.load( + q_segment_ids_tile_ref, (batch_idx[0],) + ) # [block_q, NUM_LANES]. + q_segment_ids = pltpu.repeat( + q_segment_ids, repeats, axis=1 + ) # [block_q, block_k]. + kv_segment_ids = pl.load( + kv_segment_ids_tile_ref, (batch_idx[0], pl.dslice(1)) + ) # [1, block_k]. + mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) + + if causal: + q_seq_idx = pl.program_id(2) + mask_shape = (block_q, block_k) + row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) + row_ids += q_seq_idx * block_q + col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) + causal_mask = col_ids <= row_ids + mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + s = s if mask is None else s + jnp.where(mask, 0.0, mask_value) + + m = jnp.max(s, axis=1)[:, None] + p = jnp.exp(s - m) + l = jnp.sum(p, axis=1)[:, None] + p /= l + + if m_ref is not None: + m_ref[batch_idx] = lax.broadcast_in_dim(m, m_ref.shape[2:], range(2)) + if l_ref is not None: + l_ref[batch_idx] = lax.broadcast_in_dim(l, l_ref.shape[2:], range(2)) + + v = v_tile_ref[batch_idx] + o_tile_ref[batch_idx] = jax.lax.dot( + p.astype(v.dtype), v, preferred_element_type=jnp.float32 + ).astype(o_tile_ref.dtype) + + +def _flash_attention_impl( + q, + k, + v, + ab, + segment_ids, + save_residuals, + causal, + sm_scale, + block_b, + block_q, + block_k_major, + block_k, + debug, +): + batch_size, num_heads, q_seq_len, head_dim = q.shape + _, _, kv_seq_len, _ = k.shape + _verify_block("block_q", "q_seq_len", block_q, q_seq_len, should_divide=False) + _verify_block("block_k_major", "kv_seq_len", block_k_major, kv_seq_len) + _verify_block("block_k", "kv_seq_len", block_k, kv_seq_len) + _verify_block("block_b", "batch", block_b, batch_size, should_divide=False) + + # TODO(apaszke): Tile over heads as well. + grid = ( + pl.cdiv(batch_size, block_b), + num_heads, + pl.cdiv(q_seq_len, block_q), + kv_seq_len // block_k_major, + ) + + def q_index_map(batch_index, head_index, q_seq_index, _): + return (batch_index, head_index, q_seq_index, 0) + + def kv_index_map(batch_index, head_index, q_seq_index, kv_seq_index): + if causal: + # If the kv block is skipped, prefetch the next valid kv block, i.e. the + # 0th one to be used for the next block_q rows. + next_kv_index = lax.select( + below_or_on_diag(q_seq_index, block_q, kv_seq_index, block_k_major), + kv_seq_index, + 0, + ) + else: + next_kv_index = kv_seq_index + return (batch_index, head_index, next_kv_index, 0) + + def ab_index_map(batch_index, head_index, q_seq_index, kv_seq_index): + if causal: + should_run = below_or_on_diag( + q_seq_index, block_q, kv_seq_index, block_k_major + ) + # If the ab block is skipped, prefetch the next valid ab block, i.e. the + # 0th kv to be used for the next block_q rows. + next_q_index = lax.select( + should_run, + q_seq_index, + lax.select( + q_seq_index == (q_seq_len // block_q) - 1, 0, q_seq_index + 1 + ), + ) + next_kv_index = lax.select(should_run, kv_seq_index, 0) + else: + next_q_index = q_seq_index + next_kv_index = kv_seq_index + + return (batch_index, head_index, next_q_index, next_kv_index) + + def o_index_map(batch_index, head_index, q_seq_index, _): + return (batch_index, head_index, q_seq_index, 0) + + def lm_index_map(batch_index, head_index, q_seq_index, _): + return (batch_index, head_index, q_seq_index, 0) + + kernel = functools.partial( + _flash_attention_kernel, + causal=causal, + mask_value=DEFAULT_MASK_VALUE, + sm_scale=sm_scale, + block_k=block_k, + kv_seq_len=kv_seq_len, + ) + out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype) + out_shape = [out_shape] + out_specs = [pl.BlockSpec(o_index_map, (block_b, 1, block_q, head_dim))] + + if block_k != kv_seq_len: + scratch_shape = functools.partial(jax.ShapeDtypeStruct, dtype=jnp.float32) + m_scratch = scratch_shape((block_b, 1, block_q, MIN_BLOCK_SIZE)) + l_scratch = scratch_shape((block_b, 1, block_q, MIN_BLOCK_SIZE)) + acc_scratch = scratch_shape((block_b, 1, block_q, head_dim)) + out_shape += [m_scratch, l_scratch, acc_scratch] + out_specs += [ + pl.BlockSpec(lambda *_: (0, 0, 0, 0), m_scratch.shape), + pl.BlockSpec(lambda *_: (0, 0, 0, 0), l_scratch.shape), + pl.BlockSpec(lambda *_: (0, 0, 0, 0), acc_scratch.shape), + ] + else: + out_shape += [None, None, None] + out_specs += [None, None, None] + + if save_residuals: + out_specs = [ + *out_specs, + pl.BlockSpec(lm_index_map, (block_b, 1, block_q, MIN_BLOCK_SIZE)), + pl.BlockSpec(lm_index_map, (block_b, 1, block_q, MIN_BLOCK_SIZE)), + ] + l = jax.ShapeDtypeStruct( + (batch_size, num_heads, q_seq_len, MIN_BLOCK_SIZE), dtype=jnp.float32 + ) + m = jax.ShapeDtypeStruct( + (batch_size, num_heads, q_seq_len, MIN_BLOCK_SIZE), dtype=jnp.float32 + ) + out_shape = (*out_shape, l, m) + + ab_block_spec = ( + pl.BlockSpec(ab_index_map, (block_b, 1, block_q, block_k_major)) + if ab is not None else None) + + q_segment_ids_spec = kv_segment_ids_spec = None + q_segment_ids = kv_segment_ids = None + if segment_ids is not None: + + def q_segment_ids_index_map(batch_index, head_index, q_seq_index, _): + del head_index + return (batch_index, q_seq_index, 0) + + def kv_segment_ids_index_map( + batch_index, head_index, q_seq_index, kv_seq_index + ): + del head_index + if causal: + next_kv_index = lax.select( + below_or_on_diag(q_seq_index, block_q, kv_seq_index, block_k_major), + kv_seq_index, + 0, + ) + else: + next_kv_index = kv_seq_index + return (batch_index, 0, next_kv_index) + + q_segment_ids_spec = pl.BlockSpec( + q_segment_ids_index_map, (block_b, block_q, NUM_LANES) + ) + kv_segment_ids_spec = pl.BlockSpec( + kv_segment_ids_index_map, (block_b, NUM_SUBLANES, block_k_major) + ) + + q_segment_ids = jax.lax.broadcast_in_dim( + segment_ids.q, + (batch_size, q_seq_len, NUM_LANES), + ( + 0, + 1, + ), + ) + kv_segment_ids = jax.lax.broadcast_in_dim( + segment_ids.kv, + (batch_size, NUM_SUBLANES, kv_seq_len), + ( + 0, + 2, + ), + ) + + in_specs = [ + pl.BlockSpec(q_index_map, (block_b, 1, block_q, head_dim)), + pl.BlockSpec(kv_index_map, (block_b, 1, block_k_major, head_dim)), + pl.BlockSpec(kv_index_map, (block_b, 1, block_k_major, head_dim)), + ab_block_spec, + q_segment_ids_spec, + kv_segment_ids_spec, + ] + + o, *aux = pl.pallas_call( + kernel, + out_shape=out_shape, + in_specs=in_specs, + out_specs=out_specs, + grid=grid, + debug=debug, + mosaic_params=dict( + dimension_semantics=("parallel", "parallel", "parallel", "arbitrary") + ), + )(q, k, v, ab, q_segment_ids, kv_segment_ids) + if save_residuals: + l, m = (v[..., 0] for v in aux[-2:]) + return (o, l, m) + else: + return o + + +def _flash_attention_dkv_kernel( + q_tile_ref, + k_tile_ref, + v_tile_ref, + ab_tile_ref, + q_segment_ids_tile_ref, + kv_segment_ids_tile_ref, + l_tile_ref, + m_tile_ref, + do_tile_ref, + di_tile_ref, + dk_tile_ref, + dv_tile_ref, + dk_scratch_ref, + dv_scratch_ref, + *, + sm_scale: float, + causal: bool, + mask_value: float, + q_seq_len: int, + block_q: int, + block_k: int, +): + _, _, block_q_major, _ = q_tile_ref.shape + _, _, block_k_major, _ = k_tile_ref.shape + + q_seq_index = pl.program_id(axis=3) + kv_seq_index = pl.program_id(axis=2) + + @pl.when(q_seq_index == 0) + def start_new_sequence(): + dk_scratch_ref[:, :] = jnp.zeros(dk_scratch_ref.shape, dk_scratch_ref.dtype) + dv_scratch_ref[:, :] = jnp.zeros(dv_scratch_ref.shape, dv_scratch_ref.dtype) + + def q_body(j, _): + start_q = j * block_q + + def k_body(i, _): + start_k = i * block_k + k = pl.load(k_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None))) + v = pl.load(v_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None))) + q = pl.load(q_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) + ) # [block_q, head_dim] + l = pl.load(l_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) + ) # [block_q, 128] + m = pl.load(m_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) + ) # [block_q, 128] + do = pl.load(do_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) + ) # [block_q, 128] + di = pl.load(di_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) + ).astype(jnp.float32) # [block_q, 128] + + capped_logits = lax.dot_general( + q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 + ) # [block_q_major, block_k] + + if ab_tile_ref is not None: + ab = pl.load( + ab_tile_ref, + ( + 0, + 0, + pl.dslice(j * block_q, block_q), + pl.dslice(i * block_k, block_k), + ), + ).astype(jnp.float32) + capped_logits += ab + + if sm_scale != 1.0: + capped_logits *= sm_scale + + mask = None + if q_segment_ids_tile_ref is not None: + repeats, rem = divmod(block_k, NUM_LANES) + if rem: + raise NotImplementedError( + ) + q_segment_ids = pl.load( + q_segment_ids_tile_ref, (0, pl.ds(start_q, block_q), slice(None)) + ) # [block_q, NUM_LANES]. + q_segment_ids = pltpu.repeat( + q_segment_ids, repeats, axis=1 + ) # [block_q, block_k]. + kv_segment_ids = pl.load( + kv_segment_ids_tile_ref, (slice(None), 0, pl.ds(start_k, block_k)) + ) # [1, block_k]. + mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) + + if causal: + mask_shape = (block_q, block_k) + row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) + row_ids += q_seq_index * block_q_major + start_q + col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) + col_ids += kv_seq_index * block_k_major + start_k + causal_mask = col_ids <= row_ids + mask = ( + causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + ) + + capped_logits = ( + capped_logits + if mask is None + else capped_logits + jnp.where(mask, 0.0, mask_value) + ) + + p = jnp.exp( + capped_logits - pltpu.repeat(m, block_k // MIN_BLOCK_SIZE, axis=1) + ) + p = p * pltpu.repeat( + 1 / l, block_k // MIN_BLOCK_SIZE, axis=1 + ) # [block_q_major, block_k_major] + dv = lax.dot(p.T.astype(do.dtype), do, preferred_element_type=jnp.float32) + pl.store(dv_scratch_ref, (pl.ds(start_k, block_k), slice(None)), + pl.load(dv_scratch_ref, (pl.ds(start_k, block_k), slice(None))) + + dv.astype(dv_scratch_ref.dtype)) + + # di: [block_q, 128] + # do: [block_q, head_dim] + # v: [block_k_major, head_dim] + dp = lax.dot_general( + do, v, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 + ) + ds = (dp - pltpu.repeat(di, block_k // MIN_BLOCK_SIZE, axis=1)) * p + + if sm_scale != 1.0: + ds = ds * sm_scale + + # ds: [block_q_major, block_k_major] + # q: [block_q_major, head_dim] + dk = lax.dot(ds.T.astype(do.dtype), q, preferred_element_type=jnp.float32) + pl.store(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None)), + pl.load(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None))) + + dk.astype(dk_scratch_ref.dtype)) + + lax.fori_loop(0, block_k_major // block_k, k_body, None) + + if causal: + should_run = below_or_on_diag( + q_seq_index, block_q_major, kv_seq_index, block_k_major + ) + else: + should_run = True + + @pl.when(should_run) + def run(): + lax.fori_loop(0, block_q_major // block_q, q_body, None) + + @pl.when(q_seq_index == q_seq_len // block_q_major - 1) + def end_of_q_sequence(): + dv_tile_ref[0, 0, :, :] = dv_scratch_ref[...].astype(dv_tile_ref) + dk_tile_ref[0, 0, :, :] = dk_scratch_ref[...].astype(dk_tile_ref) + + +def _flash_attention_bwd_dkv( + q, + k, + v, + ab, + segment_ids, + l, + m, + do, + di, + *, + block_q_major: int | None, + block_q: int | None, + block_k_major: int | None, + block_k: int | None, + sm_scale: float, + causal: bool = False, + mask_value: float = DEFAULT_MASK_VALUE, + debug: bool = False, +): + batch_size, num_heads, q_seq_len, head_dim = q.shape + _, _, kv_seq_len, _ = k.shape + _verify_block("block_q_major_dkv", "q_seq_len", block_q_major, q_seq_len) + _verify_block("block_q_dkv", "q_seq_len", block_q, q_seq_len) + _verify_block("block_k_major_dkv", "kv_seq_len", block_k_major, kv_seq_len) + _verify_block("block_k_dkv", "kv_seq_len", block_k, kv_seq_len) + + # Broadcast out scalar values + m = jnp.broadcast_to(m[..., None], (*m.shape, MIN_BLOCK_SIZE)) + l = jnp.broadcast_to(l[..., None], (*l.shape, MIN_BLOCK_SIZE)) + # Preprocess contraction for bwd pass + di = jnp.broadcast_to(di[..., None], (*di.shape, MIN_BLOCK_SIZE)) + + # kv index needs to be before q index since q index is the contractng + # dimension. + grid = ( + batch_size, + num_heads, + kv_seq_len // block_k_major, + q_seq_len // block_q_major, + ) + + def qo_index_map(batch_index, head_index, kv_seq_index, q_seq_index): + if causal: + # If the q block is skipped, stay at the 0th q block. + next_q_index = lax.select( + below_or_on_diag( + q_seq_index, block_q_major, kv_seq_index, block_k_major + ), + q_seq_index, + 0, + ) + else: + next_q_index = q_seq_index + + return (batch_index, head_index, next_q_index, 0) + + qo_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, head_dim)) + assert qo_spec.block_shape is not None + assert q.ndim == len(qo_spec.block_shape) + do_spec = qo_spec + assert do.ndim == len(qo_spec.block_shape) + + def kv_index_map(batch_index, head_index, kv_seq_index, _): + return (batch_index, head_index, kv_seq_index, 0) + + kv_spec = pl.BlockSpec(kv_index_map, (1, 1, block_k_major, head_dim)) + assert kv_spec.block_shape is not None + assert k.ndim == len(kv_spec.block_shape) + assert v.ndim == len(kv_spec.block_shape) + + def lm_index_map(batch_index, head_index, _, q_seq_index): + return (batch_index, head_index, q_seq_index, 0) + + lm_spec = pl.BlockSpec(lm_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE)) + assert lm_spec.block_shape is not None + assert l.ndim == len(lm_spec.block_shape) + assert m.ndim == len(lm_spec.block_shape) + + di_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE)) + assert di_spec.block_shape is not None + assert di.ndim == len(di_spec.block_shape) + + def ab_index_map(batch_index, head_index, kv_seq_index, q_seq_index): + return (batch_index, head_index, q_seq_index, kv_seq_index) + + dab_spec = ( + pl.BlockSpec(ab_index_map, (1, 1, block_q_major, block_k_major)) + if ab is not None + else None + ) + + q_segment_ids_spec = kv_segment_ids_spec = None + q_segment_ids = kv_segment_ids = None + if segment_ids is not None: + + def q_segment_ids_index_map( + batch_index, head_index, kv_seq_index, q_seq_index + ): + del head_index + if causal: + next_q_index = lax.select( + below_or_on_diag( + q_seq_index, block_q_major, kv_seq_index, block_k_major + ), + q_seq_index, + 0, + ) + else: + next_q_index = q_seq_index + return (batch_index, next_q_index, 0) + + def kv_segment_ids_index_map(batch_index, head_index, kv_seq_index, _): + del head_index + return (batch_index, 0, kv_seq_index) + + q_segment_ids_spec = pl.BlockSpec( + q_segment_ids_index_map, (1, block_q_major, NUM_LANES) + ) + kv_segment_ids_spec = pl.BlockSpec( + kv_segment_ids_index_map, (1, NUM_SUBLANES, block_k_major) + ) + + q_segment_ids = jax.lax.broadcast_in_dim( + segment_ids.q, + (batch_size, q_seq_len, NUM_LANES), + ( + 0, + 1, + ), + ) + kv_segment_ids = jax.lax.broadcast_in_dim( + segment_ids.kv, + (batch_size, NUM_SUBLANES, kv_seq_len), + ( + 0, + 2, + ), + ) + + in_specs = [ + qo_spec, + kv_spec, + kv_spec, + dab_spec, + q_segment_ids_spec, + kv_segment_ids_spec, + lm_spec, + lm_spec, + do_spec, + di_spec, + ] + + out_shapes = [ + jax.ShapeDtypeStruct((batch_size, num_heads, kv_seq_len, head_dim), + k.dtype), + jax.ShapeDtypeStruct((batch_size, num_heads, kv_seq_len, head_dim), + v.dtype), + jax.ShapeDtypeStruct((block_k_major, head_dim), jnp.float32), + jax.ShapeDtypeStruct((block_k_major, head_dim), jnp.float32), + ] + + def dkv_index_map(batch_index, head_index, kv_seq_index, _): + return (batch_index, head_index, kv_seq_index, 0) + + dkv_spec = pl.BlockSpec(dkv_index_map, (1, 1, block_k_major, head_dim)) + out_specs = [ + dkv_spec, dkv_spec, + pl.BlockSpec(lambda *_: (0, 0), (block_k_major, head_dim)), + pl.BlockSpec(lambda *_: (0, 0), (block_k_major, head_dim)), + ] + + kernel = functools.partial( + _flash_attention_dkv_kernel, + block_q=block_q, + block_k=block_k, + sm_scale=sm_scale, + causal=causal, + mask_value=mask_value, + q_seq_len=q_seq_len, + ) + name_scope = f"flash_mha_bwd_dkv_{block_q_major=}_{block_q=}_{block_k_major=}_{block_k=}" + with jax.named_scope(name_scope): + dk, dv, _, _ = pl.pallas_call( + kernel, + in_specs=in_specs, # type: ignore + out_shape=out_shapes, + out_specs=out_specs, + grid=grid, + debug=debug, + mosaic_params=dict( + dimension_semantics=( + "parallel", + "parallel", + "parallel", + "arbitrary", + ) + ), + )(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di) + assert dk.shape == k.shape + assert dv.shape == v.shape + return dk, dv + + +def _flash_attention_dq_kernel( + q_tile_ref, + k_tile_ref, + v_tile_ref, + ab_tile_ref, + q_segment_ids_tile_ref, + kv_segment_ids_tile_ref, + l_tile_ref, + m_tile_ref, + do_tile_ref, + di_tile_ref, + dq_tile_ref, + dq_scratch_ref, + ds_tile_ref, + *, + sm_scale: float, + causal: bool, + mask_value: float, + kv_seq_len: int, + block_k: int, +): + _, _, block_k_major, _ = k_tile_ref.shape + _, _, block_q_major, _ = q_tile_ref.shape + + kv_seq_index = pl.program_id(axis=3) + q_seq_index = pl.program_id(axis=2) + + @pl.when(kv_seq_index == 0) + def start_new_sequence(): + dq_scratch_ref[:, :] = jnp.zeros(dq_scratch_ref.shape, dq_scratch_ref.dtype) + + def body(i, _): + k_slice = pl.ds(i * block_k, block_k) + q = q_tile_ref[0, 0, :, :] + k = pl.load( + k_tile_ref, (0, 0, k_slice, slice(None)), + ) # [block_k, head_dim] + v = pl.load( + v_tile_ref, (0, 0, k_slice, slice(None)), + ) # [block_k, head_dim] + l = l_tile_ref[0, 0, :, :] # [block_q_major, 128] + m = m_tile_ref[0, 0, :, :] # [block_q_major, 128] + do = do_tile_ref[0, 0, :, :] # [block_q_major, head_dim] + di = di_tile_ref[0, 0, :].astype(jnp.float32) # [block_q_major, 128] + + capped_logits = jax.lax.dot_general( + q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 + ) + + if ab_tile_ref is not None: + ab = pl.load( + ab_tile_ref, (0, 0, pl.dslice(None), pl.dslice(i * block_k, block_k)) + ).astype(jnp.float32) + capped_logits += ab + + if sm_scale != 1.0: + capped_logits *= sm_scale + + mask = None + if q_segment_ids_tile_ref is not None: + repeats, rem = divmod(block_k, NUM_LANES) + if rem: + raise NotImplementedError( + f"kv block size must be a multiple of {NUM_LANES}" + ) + q_segment_ids = pltpu.repeat( + q_segment_ids_tile_ref[0], repeats, axis=1 + ) # [block_q, block_k]. + kv_segment_ids = pl.load( + kv_segment_ids_tile_ref, (slice(None), 0, k_slice) + ) # [1, block_k]. + mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) + + if causal: + mask_shape = (block_q_major, block_k) + row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) + row_ids += q_seq_index * block_q_major + col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) + col_ids += kv_seq_index * block_k_major + i * block_k + causal_mask = col_ids <= row_ids + mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + capped_logits = ( + capped_logits + if mask is None + else capped_logits + jnp.where(mask, 0.0, mask_value) + ) + + p = jnp.exp( + capped_logits - pltpu.repeat(m, block_k // MIN_BLOCK_SIZE, axis=1) + ) + p = p * pltpu.repeat( + 1 / l, block_k // MIN_BLOCK_SIZE, axis=1 + ) # [block_q_major, block_k] + + # di: [block_q_major, 128] + # do: [block_q_major, head_dim] + # v: [block_k_major, head_dim] + dp = jax.lax.dot_general( + do, + v, + TRANS_B_DIM_NUMBERS, + preferred_element_type=jnp.float32, + ) + ds = (dp - pltpu.repeat(di, block_k // MIN_BLOCK_SIZE, axis=1)) * p + # dp = jnp.dot(do, v.T) + # ds = (dp - (dp * p).sum(axis=1)[:, None]) * p + + if sm_scale != 1.0: + ds = ds * sm_scale + + if ds_tile_ref is not None: + pl.store( + ds_tile_ref, + (0, 0, pl.dslice(None), pl.dslice(i * block_k, block_k)), + ds.astype(ds_tile_ref.dtype), + ) + + # dp: [block_q_major, block_k] + # k: [block_k, head_dim] + dq_scratch_ref[:, :] += lax.dot( + ds.astype(k.dtype), + k, + preferred_element_type=jnp.float32, + ).astype(dq_scratch_ref.dtype) + + if causal: + should_run = below_or_on_diag( + q_seq_index, block_q_major, kv_seq_index, block_k_major + ) + should_not_run = lax.select(should_run, False, True) + else: + should_run = True + should_not_run = False # type: ignore + + @pl.when(should_run) + def run(): + lax.fori_loop(0, block_k_major // block_k, body, None) + + @pl.when(should_not_run) + def zero_out_ds(): + if ds_tile_ref is not None: + ds_tile_ref[...] = jnp.zeros_like(ds_tile_ref) + + @pl.when(kv_seq_index == kv_seq_len // block_k_major - 1) + def end_of_kv_sequence(): + dq_tile_ref[0, 0, :, :] = dq_scratch_ref[...].astype(dq_tile_ref) + dq_scratch_ref[...] = jnp.zeros_like(dq_scratch_ref) + + +def _flash_attention_bwd_dq( + q, + k, + v, + ab, + segment_ids, + l, + m, + do, + di, + *, + block_q_major: int | None, + block_k_major: int | None, + block_k: int | None, + sm_scale: float, + causal: bool, + mask_value: float, + debug: bool, +): + batch_size, num_heads, q_seq_len, head_dim = q.shape + _, _, kv_seq_len, _ = k.shape + _verify_block("block_q_dq", "q_seq_len", block_q_major, q_seq_len) + _verify_block("block_k_major_dq", "kv_seq_len", block_k_major, kv_seq_len) + _verify_block("block_k_dq", "block_k", block_k, kv_seq_len) + + # Broadcast out scalar values + m = jnp.broadcast_to(m[..., None], (*m.shape, MIN_BLOCK_SIZE)) + l = jnp.broadcast_to(l[..., None], (*l.shape, MIN_BLOCK_SIZE)) + # Preprocess contraction for bwd pass + di = jnp.broadcast_to(di[..., None], (*di.shape, block_k_major)) + + grid = ( + batch_size, + num_heads, + q_seq_len // block_q_major, + kv_seq_len // block_k_major, + ) + + def qo_index_map(batch_index, head_index, q_seq_index, _): + return (batch_index, head_index, q_seq_index, 0) + + qo_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, head_dim)) + do_spec = qo_spec + + def kv_index_map(batch_index, head_index, q_seq_index, kv_seq_index): + if causal: + # If the kv block is skipped, prefetch the next valid kv block, i.e. the + # 0th one to be used for the next block_q rows. + next_kv_index = lax.select( + below_or_on_diag( + q_seq_index, block_q_major, kv_seq_index, block_k_major + ), + kv_seq_index, + 0, + ) + else: + next_kv_index = kv_seq_index + return (batch_index, head_index, next_kv_index, 0) + + kv_spec = pl.BlockSpec(kv_index_map, (1, 1, block_k_major, head_dim)) + assert kv_spec.block_shape is not None + assert k.ndim == len(kv_spec.block_shape) + assert v.ndim == len(kv_spec.block_shape) + + def lm_index_map(batch_index, head_index, q_seq_index, _): + return (batch_index, head_index, q_seq_index, 0) + + lm_spec = pl.BlockSpec(lm_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE)) + assert lm_spec.block_shape is not None + assert l.ndim == len(lm_spec.block_shape) + assert m.ndim == len(lm_spec.block_shape) + + di_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE)) + assert di_spec.block_shape is not None + assert di.ndim == len(di_spec.block_shape) + + def ab_index_map(batch_index, head_index, q_seq_index, kv_seq_index): + return (batch_index, head_index, q_seq_index, kv_seq_index) + + dab_spec = ( + pl.BlockSpec(ab_index_map, (1, 1, block_q_major, block_k_major)) + if ab is not None + else None + ) + + q_segment_ids_spec = kv_segment_ids_spec = None + q_segment_ids = kv_segment_ids = None + if segment_ids is not None: + + def q_segment_ids_index_map(batch_index, head_index, q_seq_index, _): + del head_index + return (batch_index, q_seq_index, 0) + + def kv_segment_ids_index_map( + batch_index, head_index, q_seq_index, kv_seq_index + ): + del head_index + if causal: + # If the kv block is skipped, prefetch the next valid kv block, i.e. the + # 0th one to be used for the next block_q rows. + next_kv_index = lax.select( + below_or_on_diag( + q_seq_index, block_q_major, kv_seq_index, block_k_major + ), + kv_seq_index, + 0, + ) + else: + next_kv_index = kv_seq_index + return (batch_index, 0, next_kv_index) + + q_segment_ids_spec = pl.BlockSpec( + q_segment_ids_index_map, (1, block_q_major, NUM_LANES) + ) + kv_segment_ids_spec = pl.BlockSpec( + kv_segment_ids_index_map, (1, NUM_SUBLANES, block_k_major) + ) + + q_segment_ids = jax.lax.broadcast_in_dim( + segment_ids.q, + (batch_size, q_seq_len, NUM_LANES), + ( + 0, + 1, + ), + ) + kv_segment_ids = jax.lax.broadcast_in_dim( + segment_ids.kv, + (batch_size, NUM_SUBLANES, kv_seq_len), + ( + 0, + 2, + ), + ) + + in_specs = [ + qo_spec, + kv_spec, + kv_spec, + dab_spec, + q_segment_ids_spec, + kv_segment_ids_spec, + lm_spec, + lm_spec, + do_spec, + di_spec, + ] + + out_shapes = [ + jax.ShapeDtypeStruct(q.shape, q.dtype), + jax.ShapeDtypeStruct((block_q_major, head_dim), jnp.float32), + jax.ShapeDtypeStruct(ab.shape, ab.dtype) if ab is not None else None, + ] + dq_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, head_dim)) + out_specs = [ + dq_spec, + pl.BlockSpec(lambda *_: (0, 0), (block_q_major, head_dim)), + dab_spec, + ] + + kernel = functools.partial( + _flash_attention_dq_kernel, + sm_scale=sm_scale, + causal=causal, + mask_value=mask_value, + block_k=block_k, + kv_seq_len=kv_seq_len, + ) + name_scope = f"flash_mha_bwd_dq_{block_q_major=}_{block_k_major=}_{block_k=}" + with jax.named_scope(name_scope): + dq, _, ds = pl.pallas_call( + kernel, + in_specs=in_specs, # type: ignore + out_shape=out_shapes, + out_specs=out_specs, # type: ignore + grid=grid, + debug=debug, + mosaic_params=dict( + dimension_semantics=( + "parallel", + "parallel", + "parallel", + "arbitrary", + ) + ), + )(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di) + + # dab is just ds + return dq, ds + + +# For autograd testing. +def mha_reference_no_custom_vjp( + q, + k, + v, + ab: jax.Array | None = None, + segment_ids: SegmentIds | None = None, + *, + causal: bool = False, + mask_value: float = DEFAULT_MASK_VALUE, + sm_scale: float = 1.0, + save_residuals: bool = False, +): + logits = jnp.einsum("bhqc,bhkc->bhqk", q, k) + if ab is not None: + logits += ab + if sm_scale != 1.0: + logits *= sm_scale + + mask = None + if segment_ids is not None: + mask = segment_ids.q[:, :, None] == segment_ids.kv[:, None, :] + mask = mask[:, None, :, :] + + if causal: + _, _, q_seq_len, _ = q.shape + _, _, kv_seq_len, _ = k.shape + mask_shape = (q_seq_len, kv_seq_len) + row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) + col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) + causal_mask = (col_ids <= row_ids)[None, None, :, :] + mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + + logits = logits if mask is None else logits + jnp.where(mask, 0.0, mask_value) + + m = logits.max(axis=-1) + unnormalized = jnp.exp(logits - m[..., None]) + l = unnormalized.sum(axis=-1) + weights = unnormalized / l[..., None] + out = jnp.einsum("bhqk,bhkc->bhqc", weights, v) + if save_residuals: + return out, l, m + return out + + +@functools.partial( + jax.jit, static_argnames=["causal", "mask_value", "sm_scale"] +) +@jax.default_matmul_precision("bfloat16") +def mha_reference( + q, + k, + v, + ab, + segment_ids: SegmentIds | None = None, + causal: bool = False, + mask_value: float = DEFAULT_MASK_VALUE, + sm_scale=1.0, +): + return _mha_reference( + q, + k, + v, + ab, + segment_ids, + causal=causal, + mask_value=mask_value, + sm_scale=sm_scale, + save_residuals=False, + ) + + +@functools.partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8)) +def _mha_reference( + q, + k, + v, + ab, + segment_ids: SegmentIds | None, + causal: bool, + mask_value: float, + sm_scale: float, + save_residuals: bool, +): + return mha_reference_no_custom_vjp( + q, + k, + v, + ab, + segment_ids, + causal=causal, + mask_value=mask_value, + sm_scale=sm_scale, + save_residuals=save_residuals, + ) + + +def _mha_reference_fwd( + q, + k, + v, + ab, + segment_ids: SegmentIds | None, + causal: bool, + mask_value: float, + sm_scale: float, + save_residuals: bool, +): + if save_residuals: + raise NotImplementedError + res = _mha_reference( + q, + k, + v, + ab, + segment_ids, + causal=causal, + mask_value=mask_value, + sm_scale=sm_scale, + save_residuals=True, + ) + assert isinstance(res, tuple) + out, l, m = res + return out, (q, k, v, ab, segment_ids, out, l, m) + + +@functools.partial( + jax.jit, + static_argnames=[ + "causal", + "mask_value", + "sm_scale", + ], +) +def mha_reference_bwd( + q, + k, + v, + ab, + segment_ids: SegmentIds | None, + o, + l, + m, + do, + causal: bool = False, + mask_value: float = DEFAULT_MASK_VALUE, + sm_scale: float = 1.0, +): + if sm_scale != 1.0: + raise NotImplementedError + + logits = jnp.einsum( + "bhqc,bhkc->bhqk", + q.astype(jnp.float32), + k.astype(jnp.float32), + ) + if ab is not None: + logits += ab + + mask = None + if segment_ids is not None: + mask = segment_ids.q[:, :, None] == segment_ids.kv[:, None, :] + mask = mask[:, None, :, :] + + if causal: + _, _, q_seq_len, _ = q.shape + _, _, kv_seq_len, _ = k.shape + mask_shape = (q_seq_len, kv_seq_len) + row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) + col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) + causal_mask = (col_ids <= row_ids)[None, None, :, :] + mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + + logits = logits if mask is None else logits + jnp.where(mask, 0.0, mask_value) + + unnormalized = jnp.exp(logits - m[..., None]) + p = unnormalized / l[..., None] + dv = jnp.einsum("bhpt,bhpd->bhtd", p, do.astype(jnp.float32)).astype(v.dtype) + + dp = jnp.einsum( + "bhpd,bhtd->bhpt", do.astype(jnp.float32), v.astype(jnp.float32) + ) + + di = jnp.sum(o.astype(jnp.float32) * do.astype(jnp.float32), axis=-1)[ + ..., None + ] # [batch_size, num_heads, q_seq_len] + + ds = (dp - di) * p + dk = jnp.einsum("bhsd,bhst->bhtd", q.astype(jnp.float32), ds).astype(k.dtype) + dq = jnp.einsum("bhst,bhtd->bhsd", ds, k.astype(jnp.float32)).astype(q.dtype) + + # dab is just ds + dab = ds if ab is not None else None + return dq, dk, dv, dab + + +def _mha_reference_bwd( + causal: bool, + mask_value: float, + sm_scale: float, + save_residuals: bool, + residuals, + do, +): + del save_residuals + q, k, v, ab, segment_ids, o, l, m = residuals + dq, dk, dv, dab = mha_reference_bwd( + q, + k, + v, + ab, + segment_ids, + o, + l, + m, + do, + causal=causal, + mask_value=mask_value, + sm_scale=sm_scale, + ) + return dq, dk, dv, dab, None + + +_mha_reference.defvjp(fwd=_mha_reference_fwd, bwd=_mha_reference_bwd) + + +def _verify_block(block_name, dim_name, block, dim, should_divide=True): + if block > dim: + raise ValueError( + f"{block_name}={block} should be smaller or equal to {dim_name}={dim}" + ) + if should_divide and dim % block != 0: + raise ValueError( + f"{dim_name}={dim} should be divisible by {block_name}={block}" + ) diff --git a/fjformer/gpu_pallas/__init__.py b/fjformer/gpu_pallas/__init__.py new file mode 100644 index 0000000..08562ac --- /dev/null +++ b/fjformer/gpu_pallas/__init__.py @@ -0,0 +1,4 @@ +from .attention import mha as gpu_flash_attention +from .softmax import softmax +from .layer_norm import layer_norm +from .rms_norm import rms_norm diff --git a/fjformer/gpu_pallas/attention.py b/fjformer/gpu_pallas/attention.py new file mode 100644 index 0000000..e32b157 --- /dev/null +++ b/fjformer/gpu_pallas/attention.py @@ -0,0 +1,571 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module containing fused attention forward and backward pass.""" +from __future__ import annotations + +import functools +from typing import Any, Optional + +import jax +from jax import lax +from jax.experimental import pallas as pl +import jax.numpy as jnp +import numpy as np + +DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) + + +def mha_forward_kernel( + q_ref, + k_ref, + v_ref, # Input arrays + segment_ids_ref: jax.Array | None, # segment_id arrays + o_ref: Any, # Output + *residual_refs: Any, # Residual outputs + sm_scale: float, + causal: bool, + block_q: int, + block_d: int, + block_k: int, +): + seq_len = q_ref.shape[0] + start_q = pl.program_id(0) + + # acc is the buffer where we accumulate the output on sram. + # m_i and l_i (see FlashAttention paper) are updated during the k,v loop. + m_i = jnp.zeros(block_q, dtype=jnp.float32) - float('inf') + l_i = jnp.zeros(block_q, dtype=jnp.float32) + # acc is the buffer where we accumulate the output on sram. + acc = jnp.zeros((block_q, block_d), dtype=jnp.float32) + + # Load q: it will stay in L1 throughout. Indices form a matrix because we + # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. + # q tile has shape [block_q, block_d], block_d == head_dim. + q = pl.load(q_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None))) + q_segment_ids = ( + None + if segment_ids_ref is None + else pl.load(segment_ids_ref, (pl.dslice(start_q * block_q, block_q),)) + ) + + # In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size + # (Bc == block_k here), and fast over blocks of q (size Br == block_q here). + # Here we only loop over blocks of kv to process entire seq_len, the loop over + # blocks of q is carried out by the grid. + def body(start_k, carry): + acc, m_prev, l_prev = carry + + k = pl.load(k_ref, (pl.dslice(start_k * block_k, block_k), slice(None))) + kv_segment_ids = ( + None + if segment_ids_ref is None + else pl.load(segment_ids_ref, (pl.dslice(start_k * block_k, block_k),)) + ) + qk = jnp.zeros([block_q, block_k], dtype=jnp.float32) + qk += pl.dot(q, k.T) # [block_q, block_k] + if sm_scale != 1.: + qk *= sm_scale # [block_q, block_k] + + # Bring closer to XLA:GPU numerics. + qk = qk.astype(q_ref.dtype) + qk = qk.astype(jnp.float32) + + if causal or segment_ids_ref is not None: + mask = None + if segment_ids_ref is not None: + mask = segment_mask(q_segment_ids, kv_segment_ids) + if causal: + span_q = start_q * block_q + jnp.arange(block_q) + span_k = start_k * block_k + jnp.arange(block_k) + causal_mask = span_q[:, None] >= span_k[None, :] + mask = ( + causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + ) + # Apply mask to qk. + qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) + m_curr = jnp.maximum(jnp.max(qk, axis=1), m_prev) + l_prev *= jnp.exp(m_prev - m_curr) + p = jnp.exp(qk - m_curr[:, None]) + l_curr = jnp.sum(p, axis=1) + l_prev + + l_rcp = 1. / l_curr + p = p * l_rcp[:, None] + acc *= (l_prev * l_rcp)[:, None] + p = p.astype(jnp.float16) + + v = pl.load(v_ref, (pl.dslice(start_k * block_k, block_k), pl.dslice(block_d))) + acc = acc + pl.dot(p.astype(v.dtype), v) + return acc, m_curr, l_curr + + if causal: + # Ceildiv (`pl.cdiv` and `//` do not work due to type of start_q) + upper_bound = lax.div(block_q * (start_q + 1) + block_k - 1, block_k) + else: + upper_bound = pl.cdiv(seq_len, block_k) # type: ignore + acc, m_i, l_i = lax.fori_loop(0, upper_bound, body, + (acc, m_i, l_i)) + + if residual_refs: + l_ref, m_ref = residual_refs + pl.store(l_ref, (pl.ds(start_q * block_q, block_q),), l_i) + pl.store(m_ref, (pl.ds(start_q * block_q, block_q),), m_i) + # Write output to dram. + acc = acc.astype(o_ref.dtype) + pl.store(o_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)), acc) + + +def segment_mask( + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, +): + # [B, T, 1] or [T, 1] + q_segment_ids = jnp.expand_dims(q_segment_ids, axis=-1) + # [B, 1, S] or [1, S] + if kv_segment_ids.ndim == 1: + kv_segment_ids = jnp.expand_dims(kv_segment_ids, axis=0) + else: + kv_segment_ids = jnp.expand_dims(kv_segment_ids, axis=1) + return jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) + + +@functools.partial( + jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13] +) +@functools.partial( + jax.jit, + static_argnames=[ + "sm_scale", + "causal", + "block_q", + "block_k", + "backward_pass_impl", + "num_warps", + "num_stages", + "grid", + "interpret", + "debug", + ], +) +def mha( + q, + k, + v, + segment_ids: jnp.ndarray | None, + sm_scale: float = 1.0, + causal: bool = False, + block_q: int = 128, + block_k: int = 128, + backward_pass_impl: str = "triton", + num_warps: Optional[int] = None, + num_stages: int = 2, + grid=None, + interpret: bool = False, + debug: bool = False, +): + del backward_pass_impl + batch_size, seq_len, num_heads, head_dim = q.shape + block_q = min(block_q, seq_len) + block_k = min(block_k, seq_len) + # Heuristics. + grid_ = grid + if grid_ is None: + grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads) + + num_warps_ = num_warps + if num_warps_ is None: + num_warps_ = 4 if head_dim <= 64 else 8 + kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale, + block_q=block_q, block_k=block_k, + block_d=head_dim, + causal=causal) + + in_specs = [ + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + ] + in_specs.append( + None # type: ignore[arg-type] + if segment_ids is None + else pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len)) + ) + out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype) + return pl.pallas_call( + kernel, + grid=grid_, + in_specs=in_specs, + out_specs=pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + num_warps=num_warps_, + num_stages=num_stages, + out_shape=out_shape, + debug=debug, + interpret=interpret, + name="mha_forward", + )(q, k, v, segment_ids) + + +def _mha_forward( + q, + k, + v, + segment_ids: jax.Array | None, + sm_scale: float, + causal: bool, + block_q: int, + block_k: int, + backward_pass_impl: str, + num_warps: Optional[int], + num_stages: int, + grid: Any, + interpret: bool, + debug: bool, +): + del backward_pass_impl + batch_size, seq_len, num_heads, head_dim = q.shape + block_q = min(block_q, seq_len) + block_k = min(block_k, seq_len) + # Heuristics. + grid_ = grid + if grid_ is None: + grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads) + + num_warps_ = num_warps + if num_warps_ is None: + num_warps_ = 4 if head_dim <= 64 else 8 + kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale, + causal=causal, block_q=block_q, block_k=block_k, + block_d=head_dim) + out_shape = [ + jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out + jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # l + dtype=jnp.float32), + jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # m + dtype=jnp.float32) + ] + in_specs = [ + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + ] + in_specs.append( + None # type: ignore[arg-type] + if segment_ids is None + else pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len)) + ) + out, l, m = pl.pallas_call( + kernel, + grid=grid_, + in_specs=in_specs, + out_specs=[ + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + ], + num_warps=num_warps_, + num_stages=num_stages, + out_shape=out_shape, + debug=debug, + interpret=interpret, + name="mha_forward", + )(q, k, v, segment_ids) + return out, (q, k, v, segment_ids, out, l, m) + + +def _preprocess_backward_kernel(out_ref, dout_ref, l_ref, + new_dout_ref, delta_ref, *, + block_q: int): + pid_m = pl.program_id(0) + + off_m = pl.ds(pid_m * block_q, block_q) + # load + o = pl.load(out_ref, (off_m, slice(None))).astype(jnp.float32) + do = pl.load(dout_ref, (off_m, slice(None))).astype(jnp.float32) + denom = pl.load(l_ref, (off_m,)).astype(jnp.float32) + # compute + do = do / denom[:, None] + delta = jnp.sum(o * do, axis=1) + # write-back + pl.store(new_dout_ref, (off_m, slice(None)), + do.astype(new_dout_ref.dtype)) + pl.store(delta_ref, (off_m,), delta.astype(delta_ref.dtype)) + + +def _preprocess_backward(out, do, l, block_q: int, + debug: bool, interpret: bool): + batch_size, seq_len, num_heads, head_dim = out.shape + out_shape = [ + jax.ShapeDtypeStruct(do.shape, do.dtype), + jax.ShapeDtypeStruct(l.shape, l.dtype), + ] + do_scaled, delta = pl.pallas_call( + functools.partial(_preprocess_backward_kernel, block_q=block_q), + grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads), + in_specs=[ + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + ], + out_specs=[ + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + ], + num_warps=4, + num_stages=3, + out_shape=out_shape, + debug=debug, + interpret=interpret, + name="mha_preprocess_backward")(out, do, l) + return do_scaled, delta + + +def mha_backward_kernel( + # Inputs + q_ref, + k_ref, + v_ref, + segment_ids_ref: jax.Array | None, + out_ref, + do_scaled_ref, + l_ref, + m_ref, + delta_ref, + _, + # Outputs + dq_ref, + dk_ref, + dv_ref, + *, + sm_scale: float, + causal: bool, + block_q: int, + block_d: int, + block_k: int, +): + del out_ref, l_ref # Not needed + seq_len = q_ref.shape[0] + + def outer_loop(start_k, _): + + dv = jnp.zeros([block_k, block_d], dtype=jnp.float32) + dk = jnp.zeros([block_k, block_d], dtype=jnp.float32) + k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None))) + v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None))) + span_k = start_k * block_k + jnp.arange(block_k) + kv_segment_ids = ( + None + if segment_ids_ref is None + else pl.load(segment_ids_ref, (pl.ds(start_k * block_k, block_k),)) + ) + + def inner_loop(start_q, carry): + dv, dk = carry + q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None))) + qk = pl.dot(q, k.T) + qk = qk.astype(q_ref.dtype) + qk = qk.astype(jnp.float32) + if sm_scale != 1.0: + qk *= sm_scale + + q_segment_ids = ( + None + if segment_ids_ref is None + else pl.load(segment_ids_ref, (pl.ds(start_q * block_q, block_q),)) + ) + + if causal or segment_ids_ref is not None: + mask = None + if segment_ids_ref is not None: + mask = segment_mask(q_segment_ids, kv_segment_ids) + + if causal: + span_q = start_q * block_q + jnp.arange(block_q) + causal_mask = span_q[:, None] >= span_k[None, :] + mask = ( + causal_mask + if mask is None + else jnp.logical_and(mask, causal_mask) + ) + qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) + + m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),)) + p = jnp.exp(qk - m[:, None]) + do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None))) + dv = dv + pl.dot(p.astype(do.dtype).T, do) + di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),)) + dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] + dp = dp + pl.dot(do, v.T) + ds = p * dp + if sm_scale != 1.0: + ds = ds * sm_scale + dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q) + dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q), + slice(None)), eviction_policy="evict_last") + dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) + pl.store(dq_ref, (pl.ds(start_q * block_q, block_q), + slice(None)), dq, eviction_policy="evict_last") + return dv, dk + + if causal: + lower_bound = lax.div(start_k * block_k, block_q) + else: + lower_bound = 0 + dv, dk = lax.fori_loop(lower_bound, pl.cdiv(seq_len, block_q), inner_loop, + (dv, dk)) + pl.store(dv_ref, (pl.ds(start_k * block_k, block_k), + slice(None)), dv.astype(dv_ref.dtype)) + pl.store(dk_ref, (pl.ds(start_k * block_k, block_k), + slice(None)), dk.astype(dk_ref.dtype)) + + lax.fori_loop(0, pl.cdiv(seq_len, block_k), outer_loop, None) + + +def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, + backward_pass_impl: str, num_warps: Optional[int], + num_stages: int, grid: Any, interpret: bool, + debug: bool, res, do): + del num_warps, num_stages, grid + q, k, v, segment_ids, out, l, m = res + + batch_size, seq_len, num_heads, head_dim = q.shape + block_q = min(block_q, seq_len) + block_k = min(block_k, seq_len) + do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret) + + if backward_pass_impl == "xla": + return jax.vjp( + functools.partial(mha_reference, sm_scale=sm_scale, causal=causal), + q, + k, + v, + segment_ids, + )[1](do) + elif backward_pass_impl == "triton": + # We accumulate into dq so we need to initialize it to zeros. + dq = jnp.zeros(q.shape, jnp.float32) + out_shapes = [ + jax.ShapeDtypeStruct(dq.shape, dq.dtype), + jax.ShapeDtypeStruct(k.shape, k.dtype), + jax.ShapeDtypeStruct(v.shape, v.dtype), + ] + + in_specs = [ + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + ] + if segment_ids is None: + in_specs.insert(3, None) # type: ignore[arg-type] + input_output_aliases = {8: 0} + else: + in_specs.insert(3, pl.BlockSpec(lambda j, k: (j, 0), (None, seq_len))) + input_output_aliases = {9: 0} + grid = (batch_size, num_heads) + # TODO(sharadmv): figure out why num_warps=8 doesn't work! + num_warps = 4 + dq, dk, dv = pl.pallas_call( + functools.partial( + mha_backward_kernel, + block_q=block_q, + block_d=head_dim, + block_k=block_k, + sm_scale=sm_scale, + causal=causal, + ), + grid=grid, + out_shape=out_shapes, + in_specs=in_specs, + out_specs=[ + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + ], + name="mha_backward", + debug=debug, + interpret=interpret, + num_warps=num_warps, + num_stages=1, + input_output_aliases=input_output_aliases, + )(q, k, v, segment_ids, out, do_scaled, l, m, delta, dq) + else: + raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}") + return dq.astype(q.dtype), dk, dv, None + + +mha.defvjp(_mha_forward, _mha_backward) + + +@functools.partial(jax.jit, static_argnames=['sm_scale', 'causal']) +def mha_reference( + q, + k, + v, + segment_ids: jnp.ndarray | None, + sm_scale=1.0, + causal: bool = False, +): + q_seq_len = q.shape[1] + kv_seq_len = k.shape[1] + logits = jnp.einsum('bqhc,bkhc->bhqk', q, k).astype(jnp.float32) + mask = None + if segment_ids is not None: + mask = jnp.expand_dims(segment_mask(segment_ids, segment_ids), 1) + mask = jnp.broadcast_to(mask, logits.shape) + if causal: + causal_mask = jnp.tril(jnp.ones((1, 1, q_seq_len, kv_seq_len), dtype=bool)) + causal_mask = jnp.broadcast_to(causal_mask, logits.shape) + mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + logits = logits if mask is None else jnp.where(mask, logits, float("-inf")) + weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) + return jnp.einsum('bhqk,bkhc->bqhc', weights, v) diff --git a/fjformer/gpu_pallas/layer_norm.py b/fjformer/gpu_pallas/layer_norm.py new file mode 100644 index 0000000..a89b31f --- /dev/null +++ b/fjformer/gpu_pallas/layer_norm.py @@ -0,0 +1,290 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module containing fused layer norm forward and backward pass.""" + +import functools + +from typing import Optional + +import jax +from jax import lax +import jax.numpy as jnp +from jax._src.lax.control_flow.for_loop import for_loop + +from jax.experimental import pallas as pl + + +def layer_norm_forward_kernel( + x_ref, weight_ref, bias_ref, # Input arrays + o_ref, mean_ref=None, rstd_ref=None, # Output arrays + *, eps: float, block_size: int): + n_col = x_ref.shape[0] + + def mean_body(i, acc_ref): + col_idx = i * block_size + jnp.arange(block_size) + mask = col_idx < n_col + a = pl.load(x_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_last").astype(jnp.float32) + acc_ref[:] += a + + mean = for_loop(pl.cdiv(n_col, block_size), mean_body, + jnp.zeros(block_size)).sum() / n_col + + def var_body(i, acc_ref): + col_idx = i * block_size + jnp.arange(block_size) + mask = col_idx < n_col + a = pl.load(x_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_last").astype(jnp.float32) + a = jnp.where(mask, a - mean, 0.) + acc_ref[:] += a * a + + var = for_loop(pl.cdiv(n_col, block_size), var_body, + jnp.zeros(block_size)).sum() / n_col + rstd = 1 / jnp.sqrt(var + eps) + if mean_ref is not None: + mean_ref[...] = mean.astype(mean_ref.dtype) + if rstd_ref is not None: + rstd_ref[...] = rstd.astype(rstd_ref.dtype) + + def body(i, _): + col_idx = i * block_size + jnp.arange(block_size) + mask = col_idx < n_col + weight = pl.load(weight_ref, (col_idx,), mask=mask) + bias = pl.load(bias_ref, (col_idx,), mask=mask) + x = pl.load(x_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_first").astype(jnp.float32) + out = (x - mean) * rstd * weight + bias + pl.store(o_ref, (col_idx,), out.astype(o_ref.dtype), mask=mask) + + for_loop(pl.cdiv(n_col, block_size), body, ()) + + +def layer_norm_forward( + x, weight, bias, + num_warps: Optional[int] = None, + num_stages: Optional[int] = 3, + eps: float = 1e-5, + backward_pass_impl: str = 'triton', + interpret: bool = False): + del num_stages + del backward_pass_impl + n = x.shape[-1] + # Triton heuristics + # Less than 64KB per feature: enqueue fused kernel + max_fused_size = 65536 // x.dtype.itemsize + block_size = min(max_fused_size, pl.next_power_of_2(n)) + block_size = min(max(block_size, 128), 4096) + num_warps = min(max(block_size // 256, 1), 8) + + kernel = functools.partial(layer_norm_forward_kernel, eps=eps, + block_size=block_size) + out_shape = [ + jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype), + jax.ShapeDtypeStruct(shape=(), dtype=x.dtype), + jax.ShapeDtypeStruct(shape=(), dtype=x.dtype) + ] + method = pl.pallas_call(kernel, num_warps=num_warps, + grid=(), out_shape=out_shape, debug=False, + interpret=interpret, name='ln_forward') + + method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None)) + out, mean, rstd = method(x, weight, bias) + return out, (x, weight, bias, mean, rstd) + + +def layer_norm_backward_kernel_dx( + # Inputs + x_ref, weight_ref, bias_ref, do_ref, + mean_ref, rstd_ref, + # Outputs + dx_ref, + *, eps: float, block_size: int): + n_col = x_ref.shape[0] + + def mean_body(i, acc_ref): + col_idx = i * block_size + jnp.arange(block_size) + mask = col_idx < n_col + a = pl.load(x_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_last").astype(jnp.float32) + dout = pl.load(do_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_last").astype(jnp.float32) + weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_last").astype(jnp.float32) + a_hat = (a - mean_ref[...]) * rstd_ref[...] + wdout = weight * dout + mean1_acc_ref, mean2_acc_ref = acc_ref + mean1_acc_ref[:] += a_hat * wdout + mean2_acc_ref[:] += wdout + + mean = for_loop(pl.cdiv(n_col, block_size), mean_body, + (jnp.zeros(block_size), jnp.zeros(block_size))) + mean1, mean2 = mean + mean1 = mean1.sum() / n_col + mean2 = mean2.sum() / n_col + + def dx_body(i, acc_ref): + col_idx = i * block_size + jnp.arange(block_size) + mask = col_idx < n_col + a = pl.load(x_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_last").astype(jnp.float32) + dout = pl.load(do_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_last").astype(jnp.float32) + weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_last").astype(jnp.float32) + a_hat = (a - mean_ref[...]) * rstd_ref[...] + wdout = weight * dout + da = (wdout - (a_hat * mean1 + mean2)) * rstd_ref[...] + pl.store(dx_ref, (col_idx,), da.astype(dx_ref.dtype), mask=mask) + + for_loop(pl.cdiv(n_col, block_size), dx_body, ()) + + +def layer_norm_backward_kernel_dw_db( + # Inputs + x_ref, weight_ref, bias_ref, do_ref, + mean_ref, rstd_ref, + # Outputs + dw_ref, db_ref, + *, eps: float, block_m: int, block_n: int): + m, n_col = x_ref.shape + j = pl.program_id(0) + col_idx = j * block_n + jnp.arange(block_n) + col_mask = col_idx < n_col + + def body(i, acc_ref): + row_idx = i * block_m + jnp.arange(block_m) + row_mask = row_idx < m + mask = row_mask[:, None] & col_mask[None, :] + a = pl.load( + x_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0 + ).astype(jnp.float32) + dout = pl.load( + do_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0 + ).astype(jnp.float32) + mean = pl.load(mean_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32) + rstd = pl.load(rstd_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32) + a_hat = (a - mean[:, None]) * rstd[:, None] + dw_acc_ref, db_acc_ref = acc_ref + dw_acc_ref[:] += (dout * a_hat).sum(axis=0) + db_acc_ref[:] += dout.sum(axis=0) + + dw_acc, db_acc = for_loop(pl.cdiv(m, block_m), body, (jnp.zeros(block_n), jnp.zeros(block_n))) + pl.store(dw_ref, (col_idx,), dw_acc.astype(dw_ref.dtype), mask=col_mask) + pl.store(db_ref, (col_idx,), db_acc.astype(db_ref.dtype), mask=col_mask) + + +def layer_norm_backward( + num_warps: Optional[int], + num_stages: Optional[int], + eps: float, + backward_pass_impl: str, + interpret: bool, + res, do): + del num_stages + x, weight, bias, mean, rstd = res + if backward_pass_impl == 'xla': + return jax.vjp(layer_norm_reference, x, weight, bias)[1](do) + + *shape_prefix, n = x.shape + reshaped_x = x.reshape((-1, n)) + reshaped_mean = mean.reshape((-1,)) + reshaped_rstd = rstd.reshape((-1,)) + reshaped_do = do.reshape((-1, n)) + # Triton heuristics + # Less than 64KB per feature: enqueue fused kernel + max_fused_size = 65536 // x.dtype.itemsize + block_size = min(max_fused_size, pl.next_power_of_2(n)) + block_size = min(max(block_size, 128), 4096) + num_warps = min(max(block_size // 256, 1), 8) + + # layer_norm_backward_kernel_dx parallel over batch dims + kernel = functools.partial(layer_norm_backward_kernel_dx, eps=eps, + block_size=block_size) + out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) + method = pl.pallas_call(kernel, num_warps=num_warps, + grid=(), out_shape=out_shape_dx, debug=False, + interpret=interpret, name='ln_backward_dx') + + method = jax.vmap(method, in_axes=(0, None, None, 0, 0, 0)) + dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd) + dx = dx.reshape((*shape_prefix, n)) + + # layer_norm_backward_kernel_dw_db reduce over batch dims + # Triton heuristics + if n > 10240: + block_n = 128 + block_m = 32 + num_warps = 4 + else: + # maximize occupancy for small N + block_n = 16 + block_m = 16 + num_warps = 8 + kernel = functools.partial(layer_norm_backward_kernel_dw_db, eps=eps, + block_m=block_m, block_n=block_n) + out_shape_dwbias = [ + jax.ShapeDtypeStruct(shape=weight.shape, dtype=weight.dtype), + jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype) + ] + grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),) + method = pl.pallas_call(kernel, num_warps=num_warps, + grid=grid_, out_shape=out_shape_dwbias, debug=False, + interpret=interpret, name='ln_backward_dw_db') + dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd) + return dx, dw, dbias + + +@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7]) +@functools.partial(jax.jit, static_argnames=["num_warps", "num_stages", + "num_stages", "eps", + "backward_pass_impl", + "interpret"]) +def layer_norm( + x, weight, bias, + num_warps: Optional[int] = None, + num_stages: Optional[int] = 3, + eps: float = 1e-5, + backward_pass_impl: str = 'triton', + interpret: bool = False): + n = x.shape[-1] + # Triton heuristics + # Less than 64KB per feature: enqueue fused kernel + max_fused_size = 65536 // x.dtype.itemsize + block_size = min(max_fused_size, pl.next_power_of_2(n)) + block_size = min(max(block_size, 128), 4096) + num_warps = min(max(block_size // 256, 1), 8) + + kernel = functools.partial(layer_norm_forward_kernel, eps=eps, + block_size=block_size) + out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) + method = pl.pallas_call(kernel, num_warps=num_warps, num_stages=num_stages, + grid=(), out_shape=out_shape, debug=False, + interpret=interpret) + method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None)) + return method(x, weight, bias) + + +layer_norm.defvjp(layer_norm_forward, layer_norm_backward) + + +@functools.partial(jax.jit, static_argnames=["eps"]) +@functools.partial(jax.vmap, in_axes=(0, None, None), out_axes=0) +def layer_norm_reference(x, weight, bias, *, eps: float = 1e-5): + mean = jnp.mean(x, axis=1) + mean2 = jnp.mean(jnp.square(x), axis=1) + var = jnp.maximum(0., mean2 - jnp.square(mean)) + y = x - mean[:, None] + mul = lax.rsqrt(var + eps) + return y * mul[:, None] * weight[None] + bias[None] diff --git a/fjformer/gpu_pallas/rms_norm.py b/fjformer/gpu_pallas/rms_norm.py new file mode 100644 index 0000000..52dcdf4 --- /dev/null +++ b/fjformer/gpu_pallas/rms_norm.py @@ -0,0 +1,267 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module containing rms forward and backward pass.""" + +import functools + +from typing import Optional + +import jax +from jax import lax +import jax.numpy as jnp +from jax._src.lax.control_flow.for_loop import for_loop + +from jax.experimental import pallas as pl + + +def rms_norm_forward_kernel( + x_ref, weight_ref, bias_ref, # Input arrays + o_ref, rstd_ref=None, # Output arrays + *, eps: float, block_size: int): + n_col = x_ref.shape[0] + + def var_body(i, acc_ref): + col_idx = i * block_size + jnp.arange(block_size) + mask = col_idx < n_col + a = pl.load(x_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_last").astype(jnp.float32) + a = jnp.where(mask, a, 0.) + acc_ref[:] += a * a + + var = for_loop(pl.cdiv(n_col, block_size), var_body, + jnp.zeros(block_size)).sum() / n_col + rstd = 1 / jnp.sqrt(var + eps) + if rstd_ref is not None: + rstd_ref[...] = rstd.astype(rstd_ref.dtype) + + def body(i, _): + col_idx = i * block_size + jnp.arange(block_size) + mask = col_idx < n_col + weight = pl.load(weight_ref, (col_idx,), mask=mask) + bias = pl.load(bias_ref, (col_idx,), mask=mask) + x = pl.load(x_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_first").astype(jnp.float32) + out = x * rstd * weight + bias + pl.store(o_ref, (col_idx,), out.astype(o_ref.dtype), mask=mask) + + for_loop(pl.cdiv(n_col, block_size), body, ()) + + +def rms_norm_forward( + x, weight, bias, + num_warps: Optional[int] = None, + num_stages: Optional[int] = 3, + eps: float = 1e-5, + backward_pass_impl: str = 'triton', + interpret: bool = False): + del num_stages + del backward_pass_impl + n = x.shape[-1] + # Triton heuristics + # Less than 64KB per feature: enqueue fused kernel + max_fused_size = 65536 // x.dtype.itemsize + block_size = min(max_fused_size, pl.next_power_of_2(n)) + block_size = min(max(block_size, 128), 4096) + num_warps = min(max(block_size // 256, 1), 8) + + kernel = functools.partial(rms_norm_forward_kernel, eps=eps, + block_size=block_size) + out_shape = [ + jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype), + jax.ShapeDtypeStruct(shape=(), dtype=x.dtype) + ] + method = pl.pallas_call(kernel, num_warps=num_warps, + grid=(), out_shape=out_shape, debug=False, + interpret=interpret, name='rms_forward') + + method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None)) + out, rstd = method(x, weight, bias) + return out, (x, weight, bias, rstd) + + +def rms_norm_backward_kernel_dx( + # Inputs + x_ref, weight_ref, bias_ref, do_ref, + rstd_ref, + # Outputs + dx_ref, + *, eps: float, block_size: int): + n_col = x_ref.shape[0] + + def mean_body(i, c1_acc_ref): + col_idx = i * block_size + jnp.arange(block_size) + mask = col_idx < n_col + a = pl.load(x_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_last").astype(jnp.float32) + dout = pl.load(do_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_last").astype(jnp.float32) + weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_last").astype(jnp.float32) + a_hat = a * rstd_ref[...] + wdout = weight * dout + c1_acc_ref[:] += a_hat * wdout + + c1 = for_loop(pl.cdiv(n_col, block_size), mean_body, jnp.zeros(block_size)) + c1 = c1.sum() / n_col + + def dx_body(i, acc_ref): + col_idx = i * block_size + jnp.arange(block_size) + mask = col_idx < n_col + a = pl.load(x_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_last").astype(jnp.float32) + dout = pl.load(do_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_last").astype(jnp.float32) + weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0., + eviction_policy="evict_last").astype(jnp.float32) + a_hat = a * rstd_ref[...] + wdout = weight * dout + da = (wdout - (a_hat * c1)) * rstd_ref[...] + pl.store(dx_ref, (col_idx,), da.astype(dx_ref.dtype), mask=mask) + + for_loop(pl.cdiv(n_col, block_size), dx_body, ()) + + +def rms_norm_backward_kernel_dw_db( + # Inputs + x_ref, weight_ref, bias_ref, do_ref, + rstd_ref, + # Outputs + dw_ref, db_ref, + *, eps: float, block_m: int, block_n: int): + m, n_col = x_ref.shape + j = pl.program_id(0) + col_idx = j * block_n + jnp.arange(block_n) + col_mask = col_idx < n_col + + def body(i, acc_ref): + row_idx = i * block_m + jnp.arange(block_m) + row_mask = row_idx < m + mask = row_mask[:, None] & col_mask[None, :] + a = pl.load( + x_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0 + ).astype(jnp.float32) + dout = pl.load( + do_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0 + ).astype(jnp.float32) + rstd = pl.load(rstd_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32) + a_hat = a * rstd[:, None] + dw_acc_ref, db_acc_ref = acc_ref + dw_acc_ref[:] += (dout * a_hat).sum(axis=0) + db_acc_ref[:] += dout.sum(axis=0) + + dw_acc, db_acc = for_loop(pl.cdiv(m, block_m), body, (jnp.zeros(block_n), jnp.zeros(block_n))) + pl.store(dw_ref, (col_idx,), dw_acc.astype(dw_ref.dtype), mask=col_mask) + pl.store(db_ref, (col_idx,), db_acc.astype(db_ref.dtype), mask=col_mask) + + +def rms_norm_backward( + num_warps: Optional[int], + num_stages: Optional[int], + eps: float, + backward_pass_impl: str, + interpret: bool, + res, do): + del num_stages + x, weight, bias, rstd = res + if backward_pass_impl == 'xla': + return jax.vjp(rms_norm_reference, x, weight, bias)[1](do) + + *shape_prefix, n = x.shape + reshaped_x = x.reshape((-1, n)) + reshaped_rstd = rstd.reshape((-1,)) + reshaped_do = do.reshape((-1, n)) + # Triton heuristics + # Less than 64KB per feature: enqueue fused kernel + max_fused_size = 65536 // x.dtype.itemsize + block_size = min(max_fused_size, pl.next_power_of_2(n)) + block_size = min(max(block_size, 128), 4096) + num_warps = min(max(block_size // 256, 1), 8) + + # rms_norm_backward_kernel_dx parallel over batch dims + kernel = functools.partial(rms_norm_backward_kernel_dx, eps=eps, + block_size=block_size) + out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) + method = pl.pallas_call(kernel, num_warps=num_warps, + grid=(), out_shape=out_shape_dx, debug=False, + interpret=interpret, name='ln_backward_dx') + + method = jax.vmap(method, in_axes=(0, None, None, 0, 0)) + dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd) + dx = dx.reshape((*shape_prefix, n)) + + # rms_norm_backward_kernel_dw_db reduce over batch dims + # Triton heuristics + if n > 10240: + block_n = 128 + block_m = 32 + num_warps = 4 + else: + # maximize occupancy for small N + block_n = 16 + block_m = 16 + num_warps = 8 + kernel = functools.partial(rms_norm_backward_kernel_dw_db, eps=eps, + block_m=block_m, block_n=block_n) + out_shape_dwbias = [ + jax.ShapeDtypeStruct(shape=weight.shape, dtype=weight.dtype), + jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype) + ] + grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),) + method = pl.pallas_call(kernel, num_warps=num_warps, + grid=grid_, out_shape=out_shape_dwbias, debug=False, + interpret=interpret, name='ln_backward_dw_db') + dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd) + return dx, dw, dbias + + +@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7]) +@functools.partial(jax.jit, static_argnames=["num_warps", "num_stages", + "num_stages", "eps", + "backward_pass_impl", + "interpret"]) +def rms_norm( + x, weight, bias, + num_warps: Optional[int] = None, + num_stages: Optional[int] = 3, + eps: float = 1e-5, + backward_pass_impl: str = 'triton', + interpret: bool = False): + n = x.shape[-1] + # Triton heuristics + # Less than 64KB per feature: enqueue fused kernel + max_fused_size = 65536 // x.dtype.itemsize + block_size = min(max_fused_size, pl.next_power_of_2(n)) + block_size = min(max(block_size, 128), 4096) + num_warps = min(max(block_size // 256, 1), 8) + + kernel = functools.partial(rms_norm_forward_kernel, eps=eps, + block_size=block_size) + out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) + method = pl.pallas_call(kernel, num_warps=num_warps, num_stages=num_stages, + grid=(), out_shape=out_shape, debug=False, + interpret=interpret) + method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None)) + return method(x, weight, bias) + + +rms_norm.defvjp(rms_norm_forward, rms_norm_backward) + + +@functools.partial(jax.jit, static_argnames=["eps"]) +@functools.partial(jax.vmap, in_axes=(0, None, None), out_axes=0) +def rms_norm_reference(x, weight, bias, *, eps: float = 1e-5): + var = jnp.mean(jnp.square(x), axis=1) + mul = lax.rsqrt(var + eps) + return x * mul[:, None] * weight[None] + bias[None] diff --git a/fjformer/gpu_pallas/softmax.py b/fjformer/gpu_pallas/softmax.py new file mode 100644 index 0000000..a536209 --- /dev/null +++ b/fjformer/gpu_pallas/softmax.py @@ -0,0 +1,86 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pallas softmax kernel.""" +import functools + +import jax +import jax.numpy as jnp +from jax.experimental import pallas as pl + + +def _vmappable_softmax_kernel( + # inputs + input_ref, + # outputs + probs_ref, + *, + # block information + # It is assumed that block_row >= row_len + block_row: int, +): + row_len = input_ref.shape[-1] + + mask = jnp.arange(block_row) < row_len + row = pl.load( + input_ref, (pl.dslice(0, block_row),), mask=mask, other=-float("inf") + ) + + row_max = jnp.max(row, axis=0) + numerator = jnp.exp((row - row_max).astype(jnp.float32)) + denominator = jnp.sum(numerator, axis=0) + + pl.store( + probs_ref, (pl.dslice(0, block_row),), + (numerator / denominator).astype(probs_ref.dtype), + mask=mask + ) + + +@functools.partial(jax.jit, static_argnames=["axis", "num_warps", "interpret", + "debug"]) +def softmax( + x: jax.Array, *, axis: int = -1, num_warps: int = 4, + interpret: bool = False, debug: bool = False +) -> jax.Array: + """Computes the softmax of the input array along the specified axis. + + Args: + x: input array + axis: the axis along which to perform the computation + num_warps: the number of warps to use for executing the Triton kernel + interpret: whether to interpret the kernel using pallas + debug: whether to use pallas in debug mode + + Returns: + The result of the softmax operation over the specified axis of x. + """ + axis = axis if axis >= 0 else len(x.shape) + axis + if axis != len(x.shape) - 1: + raise NotImplementedError( + "reductions along non-trailing dimension unsupported") + + row_len = x.shape[-1] + + block_row = pl.next_power_of_2(row_len) + out_shape = jax.ShapeDtypeStruct(shape=(row_len,), dtype=x.dtype) + + kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row) + f = pl.pallas_call(kernel, num_warps=num_warps, num_stages=1, grid=(), + out_shape=out_shape, debug=debug, interpret=interpret) + + for _ in range(len(x.shape) - 1): + f = jax.vmap(f) + + return f(x) diff --git a/setup.py b/setup.py index b3c98f4..6af0856 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setuptools.setup( name="fjformer", - version='0.0.13', + version='0.0.14', author="Erfan Zare Chavoshi", author_email="erfanzare82@yahoo.com", long_description=long_description,