Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Support Flash Attention #6658

Merged
merged 9 commits into from
Mar 6, 2024
Merged

Conversation

alanwaketan
Copy link
Collaborator

@alanwaketan alanwaketan commented Mar 1, 2024

Summary:
This PR makes all necessary changes to support Pallas's FlashAttention.

The major change here is to disable TPU layout that places bigger dimensions on most minor layout locations. This optimization boosts performance for dim > 2 tensor input applications, like resnet. For LLM, it should be fine to disable since all inputs are 2D.

ResNet: python test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1 --metrics_debug
XLA_TPU_LAYOUT = 0

| Training Device=xla:0/0 Epoch=1 Step=2320 Loss=0.00135 Rate=474.16 GlobalRate=255.14 Time=00:06:30

XLA_TPU_LAYOUT = 1

| Training Device=xla:0/1 Epoch=1 Step=2340 Loss=0.00135 Rate=1750.53 GlobalRate=1151.09 Time=00:15:46

Llama 2 2B
XLA_TPU_LAYOUT = 0
Screenshot 2024-03-05 at 11 37 37 AM
XLA_TPU_LAYOUT = 1
Screenshot 2024-03-05 at 4 48 33 PM

Test Plan:
python test/test_operations.py -v -k test_tpu_custom_call_pallas_flash_attention

@alanwaketan alanwaketan changed the title [WIP] Support Pallas Flash Attention [Pallas] Support Flash Attention Mar 6, 2024
@alanwaketan alanwaketan self-assigned this Mar 6, 2024
@alanwaketan alanwaketan marked this pull request as ready for review March 6, 2024 00:49
@alanwaketan alanwaketan requested review from JackCaoG and qihqi March 6, 2024 00:49
@alanwaketan alanwaketan force-pushed the alanwaketan/flash_attention branch from 37cf282 to e60f07b Compare March 6, 2024 05:55
@alanwaketan
Copy link
Collaborator Author

Thanks, Han!

@alanwaketan alanwaketan merged commit 9d4dcae into master Mar 6, 2024
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants