Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDPv2] Support MultiSlice #7044

Merged
merged 4 commits into from
May 11, 2024
Merged

[FSDPv2] Support MultiSlice #7044

merged 4 commits into from
May 11, 2024

Conversation

alanwaketan
Copy link
Collaborator

Summary:
This pull request adds the multi-slice support for FSDPv2. Basically, the default setup is to use the dcn axis as the data axis, and it means we only do data parallel over multi-slices. In the future, we could also support FSDP over mutli-slices.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_fsdp_v2.py

@alanwaketan alanwaketan requested a review from jonb377 May 10, 2024 01:17
@alanwaketan alanwaketan self-assigned this May 10, 2024
@alanwaketan alanwaketan force-pushed the alanwaketan/fsdp2_ms branch from 088da28 to 76c7435 Compare May 10, 2024 01:24
@JackCaoG JackCaoG added the tpuci label May 10, 2024
Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks Jiewen!

@@ -24,6 +24,8 @@ def _prepare_spmd_partition_spec(param):
# TODO: should we shard on the maximal dim for param? Then we need
# another helper for the output.
partition_spec[0] = "fsdp"
if extra_data_axis:
partition_spec[0] = ("fsdp", extra_data_axis)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually have this reversed for DCN - (extra_data_axis, 'fsdp'). The axes should be in order of increasing network intensity in the mesh, and the order in the partition spec will impact the sharding.

mesh: Optional[spmd.Mesh] = None,
shard_output: Optional[Callable] = None,
auto_wrap_policy: Optional[Callable] = None,
auto_wrapper_callable: Optional[Callable] = None,
extra_data_axis: Optional[str] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of calling it replica_axis instead of extra_data_axis?

Copy link
Collaborator Author

@alanwaketan alanwaketan May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think replica_axis is too tied to the underneath technology while users may only be familiar with data parallel, fsdp, and tensor parallel.

xs.mark_sharding(x, mesh, (('data', 'fsdp'), None))
output = model(x)
# Make sure output are sharded.
annotation = '{devices=[4,1]0,2,1,3}'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be different than x's sharding - x should have [4,1]0,1,2,3 with the iota mesh. I left a comment below, I think we should reverse the order in _prepare_spmd_partition_spec.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My mistake. Thanks for pointing it out.

@alanwaketan
Copy link
Collaborator Author

@JackCaoG The TPU CI doesn't seem running even with the label.

@JackCaoG
Copy link
Collaborator

yea I think they only check the label when CI is being run. It is OK if you have any changes and repush it will run, otherwise we can let head to check.

@alanwaketan
Copy link
Collaborator Author

I'm landing it. If the master TPU CI breaks, let's deal with that later.

@alanwaketan alanwaketan merged commit 6f0b61e into master May 11, 2024
19 of 20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants