Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reduce-scatter coalescing for FSDP/ZeRO1 #5956

Merged
merged 4 commits into from
Dec 7, 2023

Conversation

jeffhataws
Copy link
Collaborator

@jeffhataws jeffhataws commented Nov 30, 2023

This PR adds reduce-scatter coalescence support and use that in FSDP/ZeRO1 (replacing #5938). This also enables using reduce-scatter's scale param in FSD.. This PR is companion to #5950 and to be used in conjunction with openxla openxla/xla#5740 .

This is a revival of #4145 . Will need to address the comments.

@JackCaoG
Copy link
Collaborator

I think my comment for this pr will be very similar to the all-gather one, let's try not to change the default behavior of reduce_scatter. Let me know when I should do another round of review

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 2, 2023

Now that allgather one is merged, do you mind resolve the conflict in this pr?

@jeffhataws
Copy link
Collaborator Author

FSDP test passes with pin layout disabled (--no_pin_layout_in_collective_ops).

@JackCaoG any idea on where this error comes from or how to resolve it?

F0000 00:00:1701586053.119077 375805 shape.h:169] Check failed: has_layout() element_type: TUPLE tuple_shapes { element_type: S64 dimensions: 4 layout { minor_to_major: 0 } is_dynamic_dimension: false } tuple_shapes { element_type: S64 dimensions: 10 layout { minor_to_major: 0 } is_dynamic_dimension: false }

@jeffhataws
Copy link
Collaborator Author

FSDP test passes with pin layout disabled (--no_pin_layout_in_collective_ops).

@JackCaoG any idea on where this error comes from or how to resolve it?

F0000 00:00:1701586053.119077 375805 shape.h:169] Check failed: has_layout() element_type: TUPLE tuple_shapes { element_type: S64 dimensions: 4 layout { minor_to_major: 0 } is_dynamic_dimension: false } tuple_shapes { element_type: S64 dimensions: 10 layout { minor_to_major: 0 } is_dynamic_dimension: false }

I worked around this by falling back to single-tensor reduce-scatter when not coalescing (bucket size is 0). Now the FSDP test is passing on GPU.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 4, 2023

somewhere in the code that xla is assert on input shape has layout, but tuple does not have layout... This is the part where I say we should actually test coalescing on fsdp resnet test. If there are errors on XLA side we should fix them. If it doesn't work on GPU and TPU, for this release we can call it a Trainium specified feature and trying to fix them for next release.

@jeffhataws
Copy link
Collaborator Author

somewhere in the code that xla is assert on input shape has layout, but tuple does not have layout... This is the part where I say we should actually test coalescing on fsdp resnet test. If there are errors on XLA side we should fix them. If it doesn't work on GPU and TPU, for this release we can call it a Trainium specified feature and trying to fix them for next release.

Yes I agree. Will you merge this for 2.2 then?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 4, 2023

I can merge to unblock you guys(and since it doesn't impact the default behavior), but if we can't verify that it actually works I am just not going to mention it in the release note.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 4, 2023

@jeffhataws can you separate out the FSDP change in a separate pr? I can asked @alanwaketan to review that part. I will try to help you to land the reduce scatter one after adding some tests.

@jeffhataws
Copy link
Collaborator Author

Strange, locally ./test/run_tests.sh is failing after latest changes to separate FSDP change out:

======================================================================
FAIL: test_patched_linear_2D_bias (__main__.TestAtenXlaTensor)
----------------------------------------------------------------------
Traceback (most recent call last):                                                                
  File "/home/ubuntu/jthuynh/pytorch/xla/test/test_operations.py", line 1817, in test_patched_linear_2D_bias                                                                                         
    self.assertTrue(torch.allclose(output.cpu(), output_cpu))                                     
AssertionError: False is not true     
                                                                                                  
======================================================================
FAIL: test_patched_linear_3D (__main__.TestAtenXlaTensor)
----------------------------------------------------------------------
Traceback (most recent call last):                                                                
  File "/home/ubuntu/jthuynh/pytorch/xla/test/test_operations.py", line 1761, in test_patched_linear_3D                                                                                              
    self.assertTrue(torch.allclose(output.cpu(), output_cpu))                                     
AssertionError: False is not true         
                                                                                                  
======================================================================
FAIL: test_patched_linear_3D_bias (__main__.TestAtenXlaTensor)                                    
----------------------------------------------------------------------                            
Traceback (most recent call last):      
  File "/home/ubuntu/jthuynh/pytorch/xla/test/test_operations.py", line 1789, in test_patched_linear_3D_bias                                                                                         
    self.assertTrue(torch.allclose(output.cpu(), output_cpu))
AssertionError: False is not true                                                                                                                                                                    
                                        

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 4, 2023

hmm can you print out output values? If they are close you can ignore them. Might just be a precision issue and your local machine and CI are using different random seeds.

@jeffhataws
Copy link
Collaborator Author

hmm can you print out output values? If they are close you can ignore them. Might just be a precision issue and your local machine and CI are using different random seeds.

Something weird with my environtment so let's ignore. CI seems to be fine.

@jeffhataws
Copy link
Collaborator Author

@jeffhataws can you separate out the FSDP change in a separate pr? I can asked @alanwaketan to review that part. I will try to help you to land the reduce scatter one after adding some tests.

#6024 is the PR for FSDP.

Comment on lines +247 to +248
std::vector<XLATensorPtr> xtensors_out =
GetXlaTensors(outputs, /*want_all=*/true);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want to handle the outputs cases, it is better to define a ReduceScatterCoalescedOut instead of merging the logic into a single function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it ok if I work on this in another change?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's OK, but can we explictly error out in python api side if output is not None? I'd rather throw a explictly error on cases that we don't test/support yet.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can work on the reduce_scatter_out in a separate pr

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Jack. Here's the change to error out if output!=None: 9baadf5

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#6058 is for adding out-of-place version for reduce-scatter, along with #6059 for out-of-place all-gather.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 5, 2023

Left some comments, @jeffhataws can you rebase this pr and add test cases to https://github.com/pytorch/xla/blob/master/test/test_mp_reduce_scatter.py so we can actually run reduce_scatter with list input on GPU?

Also allow using reduce-scatter's scale param in FSDP.
(revived #4145)

Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token

Switch to GetOperandListWithToken naming for func GetOperandList

Add separate BuildReduceScatterCoalesced builder

Use token_handler.GetInput to consume the token

If bucket_size_mb is default 0, reduce-scatter every tensor rather than coalesce

Fix error checking in xm.reduce_scatter

Move FSDP changes to another PR
@jeffhataws jeffhataws force-pushed the cc_coalesce_reducescatter branch 2 times, most recently from be0d3f7 to 51e9919 Compare December 5, 2023 20:54
@jeffhataws jeffhataws force-pushed the cc_coalesce_reducescatter branch from 4624ec5 to 69968e5 Compare December 6, 2023 03:34
@jeffhataws
Copy link
Collaborator Author

Left some comments, @jeffhataws can you rebase this pr and add test cases to https://github.com/pytorch/xla/blob/master/test/test_mp_reduce_scatter.py so we can actually run reduce_scatter with list input on GPU?

Added a test case. Only can handle pin_layout=False at the moment.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 6, 2023

Let me take another look. 2.2 branch is cut, I will take care of bakporting this pr after it is merged.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one minor comment. If we can explictly error out on out != None case and CI pass I can merge this one.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jeffhataws !

@JackCaoG JackCaoG merged commit ac94781 into master Dec 7, 2023
18 checks passed
ManfeiBai pushed a commit that referenced this pull request Dec 8, 2023
jeffhataws added a commit to jeffhataws/xla that referenced this pull request Dec 8, 2023
jeffhataws added a commit to jeffhataws/xla that referenced this pull request Dec 11, 2023
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants