-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDPv2] Support MultiSlice #7044
Conversation
088da28
to
76c7435
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks Jiewen!
@@ -24,6 +24,8 @@ def _prepare_spmd_partition_spec(param): | |||
# TODO: should we shard on the maximal dim for param? Then we need | |||
# another helper for the output. | |||
partition_spec[0] = "fsdp" | |||
if extra_data_axis: | |||
partition_spec[0] = ("fsdp", extra_data_axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We usually have this reversed for DCN - (extra_data_axis, 'fsdp')
. The axes should be in order of increasing network intensity in the mesh, and the order in the partition spec will impact the sharding.
mesh: Optional[spmd.Mesh] = None, | ||
shard_output: Optional[Callable] = None, | ||
auto_wrap_policy: Optional[Callable] = None, | ||
auto_wrapper_callable: Optional[Callable] = None, | ||
extra_data_axis: Optional[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think of calling it replica_axis
instead of extra_data_axis
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think replica_axis is too tied to the underneath technology while users may only be familiar with data parallel, fsdp, and tensor parallel.
test/spmd/test_fsdp_v2.py
Outdated
xs.mark_sharding(x, mesh, (('data', 'fsdp'), None)) | ||
output = model(x) | ||
# Make sure output are sharded. | ||
annotation = '{devices=[4,1]0,2,1,3}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be different than x
's sharding - x should have [4,1]0,1,2,3
with the iota mesh. I left a comment below, I think we should reverse the order in _prepare_spmd_partition_spec
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My mistake. Thanks for pointing it out.
@JackCaoG The TPU CI doesn't seem running even with the label. |
yea I think they only check the label when CI is being run. It is OK if you have any changes and repush it will run, otherwise we can let head to check. |
I'm landing it. If the master TPU CI breaks, let's deal with that later. |
Summary:
This pull request adds the multi-slice support for FSDPv2. Basically, the default setup is to use the dcn axis as the data axis, and it means we only do data parallel over multi-slices. In the future, we could also support FSDP over mutli-slices.
Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_fsdp_v2.py