-
Notifications
You must be signed in to change notification settings - Fork 489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add reduce-scatter coalescing for FSDP/ZeRO1 #5956
Conversation
I think my comment for this pr will be very similar to the all-gather one, let's try not to change the default behavior of |
Now that allgather one is merged, do you mind resolve the conflict in this pr? |
FSDP test passes with pin layout disabled (--no_pin_layout_in_collective_ops). @JackCaoG any idea on where this error comes from or how to resolve it?
|
I worked around this by falling back to single-tensor reduce-scatter when not coalescing (bucket size is 0). Now the FSDP test is passing on GPU. |
somewhere in the code that xla is assert on input shape has layout, but tuple does not have layout... This is the part where I say we should actually test coalescing on fsdp resnet test. If there are errors on XLA side we should fix them. If it doesn't work on GPU and TPU, for this release we can call it a Trainium specified feature and trying to fix them for next release. |
Yes I agree. Will you merge this for 2.2 then? |
I can merge to unblock you guys(and since it doesn't impact the default behavior), but if we can't verify that it actually works I am just not going to mention it in the release note. |
@jeffhataws can you separate out the FSDP change in a separate pr? I can asked @alanwaketan to review that part. I will try to help you to land the reduce scatter one after adding some tests. |
Strange, locally ./test/run_tests.sh is failing after latest changes to separate FSDP change out:
|
hmm can you print out output values? If they are close you can ignore them. Might just be a precision issue and your local machine and CI are using different random seeds. |
Something weird with my environtment so let's ignore. CI seems to be fine. |
#6024 is the PR for FSDP. |
std::vector<XLATensorPtr> xtensors_out = | ||
GetXlaTensors(outputs, /*want_all=*/true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you want to handle the outputs
cases, it is better to define a ReduceScatterCoalescedOut
instead of merging the logic into a single function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it ok if I work on this in another change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's OK, but can we explictly error out in python api side if output
is not None? I'd rather throw a explictly error on cases that we don't test/support yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can work on the reduce_scatter_out
in a separate pr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Jack. Here's the change to error out if output!=None: 9baadf5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments, @jeffhataws can you rebase this pr and add test cases to https://github.com/pytorch/xla/blob/master/test/test_mp_reduce_scatter.py so we can actually run |
Also allow using reduce-scatter's scale param in FSDP. (revived #4145) Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token Switch to GetOperandListWithToken naming for func GetOperandList Add separate BuildReduceScatterCoalesced builder Use token_handler.GetInput to consume the token If bucket_size_mb is default 0, reduce-scatter every tensor rather than coalesce Fix error checking in xm.reduce_scatter Move FSDP changes to another PR
be0d3f7
to
51e9919
Compare
4624ec5
to
69968e5
Compare
Added a test case. Only can handle pin_layout=False at the moment. |
Let me take another look. 2.2 branch is cut, I will take care of bakporting this pr after it is merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with one minor comment. If we can explictly error out on out != None
case and CI pass I can merge this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jeffhataws !
Co-authored-by: jeffhataws <[email protected]>
This PR adds reduce-scatter coalescence support and use that in FSDP/ZeRO1 (replacing #5938). This also enables using reduce-scatter's scale param in FSD.. This PR is companion to #5950 and to be used in conjunction with openxla openxla/xla#5740 .
This is a revival of #4145 . Will need to address the comments.