Skip to content

Commit

Permalink
[Backport] Update Pallas user guide (#6965)
Browse files Browse the repository at this point in the history
  • Loading branch information
alanwaketan authored Apr 24, 2024
1 parent b94c014 commit 6f93cc1
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions docs/pallas.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Custom Kernels via Pallas

With the rise of OpenAI [triton](https://openai.com/research/triton), custom kernels become more and more popular in the GPU community, for instance, the introduction of [FlashAttention](https://github.com/Dao-AILab/flash-attention) and [PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html). In order to provide the feature parity in the TPU world, Google has introduced [Pallas](http://go/jax-pallas) and [Mosaic](http://go/mosaic-tpu). For PyTorch/XLA to continue pushing the performance in TPU, we have to support custom kernels, and the best way is through Pallas and Mosaic. The design doc is [TBA]().

Let's assume you have a Pallas kernel defined as follow:
```python3
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y

@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
```

## Adopt the above kernel to be compatible with PyTorch/XLA

Example usage:
```python3
q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")

# Adopts any Pallas kernel
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)])
output = pt_kernel(q, k)
```
For simple kernels, the adoption is just as simple as one liner. For more complicated kernels, you can refer to our Flash Attention implementation for details.

## Use built-in kernels

Besides manually wrapping external Pallas kernels, there are built-in kernels where the adoptions are done by PyTorch/XLA already.

Example usage:
```python3
# Use built-in kernels
from torch_xla.experimental.custom_kernel import flash_attention
output = flash_attention(q, k, v)
```

You can just use it like any other torch.ops.

## HuggingFace Llama 3 Example
We have a fork of HF Llama 3 to demonstrate a potential integration [here](https://github.com/pytorch-tpu/transformers/tree/alanwaketan/flash_attention).

## Dependencies
The Pallas integration depends on JAX to function. However, not every JAX version is compatible with your installed PyTorch/XLA. To install the proper JAX:
```bash
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
```

0 comments on commit 6f93cc1

Please sign in to comment.