Skip to content

Commit

Permalink
Update unit tests and fix typos
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Apr 16, 2024
1 parent 9dac216 commit f32836e
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 32 deletions.
131 changes: 108 additions & 23 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def _attention(self, q, k, v):
return attn_output

# The following helper functions prefixed with _pagedattention are used for PagedAttention unit tests
# Reference: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py
# Reference: https://github.com/google/jax/blob/main/tests/pallas/paged_attention_kernel_test.py
def _pagedattention_generate_qkv(
self,
seq_lens,
page_size,
max_seq_len,
Expand All @@ -37,45 +38,44 @@ def _pagedattention_generate_qkv(
head_dim,
dtype=torch.float32,
):
assert max_seq_len % page_size == 0
# assert max_seq_len % page_size == 0
pages_per_sequence = max_seq_len // page_size
batch_size = len(seq_lens)
total_pages = batch_size * pages_per_sequence
k1, k2, k3, k4 = jax.random.split(prng_key, 4)
k_pages = torch.randn(
num_kv_heads, total_pages, page_size, head_dim, dtype=dtype)
v_pages = torch.randn(
num_kv_heads, total_pages, page_size, head_dim, dtype=dtype)
page_indices = torch.randperm(
batch_size * pages_per_sequence, dtype=torch.int32)
batch_size * pages_per_sequence, dtype=torch.int64)
page_indices = page_indices.reshape(batch_size, pages_per_sequence)
q = torch.randn(batch_size, num_heads, head_dim, dtype=dtype)
return q, k_pages, v_pages, page_indices

def _pagedattention_reconstruct_kv(page_indices, pages):
def _pagedattention_reconstruct_kv(self, page_indices, pages):
batch_size = page_indices.shape[0]
num_heads, _, _, head_dim = pages.shape

def per_sequence_page_gather(pages, page_indices):
return torch.gather(
torch_pages, dim=1, index=torch_page_indices.unsqueeze(1))
return torch.index_select(pages, dim=1, index=page_indices)

gathered = torch.vmap(
per_sequence_page_gather, in_dims=(None, 0))(pages, page_indices)

return gathered.reshape(batch_size, num_heads, -1, head_dim)

def _pagedattention_grouped_query_attention_reference(q, k, v, lengths):
def _pagedattention_grouped_query_attention_reference(self, q, k, v, lengths):
batch_size, num_heads, head_dim = q.shape
_, num_kv_heads, max_seq_len, _ = k.shape
assert k.shape == v.shape
assert num_heads % num_kv_heads == 0
q = q.reshape(batch_size, num_kv_heads, num_heads // num_kv_heads, head_dim)
logits = torch.einsum("bhgd, bhtd -> bhgt", q.float(), k.float())
logits = torch.einsum("bhgd,bhtd->bhgt", q.float(), k.float())
mask = torch.arange(max_seq_len)[None, :] < lengths[:, None]
mask_value = -0.7 * torch.finfo(torch.float32).max
logits = logits.masked_fill(~mask, mask_value)
mask_value = -0.7 * float(torch.finfo(torch.float32).max)
logits = logits.masked_fill(mask[:, None, None, :], mask_value)
weights = torch.softmax(logits, dim=-1)
o = torch.einsum("bhgt, bhtd -> bhgd", weights, v.to(weights.dtype))
o = torch.einsum("bhgt,bhtd->bhgd", weights.to(v.dtype), v)
return o.reshape(batch_size, num_heads, head_dim)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
Expand Down Expand Up @@ -534,19 +534,104 @@ def test_flash_attention_backward(self):
def test_tpu_custom_call_pallas_wrap_paged_attention(self):
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
paged_attention_kernel = make_kernel_from_pallas(
paged_attention, lambda q, k, v: [(q.shape, q.dtype)])

q_mini = torch.arange(128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13
k_mini = torch.arange(
1000, 1000 + 128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13
q = q_mini.broadcast_to(3, 2, 128, 4).to("xla")
k = k_mini.broadcast_to(3, 2, 128, 4).to("xla")
v = torch.ones(3, 2, 128, 4, dtype=torch.bfloat16).to("xla")
def shape_dtype(q, *args):
return [(q.shape, q.dtype)]

o = paged_attention_kernel(q, k, v)
expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))
paged_attention_kernel = make_kernel_from_pallas(
paged_attention,
shape_dtype,
static_argnames=['pages_per_compute_block'])

batch_size = 4
max_kv_len = 2048
block_size = 512
page_size = 16
num_kv_heads = 1
q_kv_head_ratio = 1
head_dim = 128
dtype = torch.float32
seq_lens = torch.tensor(
[max_kv_len // batch_size * (i + 1) for i in range(batch_size)])

q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
seq_lens,
page_size,
max_kv_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
dtype=dtype,
)

q_xla = q.to("xla")
k_pages_xla = k_pages.to("xla")
v_pages_xla = v_pages.to("xla")
seq_lens_xla = seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")
o = paged_attention_kernel(
q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
page_indices_xla,
pages_per_compute_block=block_size // page_size,
)
k = self._pagedattention_reconstruct_kv(page_indices, k_pages)
v = self._pagedattention_reconstruct_kv(page_indices, v_pages)

o_expected = self._pagedattention_grouped_query_attention_reference(
q, k, v, seq_lens)

self.assertEqual(o.shape, o_expected.shape)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_wrapper(self):
from torch_xla.experimental.custom_kernel import paged_attention

batch_size = 4
max_kv_len = 2048
block_size = 512
page_size = 16
num_kv_heads = 1
q_kv_head_ratio = 1
head_dim = 128
dtype = torch.float32
seq_lens = torch.tensor(
[max_kv_len // batch_size * (i + 1) for i in range(batch_size)])

q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
seq_lens,
page_size,
max_kv_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
dtype=dtype,
)

q_xla = q.to("xla")
k_pages_xla = k_pages.to("xla")
v_pages_xla = v_pages.to("xla")
seq_lens_xla = seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")

o = paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
page_indices_xla,
pages_per_compute_block=block_size // page_size,
)
k = self._pagedattention_reconstruct_kv(page_indices, k_pages)
v = self._pagedattention_reconstruct_kv(page_indices, v_pages)

o_expected = self._pagedattention_grouped_query_attention_reference(
q, k, v, seq_lens)

self.assertEqual(o.shape, o_expected.shape)


if __name__ == '__main__':
Expand Down
39 changes: 30 additions & 9 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
return payload, tensor_args


def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
def make_kernel_from_pallas(kernel: Callable,
output_shape_dtype_fn: Callable,
static_argnums: List[int] = None,
static_argnames: List[str] = None):
# TODO: Maybe we can cache the payload for the same input.
def wrapped_kernel(kernel: Callable,
output_shape_dtype_fn: Callable,
Expand All @@ -156,7 +159,12 @@ def wrapped_kernel(kernel: Callable,
return outputs[0]
return tuple(outputs)

return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn)
return functools.partial(
wrapped_kernel,
kernel,
output_shape_dtype_fn,
static_argnums=static_argnums,
static_argnames=static_argnames)


class FlashAttention(torch.autograd.Function):
Expand Down Expand Up @@ -339,19 +347,26 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices,
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention

# It returns the shape and type of o, l, m.
def shape_dtype(q, *arg):
def shape_dtype(q, *args):
return [(q.shape, q.dtype)]

paged_attention_kernel = make_kernel_from_pallas(paged_attention, shape_dtype)
o = paged_attention_kernel(q, k_pages, v_pages, lengths, page_indices,
pages_per_compute_block)
paged_attention_kernel = make_kernel_from_pallas(
paged_attention, shape_dtype, static_argnames=['pages_per_compute_block'])

o = paged_attention_kernel(
q,
k_pages,
v_pages,
lengths,
page_indices,
pages_per_compute_block=pages_per_compute_block,
)

return o


XLA_LIB.define(
"flash_attention(Tensor q, Tensor k, Tensor v, bool casual=False) -> Tensor",
"paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block) -> Tensor[]",
)


Expand Down Expand Up @@ -381,12 +396,18 @@ def flash_attention_non_xla(q: torch.Tensor,
return attn_output


XLA_LIB.define(
"paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block) -> Tensor",
)


@impl(XLA_LIB, "paged_attention", "XLA")
def paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor,
v_pages: torch.Tensor, lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int):
return flash_attention(q, k, v, causal=causal)
return paged_attention(q, k_pages, v_pages, lengths, page_indices,
pages_per_compute_block)


@impl(XLA_LIB, "paged_attention", "CompositeExplicitAutograd")
Expand All @@ -398,7 +419,7 @@ def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor,
# We need to make sure output tensor's shape is correct.
if k.device != torch.device("meta"):
warnings.warn(
'XLA flash attention should only be applied to tensors on XLA device')
'XLA paged attention should only be applied to tensors on XLA device')

# perform a regular attention if input tensors are not on XLA device.
attn_weight = q @ k.transpose(-2, -1)
Expand Down

0 comments on commit f32836e

Please sign in to comment.