Skip to content

Commit

Permalink
Adding and improving Documentations 🧬
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Nov 29, 2023
1 parent 5585e50 commit e12a40d
Show file tree
Hide file tree
Showing 16 changed files with 1,050 additions and 96 deletions.
104 changes: 19 additions & 85 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,99 +1,33 @@
# FJFormer

a package for custom Jax Flax Functions and Utils
Welcome to fjformer - A collection of useful functions and utilities for Flax and JAX!
Welcome to FJFormer - A collection of useful functions and utilities for Flax and JAX!

## Overview

fjformer is a collection of functions and utilities that can help with various tasks when using Flax and JAX. It
FJFormer is a collection of functions and utilities that can help with various tasks when using Flax and JAX. It
includes
checkpoint savers, partitioning tools, and other helpful functions.
The goal of fjformer is to make your life easier when working with Flax and JAX. Whether you are training a new model,
fine-tuning an existing one, or just exploring the capabilities of these powerful frameworks, fjformer has something to
offer.

## Features

Here are some of the features included in fjformer:

Checkpoint saver: This tool provides an easy way to save and restore checkpoints during training. You can specify how
often to save checkpoints, where to store them, and more.

Partitioning tools: fjformer includes several tools for partitioning data across multiple devices or nodes. These tools
can help you optimize the performance of your models on clusters or distributed systems.

Other utilities: fjformer includes a variety of other helpful functions and utilities and more.

## Getting Started

To get started with fjformer, simply install the package using pip:

```shell
pip install fjformer
```

Once installed, you can import the package and start using its functions and utilities. For example, here's how you can
use the checkpoint saver for loading models like :

```python
from fjformer import StreamingCheckpointer

ckpt = StreamingCheckpointer.load_trainstate_checkpoint('params::<path to model>')

```

or simply getting an optimizer for example adafactor with cosine scheduler :

```python
from jax import numpy as jnp
from fjformer.optimizers import get_adafactor_with_cosine_scheduler

optimizer, scheduler = get_adafactor_with_cosine_scheduler(
steps=5000,
learning_rate=5e-5,
weight_decay=1e-1,
min_dim_size_to_factor=128,
decay_rate=0.8,
decay_offset=0,
multiply_by_parameter_scale=True,
clipping_threshold=1.0,
momentum=None,
dtype_momentum=jnp.float32,
weight_decay_rate=None,
eps=1e-30,
factored=True,
weight_decay_mask=None,
)

```

or getting adamw with linear scheduler:

```python
from fjformer.optimizers import get_adamw_with_linear_scheduler

optimizer, scheduler = get_adamw_with_linear_scheduler(
steps=5000,
learning_rate_start=5e-5,
learning_rate_end=1e-5,
b1=0.9,
b2=0.999,
eps=1e-8,
eps_root=0.0,
weight_decay=1e-1,
mu_dtype=None,
)

```

## Documentation

Documentations are available [here](https://erfanzar.github.io/fjformer/docs)
The goal of FJFormer is to make your life easier when working with Flax and JAX. Whether you are training a new model,
fine-tuning an existing one, or just exploring the capabilities of these powerful frameworks, FJFormer offers

- FlashAttention on `TPU/GPU`
- BITComputations for 8,6,4 BIT Flax Models
- Smart Dataset Loading
- Built-in functions and Loss functions
- GPU-Pallas triton like implementation of `Softmax`, `FlashAttention`, `RMSNorm`, `LayerNorm`
- Distributed and sharding Model Loaders and Checkpoint Savers
- Monitoring Utils for *TPU/GPU/CPU* memory `foot-print`
- Special Optimizers with schedulers and Easy to Use
- Partitioning Utils

and A lot of these features are fully documented so i gusse FJFormer has something
to offer and it's not just a Computation BackEnd for [EasyDel](https://github.com/erfanzar/EasyDel).

## Contributing

fjformer is an open-source project, and contributions are always welcome! If you have a feature request, bug report, or
FJFormer is an open-source project, and contributions are always welcome! If you have a feature request, bug report, or
just want to help out with development, please check out our GitHub repository and feel free to submit a pull request or
open an issue.

Thank you for using fjformer, and happy training!
Thank you for using FJFormer, and happy training!
2 changes: 1 addition & 1 deletion fjformer/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
dot_product_attention_queries_per_head
from .flash_attention import ring_attention, ring_attention_standard, ring_flash_attention_gpu, \
ring_flash_attention_tpu, blockwise_ffn, blockwise_attn
from .jax_flash_attn_tpu import flash_attention as tpu_flash_attention
from .jax_flash_attn_tpu import flash_attention as tpu_flash_attention, BlockSizes
from .jax_flash_attn_gpu import mha as gpu_flash_attention
17 changes: 17 additions & 0 deletions fjformer/attention/efficient_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,23 @@ def _chunk_attention_bias(
query_chunk_idx: int,
key_chunk_idx: int
):
"""
The _chunk_attention_bias function is used to compute the attention bias for a single chunk of
the query and key tensors. The function takes in the following arguments:
:param query_chunk_size: int: Determine the size of the query chunk
:param key_chunk_size: int: Determine the size of the key_chunk
:param bias: chex.Array: Mask out the attention weights
:param deterministic: bool: Determine whether to use dropout or not
:param attn_dropout: chex.Array: Drop out attention weights
:param attention_drop_rate: float: Determine the dropout rate for attention
:param causal: bool: Determine if the attention is causal or not
:param dtype: chex.ArrayDType: Specify the data type of the array
:param query_chunk_idx: int: Select the query chunk
:param key_chunk_idx: int: Determine the key_offset
:return: A chunk of the attention bias
"""
query_offset = query_chunk_idx * query_chunk_size
key_offset = key_chunk_idx * key_chunk_size
chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
Expand Down
76 changes: 76 additions & 0 deletions fjformer/attention/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,23 @@ def scan_kv_block(carry, idx):

@partial(jax.custom_vjp, nondiff_argnums=[4, 5, 6])
def ring_flash_attention_tpu(q, k, v, attn_bias, axis_name, float32_logits, blockwise_kwargs):
"""
The ring_flash_attention_tpu function is a TPU-specific implementation of the ring_flash_attention function.
It takes in the same arguments as ring_flash_attention, but it also takes in blockwise_kwargs, which are used to
specify how to partition the input tensors into blocks for parallel processing on TPUs. The blockwise kwargs are:
- num_blocks: number of blocks to split each dimension into (default 1)
- dims: list of dimensions that will be split into blocks (default [0])
:param q: Compute the attention weights
:param k: Define the key, and the v parameter is used to define the value
:param v: Compute the attention weights and then used to multiply with the attention weights
:param attn_bias: Mask out the attention weights for certain positions
:param axis_name: Specify the axis along which the attention is applied
:param float32_logits: Determine whether to use float32 or int32 for the logits
:param blockwise_kwargs: Pass the blockwise_compute function to ring_flash_attention
:return: A tuple
"""
y, _ = _ring_flash_attention_fwd_tpu(q, k, v, attn_bias, axis_name, float32_logits, blockwise_kwargs)
return y

Expand All @@ -711,6 +728,21 @@ def ring_flash_attention_tpu(q, k, v, attn_bias, axis_name, float32_logits, bloc


def ring_flash_dummy_gpu(q, k, v, attn_bias, axis_name, float32_logits, blockwise_kwargs):
"""
The ring_flash_dummy_gpu function is a dummy function that does not actually perform any computation.
It is used to determine the GPU memory usage of the ring_flash_gpu function, which performs multi-head attention.
The ring_flash_dummy_gpu function has exactly the same inputs and outputs as ring flash gpu, but it does not perform any computation.
:param q: Compute the query_chunk_size
:param k: Calculate the output of the ring_flash_dummy function
:param v: Compute the output
:param attn_bias: Mask the attention
:param axis_name: Specify which axis to split the tensor along
:param float32_logits: Determine whether to use float32 or float64 for the q and k matrices
:param blockwise_kwargs: Pass parameters to the ring_flash_dummy_gpu function
:return: A tuple of (output, none)
"""
if float32_logits:
q, k = q.astype(jnp.float32), k.astype(jnp.float32)
attn_bias = attn_bias[:, 0, 0] # (batch, q_len)
Expand All @@ -723,6 +755,23 @@ def ring_flash_dummy_gpu(q, k, v, attn_bias, axis_name, float32_logits, blockwis


def _ring_flash_attention_fwd_gpu(q, k, v, attn_bias, axis_name, float32_logits, blockwise_kwargs):
"""
The _ring_flash_attention_fwd_gpu function is a GPU-only implementation of the ring flash attention
forward pass. It takes in query, key, and value tensors (q, k, v), an attention bias tensor (attn_bias),
and axis name for the blockwise dimension. The function then performs a scan over all blocks in the blockwise
dimension to compute output values for each query position. This is done by first computing logits using q and k;
then applying softmax to get probabilities; then multiplying these probabilities with v to get output values.
:param q: Compute the logits
:param k: Compute the logits
:param v: Compute the output, but what is it?
:param attn_bias: Mask out certain positions in the attention matrix
:param axis_name: Specify the axis along which to perform ring flash attention
:param float32_logits: Determine whether to use float32 or bfloat16 for the logits
:param blockwise_kwargs: Pass in the parameters for the _mha_forward function
:return: The output of the attention
"""
if float32_logits:
q, k = q.astype(jnp.float32), k.astype(jnp.float32)
batch, q_len, num_heads, dim_per_head = q.shape
Expand Down Expand Up @@ -781,6 +830,18 @@ def scan_kv_block(carry, idx):


def _ring_flash_attention_bwd_gpu(axis_name, float32_logits, blockwise_kwargs, res, g):
"""
The _ring_flash_attention_bwd_gpu function is a GPU-specific implementation of the backward pass for
the ring flash attention mechanism. It takes in the following arguments:
:param axis_name: Specify the axis that is being sharded
:param float32_logits: Determine whether to use the float32 or float16 version of the function
:param blockwise_kwargs: Pass the blockwise_kwargs dictionary to the _ring_flash_attention function
:param res: Pass in the result of the forward pass
:param g: Pass the gradient of the output to this function
:return: A tuple of 4 elements
"""
del float32_logits
o, q, k, v, attn_bias, l, m = res
batch, kv_len, num_heads, dim_per_head = k.shape
Expand Down Expand Up @@ -834,6 +895,21 @@ def scan_kv_block(carry, idx):

@partial(jax.custom_vjp, nondiff_argnums=[4, 5, 6])
def ring_flash_attention_gpu(q, k, v, attn_bias, axis_name, float32_logits, blockwise_kwargs):
"""
The ring_flash_attention_gpu function is a GPU-optimized version of the ring_flash_attention function.
It uses the same algorithm as ring_flash_attention, but it has been optimized for speed on GPUs.
The main difference between this function and its CPU counterpart is that this one uses blockwise matrix multiplication to compute attention scores in parallel across multiple blocks of data, rather than computing them sequentially across all data points. This allows us to take advantage of the massive parallelism available on modern GPUs.
:param q: Query the key-value pairs
:param k: Determine the number of heads
:param v: Pass the values to the attention function
:param attn_bias: Mask the attention weights
:param axis_name: Specify which axis to perform the attention on
:param float32_logits: Determine whether to use float32 or float16
:param blockwise_kwargs: Pass the blockwise_attention function parameters
:return: The output of the attention layer
"""
y, _ = _ring_flash_attention_fwd_gpu(q, k, v, attn_bias, axis_name, float32_logits, blockwise_kwargs)
return y

Expand Down
Loading

0 comments on commit e12a40d

Please sign in to comment.