Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable PagedAttention through Pallas #6912

Merged
merged 15 commits into from
Apr 25, 2024
Merged

Conversation

wonjoolee95
Copy link
Collaborator

@wonjoolee95 wonjoolee95 commented Apr 10, 2024

Enable PagedAttention through Pallas

Test plan:

root@1fdc3324aeef:/pytorch/xla# python test/test_pallas.py PallasTest.test_paged_attention_wrapper
.
----------------------------------------------------------------------
Ran 1 test in 2.209s

OK
root@1fdc3324aeef:/pytorch/xla# python test/test_pallas.py PallasTest.test_paged_attention_wrapper_with_dynamo
.
----------------------------------------------------------------------
Ran 1 test in 2.114s

OK

Todo as follow-ups:

  • Add unit test for Dynamo
  • Enable all other parameters for jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel

@wonjoolee95 wonjoolee95 force-pushed the wonjoo/pallas-pagedattention branch from 50dac57 to b6822a3 Compare April 10, 2024 19:55
@wonjoolee95 wonjoolee95 force-pushed the wonjoo/pallas-pagedattention branch 3 times, most recently from b0262b0 to 72fdd57 Compare April 10, 2024 21:12
@alanwaketan alanwaketan self-requested a review April 11, 2024 18:45
@wonjoolee95 wonjoolee95 force-pushed the wonjoo/pallas-pagedattention branch 4 times, most recently from 45d9fe3 to f07c7e3 Compare April 12, 2024 22:14
@wonjoolee95 wonjoolee95 force-pushed the wonjoo/pallas-pagedattention branch from f07c7e3 to 1f70ed9 Compare April 16, 2024 20:13
@wonjoolee95 wonjoolee95 changed the title [WIP] Enable PagedAttention through Pallas Enable PagedAttention through Pallas Apr 16, 2024
@wonjoolee95 wonjoolee95 marked this pull request as ready for review April 16, 2024 20:16
@wonjoolee95 wonjoolee95 force-pushed the wonjoo/pallas-pagedattention branch from 1f70ed9 to f32836e Compare April 16, 2024 20:27
@miladm
Copy link
Collaborator

miladm commented Apr 16, 2024

cc @WoosukKwon to take a look

@wonjoolee95 wonjoolee95 force-pushed the wonjoo/pallas-pagedattention branch 4 times, most recently from cc5ad3a to 312bef1 Compare April 22, 2024 21:45
@wonjoolee95
Copy link
Collaborator Author

Locally, the tests are succeeding on my v4:

root@1fdc3324aeef:/pytorch/xla# python test/test_pallas.py PallasTest.test_paged_attention_wrapper
.
----------------------------------------------------------------------
Ran 1 test in 2.209s

OK
root@1fdc3324aeef:/pytorch/xla# python test/test_pallas.py PallasTest.test_paged_attention_wrapper_with_dynamo
.
----------------------------------------------------------------------
Ran 1 test in 2.114s

OK

I also just triggered the TPU CI on this PR.

@wonjoolee95 wonjoolee95 force-pushed the wonjoo/pallas-pagedattention branch from c6040cf to 19e28f8 Compare April 22, 2024 21:58
@wonjoolee95 wonjoolee95 requested a review from JackCaoG April 24, 2024 21:14
@wonjoolee95
Copy link
Collaborator Author

wonjoolee95 commented Apr 24, 2024

The CPU CI is failing with an unrelated test:

======================================================================
FAIL: test_resnet18 (__main__.DynamoTrainingBasicTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/__w/xla/xla/pytorch/xla/test/dynamo/test_dynamo.py", line 494, in test_resnet18
    self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count * 3)
AssertionError: 29 != 30

The CI including the TPU CI is passing, so this PR should be good to be reviewed. Thanks!

@wonjoolee95 wonjoolee95 requested a review from miladm April 24, 2024 21:16
torch.allclose(
output.cpu()[seq_lens > 0],
expected_output.cpu()[seq_lens > 0],
atol=1e-1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdyt we use a tighter bound for atol and rtol? e.g. 1e-3

Copy link
Collaborator Author

@wonjoolee95 wonjoolee95 Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sg, updated to 1e-5 for both tests.

Copy link
Collaborator

@miladm miladm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @wonjoolee95 - left a comment for you to eval and address - approving to unblock you

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, it looks good to me. Left a few comments.

@@ -331,6 +331,51 @@ def flash_attention(
return FlashAttention.apply(q, k, v, causal)


def paged_attention(q, k_pages, v_pages, lengths, page_indices,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original kernel has this thing called: q_dtype_for_kernel_launch? What does it do? Should we copy that as well?

Copy link
Collaborator Author

@wonjoolee95 wonjoolee95 Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the original kernel, the q_dtype_for_kernel_launch is always either jnp.float32 or q's dtype. In our case, I'm expecting the passed-in q's dtype to be torch.float32, so the q_dtype_for_kernel_launch will always be float32.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't think that will be the case for actual workflow. It could be bf16 or even in8 etc...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense. Just updated to handle q_dtype_for_kernel_launch, following jax's kernel -- https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L393. I can follow-up in another PR to add some more unit tests for different dtypes for q.

pages_per_compute_block: int):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
if k.device != torch.device("meta"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like this part can be consolidated with the flash attention one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sg, refactored these into a helper function.

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_wrapper(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's interesting that you use jax as the reference. I guess that works too. Wondering if we can just the eager attention helper in the class instead? Or that doesn't work? Anyway, if you are using jax as the reference, you can drop this.

Copy link
Collaborator Author

@wonjoolee95 wonjoolee95 Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sg, yeah I saw that we're dependent on JAX Pallas anyways, so I thought it may be easier to just test against the JAX's outputs.

Ah, makes sense. Just removed the jax.config updates.

q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what these seq_lens are? Are these the previous tokens for each batch in k, v?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that is my understanding -- the seq_lens here equals the number of tokens that are processed in the batch. Reference: https://docs.vllm.ai/en/latest/dev/kernel/paged_attention.html#concepts.

@wonjoolee95 wonjoolee95 force-pushed the wonjoo/pallas-pagedattention branch from 19e28f8 to b3a5948 Compare April 25, 2024 00:37
# We need to make sure output tensor's shape is correct.
if k.device != torch.device("meta"):
warnings.warn(
'XLA flash attention should only be applied to tensors on XLA device')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, paged attention instead of flash attention

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it is not even paged attention, you can just make this warning message more general.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, updated to use an f string.

@@ -400,6 +400,9 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices,
buffer_index = torch.zeros((1,), dtype=torch.int32).to("xla")
step = torch.zeros((1,), dtype=torch.int32).to("xla")
output_shape = torch.Size(list(q.shape[:-1]) + [1])
q_output_dtype = torch.float32
if (num_heads // num_kv_heads) % 8 != 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you can combine this with the above L396 code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Updated.

@wonjoolee95
Copy link
Collaborator Author

Thanks all for the reviews. After addressing all the comments, the two unit tests are still passing locally on my V4. I'll let the TPU CI verify one more time before merging.

], payload, [q.shape, output_shape, output_shape],
[q_output_dtype, torch.float32, torch.float32])

return output.reshape(batch_size, num_heads, head_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably want to use .to to cast the output back to the original dtype here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

@wonjoolee95 wonjoolee95 force-pushed the wonjoo/pallas-pagedattention branch from 961dfff to 7818d0f Compare April 25, 2024 01:04
@wonjoolee95
Copy link
Collaborator Author

Merging as all CI is green.

@wonjoolee95 wonjoolee95 merged commit 6ed2026 into master Apr 25, 2024
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants