Skip to content

Commit

Permalink
Run linter
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Apr 16, 2024
1 parent 4b60c11 commit 9dac216
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 30 deletions.
37 changes: 20 additions & 17 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,28 @@ def _attention(self, q, k, v):
attn_output = attn_weight @ v
return attn_output

# The following helper functions prefixed with _pagedattention are used to help test PagedAttention
# The following helper functions prefixed with _pagedattention are used for PagedAttention unit tests
# Reference: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py
def _pagedattention_generate_qkv(
seq_lens,
page_size,
max_seq_len,
num_kv_heads,
num_heads,
head_dim,
dtype = torch.float32,
seq_lens,
page_size,
max_seq_len,
num_kv_heads,
num_heads,
head_dim,
dtype=torch.float32,
):
assert max_seq_len % page_size == 0
pages_per_sequence = max_seq_len // page_size
batch_size = len(seq_lens)
total_pages = batch_size * pages_per_sequence
k1, k2, k3, k4 = jax.random.split(prng_key, 4)
k_pages = torch.randn(num_kv_heads, total_pages, page_size, head_dim, dtype=dtype)
v_pages = torch.randn(num_kv_heads, total_pages, page_size, head_dim, dtype=dtype)
page_indices = torch.randperm(batch_size * pages_per_sequence, dtype=torch.int32)
k_pages = torch.randn(
num_kv_heads, total_pages, page_size, head_dim, dtype=dtype)
v_pages = torch.randn(
num_kv_heads, total_pages, page_size, head_dim, dtype=dtype)
page_indices = torch.randperm(
batch_size * pages_per_sequence, dtype=torch.int32)
page_indices = page_indices.reshape(batch_size, pages_per_sequence)
q = torch.randn(batch_size, num_heads, head_dim, dtype=dtype)
return q, k_pages, v_pages, page_indices
Expand All @@ -54,11 +57,11 @@ def _pagedattention_reconstruct_kv(page_indices, pages):
num_heads, _, _, head_dim = pages.shape

def per_sequence_page_gather(pages, page_indices):
return torch.gather(torch_pages, dim=1, index=torch_page_indices.unsqueeze(1))
return torch.gather(
torch_pages, dim=1, index=torch_page_indices.unsqueeze(1))

gathered = torch.vmap(per_sequence_page_gather, in_dims=(None, 0))(
pages, page_indices
)
gathered = torch.vmap(
per_sequence_page_gather, in_dims=(None, 0))(pages, page_indices)
return gathered.reshape(batch_size, num_heads, -1, head_dim)

def _pagedattention_grouped_query_attention_reference(q, k, v, lengths):
Expand All @@ -67,10 +70,10 @@ def _pagedattention_grouped_query_attention_reference(q, k, v, lengths):
assert k.shape == v.shape
assert num_heads % num_kv_heads == 0
q = q.reshape(batch_size, num_kv_heads, num_heads // num_kv_heads, head_dim)
logits = torch.einsum("bhgd, bhtd -> bhgt", q.float(), k.float())
logits = torch.einsum("bhgd, bhtd -> bhgt", q.float(), k.float())
mask = torch.arange(max_seq_len)[None, :] < lengths[:, None]
mask_value = -0.7 * torch.finfo(torch.float32).max
logits = logits.masked_fill(~mask, mask_value)
logits = logits.masked_fill(~mask, mask_value)
weights = torch.softmax(logits, dim=-1)
o = torch.einsum("bhgt, bhtd -> bhgd", weights, v.to(weights.dtype))
return o.reshape(batch_size, num_heads, head_dim)
Expand Down
34 changes: 21 additions & 13 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,20 +331,22 @@ def flash_attention(
return FlashAttention.apply(q, k, v, causal)


def paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention
def paged_attention(q, k_pages, v_pages, lengths, page_indices,
pages_per_compute_block):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention

# It returns the shape and type of o, l, m.
def shape_dtype(q, *arg):
return [(q.shape, q.dtype)]
# It returns the shape and type of o, l, m.
def shape_dtype(q, *arg):
return [(q.shape, q.dtype)]

paged_attention_kernel = make_kernel_from_pallas(paged_attention, shape_dtype)
o = paged_attention_kernel(q, k_pages, v_pages, lengths, page_indices, pages_per_compute_block)
paged_attention_kernel = make_kernel_from_pallas(paged_attention, shape_dtype)
o = paged_attention_kernel(q, k_pages, v_pages, lengths, page_indices,
pages_per_compute_block)

return o
return o


XLA_LIB.define(
Expand Down Expand Up @@ -380,12 +382,18 @@ def flash_attention_non_xla(q: torch.Tensor,


@impl(XLA_LIB, "paged_attention", "XLA")
def paged_attention_xla(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block):
def paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor,
v_pages: torch.Tensor, lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int):
return flash_attention(q, k, v, causal=causal)


@impl(XLA_LIB, "paged_attention", "CompositeExplicitAutograd")
def paged_attention_non_xla(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block):
def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor,
v_pages: torch.Tensor, lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
if k.device != torch.device("meta"):
Expand Down

0 comments on commit 9dac216

Please sign in to comment.