Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fori loop simple case test test #7028

Conversation

ManfeiBai
Copy link
Collaborator

No description provided.

qihqi and others added 30 commits April 17, 2024 15:31
Summary:
This pull request integrating FlashAttention with SPMD. The way it works is to create a manual sharding region for the kernel which means we wraps all the inputs with enable_manual_sharding and all the outputs with disable_manual_sharding.

Added a new test file because the original test file is not SPMD aware.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas_spmd.py
Add doc for fori_loop/while_loop and add simple user guide for simple test case
Summary:
This PR adds a preliminary user guide for Pallas.

Test Plan:
Skip CI.
Summary:
Clarify some caveats for using the TPU docker images in the landing page.

Test Plan:
Skip CI.
will-cromar and others added 23 commits April 26, 2024 15:12
Summary:
This PR is to add segment ids to the flash attention wrapper. The segment ids are a way to create an attention mask where each token can only attend to other tokens within the same segment. The mask is therefore a block diagonal matrix.

To support it, we further split the flash attention forward into tracing and execution part, and implement all the shape operations to make it compatible with the kernel.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py
@ManfeiBai ManfeiBai closed this May 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.