Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable PagedAttention through Pallas #6912

Merged
merged 15 commits into from
Apr 25, 2024
165 changes: 165 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torch_xla import runtime as xr
from torch_xla._internal import tpu

import numpy as np

if xr.device_type() == 'TPU':
from torch_xla.experimental.custom_kernel import jax_import_guard
jax_import_guard()
Expand All @@ -26,6 +28,32 @@ def _attention(self, q, k, v):
attn_output = attn_weight @ v
return attn_output

# The following helper functions prefixed with _pagedattention are used for PagedAttention unit tests
# Reference: https://github.com/google/jax/blob/main/tests/pallas/paged_attention_kernel_test.py
def _pagedattention_generate_qkv(
self,
seq_lens,
page_size,
max_seq_len,
num_kv_heads,
num_heads,
head_dim,
dtype=torch.float32,
):
assert max_seq_len % page_size == 0
pages_per_sequence = max_seq_len // page_size
batch_size = len(seq_lens)
total_pages = batch_size * pages_per_sequence
k_pages = torch.randn(
num_kv_heads, total_pages, page_size, head_dim, dtype=dtype)
v_pages = torch.randn(
num_kv_heads, total_pages, page_size, head_dim, dtype=dtype)
page_indices = torch.randperm(
batch_size * pages_per_sequence, dtype=torch.int32)
page_indices = page_indices.reshape(batch_size, pages_per_sequence)
q = torch.randn(batch_size, num_heads, head_dim, dtype=dtype)
return q, k_pages, v_pages, page_indices

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_add(self):
# This payload is generated by the following Pallas code:
Expand Down Expand Up @@ -454,6 +482,143 @@ def test_flash_attention_backward(self):
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_wrapper(self):
from torch_xla.experimental.custom_kernel import paged_attention
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention

max_kv_len = 2048
block_size = 512
page_size = 64
num_kv_heads = 8
q_kv_head_ratio = 8
head_dim = 256
dtype = torch.float32
seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32)

q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
seq_lens,
page_size,
max_kv_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
)

q_xla = q.to("xla")
k_pages_xla = k_pages.to("xla")
v_pages_xla = v_pages.to("xla")
seq_lens_xla = seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")

output = paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what these seq_lens are? Are these the previous tokens for each batch in k, v?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that is my understanding -- the seq_lens here equals the number of tokens that are processed in the batch. Reference: https://docs.vllm.ai/en/latest/dev/kernel/paged_attention.html#concepts.

page_indices_xla,
pages_per_compute_block=block_size // page_size,
)

q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32)
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
expected_output = torch.from_numpy(
np.array(
jax_paged_attention(
q_jax,
k_pages_jax,
v_pages_jax,
seq_lens_jax,
page_indices_jax,
pages_per_compute_block=block_size // page_size,
)))

self.assertTrue(
torch.allclose(
output.cpu()[seq_lens > 0],
expected_output.cpu()[seq_lens > 0],
atol=1e-5,
rtol=1e-5))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_wrapper_with_dynamo(self):
from torch_xla.experimental.custom_kernel import paged_attention
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention

max_kv_len = 2048
block_size = 512
page_size = 64
num_kv_heads = 8
q_kv_head_ratio = 8
head_dim = 256
dtype = torch.float32
seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32)

q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
seq_lens,
page_size,
max_kv_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
)

q_xla = q.to("xla")
k_pages_xla = k_pages.to("xla")
v_pages_xla = v_pages.to("xla")
seq_lens_xla = seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")

def paged_attention_wrapper(q, k, v, seq_lens, page_indices,
pages_per_compute_block):
return paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
page_indices_xla,
pages_per_compute_block=block_size // page_size,
)

compiled_paged_attention = torch.compile(
paged_attention_wrapper, backend="openxla")
output = paged_attention_wrapper(
q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
page_indices_xla,
pages_per_compute_block=block_size // page_size,
)

q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32)
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
expected_output = torch.from_numpy(
np.array(
jax_paged_attention(
q_jax,
k_pages_jax,
v_pages_jax,
seq_lens_jax,
page_indices_jax,
pages_per_compute_block=block_size // page_size,
)))

self.assertTrue(
torch.allclose(
output.cpu()[seq_lens > 0],
expected_output.cpu()[seq_lens > 0],
atol=1e-5,
rtol=1e-5))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
93 changes: 83 additions & 10 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,67 @@ def flash_attention(
return FlashAttention.apply(q, k, v, causal, partition_spec, mesh)


def paged_attention(q, k_pages, v_pages, lengths, page_indices,
wonjoolee95 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original kernel has this thing called: q_dtype_for_kernel_launch? What does it do? Should we copy that as well?

Copy link
Collaborator Author

@wonjoolee95 wonjoolee95 Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the original kernel, the q_dtype_for_kernel_launch is always either jnp.float32 or q's dtype. In our case, I'm expecting the passed-in q's dtype to be torch.float32, so the q_dtype_for_kernel_launch will always be float32.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't think that will be the case for actual workflow. It could be bf16 or even in8 etc...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense. Just updated to handle q_dtype_for_kernel_launch, following jax's kernel -- https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L393. I can follow-up in another PR to add some more unit tests for different dtypes for q.

pages_per_compute_block):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention

payload, tensor_args = trace_pallas(
paged_attention,
q,
k_pages,
v_pages,
lengths,
page_indices,
pages_per_compute_block=pages_per_compute_block,
static_argnames=["pages_per_compute_block"],
)

batch_size, num_heads, head_dim = q.shape
num_kv_heads, _, page_size, head_dim_k = k_pages.shape
batch_size_paged_indices, pages_per_sequence = page_indices.shape
q_output_dtype = torch.float32
if (num_heads // num_kv_heads) % 8 != 0:
q = q.reshape(batch_size, num_heads, 1, head_dim)
q_output_dtype = q.dtype

page_indices_reshaped = page_indices.reshape(-1)
buffer_index = torch.zeros((1,), dtype=torch.int32).to("xla")
step = torch.zeros((1,), dtype=torch.int32).to("xla")
output_shape = torch.Size(list(q.shape[:-1]) + [1])

output, _, _ = torch_xla._XLAC._xla_tpu_custom_call(
[
lengths,
page_indices_reshaped,
buffer_index,
step,
q,
k_pages,
v_pages,
], payload, [q.shape, output_shape, output_shape],
[q_output_dtype, torch.float32, torch.float32])

return output.reshape(batch_size, num_heads, head_dim).to(q.dtype)


def non_xla_attetion(q, k, v, attention_type):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
if k.device != torch.device("meta"):
warnings.warn(
f'XLA {attention_type} attention should only be applied to tensors on XLA device'
)

# perform a regular attention if input tensors are not on XLA device.
attn_weight = q @ k.transpose(-2, -1)
attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output


XLA_LIB.define(
"flash_attention(Tensor q, Tensor k, Tensor v, bool casual=False) -> Tensor",
)
Expand All @@ -389,14 +450,26 @@ def flash_attention_non_xla(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
if k.device != torch.device("meta"):
warnings.warn(
'XLA flash attention should only be applied to tensors on XLA device')
return non_xla_attetion(q, k, v, "flash")

# perform a regular attention if input tensors are not on XLA device.
attn_weight = q @ k.transpose(-2, -1)
attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output

XLA_LIB.define(
"paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block) -> Tensor",
)


@impl(XLA_LIB, "paged_attention", "XLA")
def paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor,
v_pages: torch.Tensor, lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int):
return paged_attention(q, k_pages, v_pages, lengths, page_indices,
pages_per_compute_block)


@impl(XLA_LIB, "paged_attention", "CompositeExplicitAutograd")
def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor,
v_pages: torch.Tensor, lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int):
return non_xla_attetion(q, k, v, "paged")
Loading