Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add all-gather coalescing for FSDP/ZeRO1 #5950

Merged
merged 13 commits into from
Dec 2, 2023
Merged

Conversation

jeffhataws
Copy link
Collaborator

This PR adds all-gather coalescence support and use that in FSDP/ZeRO1 (replacing #5624). This PR is to be used in conjunction with openxla/xla#5740 .

A separate and related PR for reduce-scatter coalescence that also enables using reduce-scatter's scale param in FSDP is #5938.

This is a revival of #4145 . Will need to address the comments.

@jeffhataws jeffhataws requested a review from JackCaoG November 30, 2023 06:07
@JackCaoG
Copy link
Collaborator

@jeffhataws let me know when you are done addressing comments, I will take another look

@@ -295,6 +295,7 @@ def __init__(
sharding_world_size: Optional[int] = None,
shard_param_on_dim_0: bool = False,
pin_layout_in_collective_ops: bool = True,
coalesce_all_gather_ops: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind explaining the change in this file? I think coalesce_all_gather_ops is always False in our test, did you run into these issues with your own test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the coalesce_all_gather_ops is True, the parameter shards are collected into a list and gathered in one all-gather coalesced command at the end (instead of all-gather one parameter at a time).

It is off by default to avoid changing existing behavior. The code is same as what we are using in our local fork.

ReduceContext cc_ctx = GetReduceContext(inputs);
std::vector<xla::XlaOp> result(inputs.size());

for (auto& type_ctx : cc_ctx.contexts) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want to assume there is only one type_ctx, let's not use the for loop and GetReduceContext at all. This way we don't need to handle the token per type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check with others on this.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm beside the changes in FSDP. If we didn't change the default behavior of all-gather test should pass right?

I will look into reduce scatter one today, let's try to merge these two pr soon.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think we should test allgather_coalesced using resnet on gpu to make sure we don't break it in the future. You can refer to existing test

PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1
.

we can do that in a separate pr.

@JackCaoG JackCaoG merged commit 1271964 into master Dec 2, 2023
17 checks passed
jeffhataws added a commit to jeffhataws/xla that referenced this pull request Dec 8, 2023
* Add all-gather and reduce-scatter coalescence support for FSDP.

Also allow using reduce-scatter's scale param in FSDP.
(revived pytorch#4145)

* clang-format-7 and python lint fixes

* Fix "SyntaxError: 'return' outside function" error

* Code/test fixes to get run_tests.sh to run on CPU

* Fix allgather to be compatible with openxla allgather tuple change without token

* Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token

* Separate out the reduce-scatter-coalesce changes into a separate PR

* Some cleanups

* Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class

* Use token_handler.GetInput to capture token

* Clean up

* Clean up

* Switch to GetOperandListWithToken naming for func GetOperandList
jeffhataws added a commit to jeffhataws/xla that referenced this pull request Dec 11, 2023
* Add all-gather and reduce-scatter coalescence support for FSDP.

Also allow using reduce-scatter's scale param in FSDP.
(revived pytorch#4145)

* clang-format-7 and python lint fixes

* Fix "SyntaxError: 'return' outside function" error

* Code/test fixes to get run_tests.sh to run on CPU

* Fix allgather to be compatible with openxla allgather tuple change without token

* Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token

* Separate out the reduce-scatter-coalesce changes into a separate PR

* Some cleanups

* Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class

* Use token_handler.GetInput to capture token

* Clean up

* Clean up

* Switch to GetOperandListWithToken naming for func GetOperandList
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
* Add all-gather and reduce-scatter coalescence support for FSDP.

Also allow using reduce-scatter's scale param in FSDP.
(revived pytorch#4145)

* clang-format-7 and python lint fixes

* Fix "SyntaxError: 'return' outside function" error

* Code/test fixes to get run_tests.sh to run on CPU

* Fix allgather to be compatible with openxla allgather tuple change without token

* Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token

* Separate out the reduce-scatter-coalesce changes into a separate PR

* Some cleanups

* Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class

* Use token_handler.GetInput to capture token

* Clean up

* Clean up

* Switch to GetOperandListWithToken naming for func GetOperandList
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
* Add all-gather and reduce-scatter coalescence support for FSDP.

Also allow using reduce-scatter's scale param in FSDP.
(revived #4145)

* clang-format-7 and python lint fixes

* Fix "SyntaxError: 'return' outside function" error

* Code/test fixes to get run_tests.sh to run on CPU

* Fix allgather to be compatible with openxla allgather tuple change without token

* Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token

* Separate out the reduce-scatter-coalesce changes into a separate PR

* Some cleanups

* Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class

* Use token_handler.GetInput to capture token

* Clean up

* Clean up

* Switch to GetOperandListWithToken naming for func GetOperandList
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
* Add all-gather and reduce-scatter coalescence support for FSDP.

Also allow using reduce-scatter's scale param in FSDP.
(revived #4145)

* clang-format-7 and python lint fixes

* Fix "SyntaxError: 'return' outside function" error

* Code/test fixes to get run_tests.sh to run on CPU

* Fix allgather to be compatible with openxla allgather tuple change without token

* Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token

* Separate out the reduce-scatter-coalesce changes into a separate PR

* Some cleanups

* Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class

* Use token_handler.GetInput to capture token

* Clean up

* Clean up

* Switch to GetOperandListWithToken naming for func GetOperandList
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants