-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable PagedAttention through Pallas #6912
Conversation
50dac57
to
b6822a3
Compare
b0262b0
to
72fdd57
Compare
45d9fe3
to
f07c7e3
Compare
f07c7e3
to
1f70ed9
Compare
1f70ed9
to
f32836e
Compare
cc @WoosukKwon to take a look |
cc5ad3a
to
312bef1
Compare
Locally, the tests are succeeding on my v4:
I also just triggered the TPU CI on this PR. |
c6040cf
to
19e28f8
Compare
The CPU CI is failing with an unrelated test:
The CI including the TPU CI is passing, so this PR should be good to be reviewed. Thanks! |
test/test_pallas.py
Outdated
torch.allclose( | ||
output.cpu()[seq_lens > 0], | ||
expected_output.cpu()[seq_lens > 0], | ||
atol=1e-1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wdyt we use a tighter bound for atol and rtol? e.g. 1e-3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sg, updated to 1e-5 for both tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @wonjoolee95 - left a comment for you to eval and address - approving to unblock you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, it looks good to me. Left a few comments.
@@ -331,6 +331,51 @@ def flash_attention( | |||
return FlashAttention.apply(q, k, v, causal) | |||
|
|||
|
|||
def paged_attention(q, k_pages, v_pages, lengths, page_indices, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original kernel has this thing called: q_dtype_for_kernel_launch? What does it do? Should we copy that as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the original kernel, the q_dtype_for_kernel_launch
is always either jnp.float32
or q's dtype
. In our case, I'm expecting the passed-in q's dtype
to be torch.float32
, so the q_dtype_for_kernel_launch
will always be float32
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I don't think that will be the case for actual workflow. It could be bf16 or even in8 etc...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, makes sense. Just updated to handle q_dtype_for_kernel_launch
, following jax's kernel -- https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L393. I can follow-up in another PR to add some more unit tests for different dtypes for q
.
pages_per_compute_block: int): | ||
# This will be called when dynamo use fake tensor to construct the fake output. | ||
# We need to make sure output tensor's shape is correct. | ||
if k.device != torch.device("meta"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels like this part can be consolidated with the flash attention one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sg, refactored these into a helper function.
test/test_pallas.py
Outdated
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, | ||
"This test only works on TPUv4+.") | ||
def test_paged_attention_wrapper(self): | ||
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's interesting that you use jax as the reference. I guess that works too. Wondering if we can just the eager attention helper in the class instead? Or that doesn't work? Anyway, if you are using jax as the reference, you can drop this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sg, yeah I saw that we're dependent on JAX Pallas anyways, so I thought it may be easier to just test against the JAX's outputs.
Ah, makes sense. Just removed the jax.config
updates.
q_xla, | ||
k_pages_xla, | ||
v_pages_xla, | ||
seq_lens_xla, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what these seq_lens are? Are these the previous tokens for each batch in k, v?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, that is my understanding -- the seq_lens
here equals the number of tokens that are processed in the batch. Reference: https://docs.vllm.ai/en/latest/dev/kernel/paged_attention.html#concepts.
19e28f8
to
b3a5948
Compare
# We need to make sure output tensor's shape is correct. | ||
if k.device != torch.device("meta"): | ||
warnings.warn( | ||
'XLA flash attention should only be applied to tensors on XLA device') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, paged attention
instead of flash attention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually it is not even paged attention
, you can just make this warning message more general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, updated to use an f
string.
@@ -400,6 +400,9 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices, | |||
buffer_index = torch.zeros((1,), dtype=torch.int32).to("xla") | |||
step = torch.zeros((1,), dtype=torch.int32).to("xla") | |||
output_shape = torch.Size(list(q.shape[:-1]) + [1]) | |||
q_output_dtype = torch.float32 | |||
if (num_heads // num_kv_heads) % 8 != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you can combine this with the above L396 code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Updated.
Thanks all for the reviews. After addressing all the comments, the two unit tests are still passing locally on my V4. I'll let the TPU CI verify one more time before merging. |
], payload, [q.shape, output_shape, output_shape], | ||
[q_output_dtype, torch.float32, torch.float32]) | ||
|
||
return output.reshape(batch_size, num_heads, head_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably want to use .to
to cast the output back to the original dtype here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
961dfff
to
7818d0f
Compare
Merging as all CI is green. |
Enable PagedAttention through Pallas
Test plan:
Todo as follow-ups: