-
Notifications
You must be signed in to change notification settings - Fork 488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable PagedAttention through Pallas #6912
Changes from all commits
bfdac9f
f1893ff
6608939
7208441
e2948b5
acfc289
a7f7a87
59280cf
9c4a4cc
053ddef
2e48553
22c4f27
7f8dc98
4e5b0b0
7818d0f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -371,6 +371,67 @@ def flash_attention( | |
return FlashAttention.apply(q, k, v, causal, partition_spec, mesh) | ||
|
||
|
||
def paged_attention(q, k_pages, v_pages, lengths, page_indices, | ||
wonjoolee95 marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original kernel has this thing called: q_dtype_for_kernel_launch? What does it do? Should we copy that as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the original kernel, the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I don't think that will be the case for actual workflow. It could be bf16 or even in8 etc... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, makes sense. Just updated to handle |
||
pages_per_compute_block): | ||
# Import JAX within the function such that we don't need to call the jax_import_guard() | ||
# in the global scope which could cause problems for xmp.spawn. | ||
jax_import_guard() | ||
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention | ||
|
||
payload, tensor_args = trace_pallas( | ||
paged_attention, | ||
q, | ||
k_pages, | ||
v_pages, | ||
lengths, | ||
page_indices, | ||
pages_per_compute_block=pages_per_compute_block, | ||
static_argnames=["pages_per_compute_block"], | ||
) | ||
|
||
batch_size, num_heads, head_dim = q.shape | ||
num_kv_heads, _, page_size, head_dim_k = k_pages.shape | ||
batch_size_paged_indices, pages_per_sequence = page_indices.shape | ||
q_output_dtype = torch.float32 | ||
if (num_heads // num_kv_heads) % 8 != 0: | ||
q = q.reshape(batch_size, num_heads, 1, head_dim) | ||
q_output_dtype = q.dtype | ||
|
||
page_indices_reshaped = page_indices.reshape(-1) | ||
buffer_index = torch.zeros((1,), dtype=torch.int32).to("xla") | ||
step = torch.zeros((1,), dtype=torch.int32).to("xla") | ||
output_shape = torch.Size(list(q.shape[:-1]) + [1]) | ||
|
||
output, _, _ = torch_xla._XLAC._xla_tpu_custom_call( | ||
[ | ||
lengths, | ||
page_indices_reshaped, | ||
buffer_index, | ||
step, | ||
q, | ||
k_pages, | ||
v_pages, | ||
], payload, [q.shape, output_shape, output_shape], | ||
[q_output_dtype, torch.float32, torch.float32]) | ||
|
||
return output.reshape(batch_size, num_heads, head_dim).to(q.dtype) | ||
|
||
|
||
def non_xla_attetion(q, k, v, attention_type): | ||
# This will be called when dynamo use fake tensor to construct the fake output. | ||
# We need to make sure output tensor's shape is correct. | ||
if k.device != torch.device("meta"): | ||
warnings.warn( | ||
f'XLA {attention_type} attention should only be applied to tensors on XLA device' | ||
) | ||
|
||
# perform a regular attention if input tensors are not on XLA device. | ||
attn_weight = q @ k.transpose(-2, -1) | ||
attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) | ||
attn_output = attn_weight @ v | ||
return attn_output | ||
|
||
|
||
XLA_LIB.define( | ||
"flash_attention(Tensor q, Tensor k, Tensor v, bool casual=False) -> Tensor", | ||
) | ||
|
@@ -389,14 +450,26 @@ def flash_attention_non_xla(q: torch.Tensor, | |
k: torch.Tensor, | ||
v: torch.Tensor, | ||
causal: bool = False): | ||
# This will be called when dynamo use fake tensor to construct the fake output. | ||
# We need to make sure output tensor's shape is correct. | ||
if k.device != torch.device("meta"): | ||
warnings.warn( | ||
'XLA flash attention should only be applied to tensors on XLA device') | ||
return non_xla_attetion(q, k, v, "flash") | ||
|
||
# perform a regular attention if input tensors are not on XLA device. | ||
attn_weight = q @ k.transpose(-2, -1) | ||
attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) | ||
attn_output = attn_weight @ v | ||
return attn_output | ||
|
||
XLA_LIB.define( | ||
"paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block) -> Tensor", | ||
) | ||
|
||
|
||
@impl(XLA_LIB, "paged_attention", "XLA") | ||
def paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor, | ||
v_pages: torch.Tensor, lengths: torch.Tensor, | ||
page_indices: torch.Tensor, | ||
pages_per_compute_block: int): | ||
return paged_attention(q, k_pages, v_pages, lengths, page_indices, | ||
pages_per_compute_block) | ||
|
||
|
||
@impl(XLA_LIB, "paged_attention", "CompositeExplicitAutograd") | ||
def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor, | ||
v_pages: torch.Tensor, lengths: torch.Tensor, | ||
page_indices: torch.Tensor, | ||
pages_per_compute_block: int): | ||
return non_xla_attetion(q, k, v, "paged") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what these seq_lens are? Are these the previous tokens for each batch in k, v?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, that is my understanding -- the
seq_lens
here equals the number of tokens that are processed in the batch. Reference: https://docs.vllm.ai/en/latest/dev/kernel/paged_attention.html#concepts.