Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add all-gather coalescing for FSDP/ZeRO1 #5624

Closed
wants to merge 9 commits into from

Conversation

jeffhataws
Copy link
Collaborator

@jeffhataws jeffhataws commented Sep 20, 2023

(Replaced by #5950)

This PR adds all-gather coalescence support and use that in FSDP/ZeRO1. This PR is to be used in conjunction with openxla/xla#5740 .

A separate and related PR for reduce-scatter coalescence that also enables using reduce-scatter's scale param in FSDP is #5938.

This is a revival of #4145 . Will need to address the comments.

@JackCaoG
Copy link
Collaborator

@alanwaketan can you take a look?

@alanwaketan
Copy link
Collaborator

@jeffhataws Can you double check the CI failures?

@jeffhataws
Copy link
Collaborator Author

@jeffhataws Can you double check the CI failures?

Thanks. Since it depends on openxla/xla#5740 I will need to take care of merging that first. So let's leave this open for now.

copybara-service bot pushed a commit to openxla/xla that referenced this pull request Oct 17, 2023
Imported from GitHub PR #5740

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

Copybara import of the project:

--
7ea1159 by Junmin Hao <[email protected]>:

Add Tuple input and token support to all-gather and reduce-scatter.

Committer: Junmin Hao <[email protected]>

--
cdb873e by Junmin Hao <[email protected]>:

lint fix

--
aad3521 by Jeffrey Huynh <[email protected]>:

Fix hlo_verifier_test failure due to changed expectation

--
32e8145 by Jeffrey Huynh <[email protected]>:

Separate the token change out into a separate PR with RFC.

--
b301c2a by Jeffrey Huynh <[email protected]>:

Change *WithToken tests to *WithTuple

--
5890278 by Jeffrey Huynh <[email protected]>:

Fix missing parenthesis

Merging this change closes #5740

COPYBARA_INTEGRATE_REVIEW=#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0
PiperOrigin-RevId: 573976449
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 17, 2023
Imported from GitHub PR openxla/xla#5740

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of #58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

Copybara import of the project:

--
7ea1159a1464efddebe9384e87ed6df504d89b2e by Junmin Hao <[email protected]>:

Add Tuple input and token support to all-gather and reduce-scatter.

Committer: Junmin Hao <[email protected]>

--
cdb873e6d97f5f12b3d3c587bb5782d58e3554c5 by Junmin Hao <[email protected]>:

lint fix

--
aad352117ba950ac5ae62330e3980f4b5898a701 by Jeffrey Huynh <[email protected]>:

Fix hlo_verifier_test failure due to changed expectation

--
32e814524b88a474af5e4e904c0dd19841430b86 by Jeffrey Huynh <[email protected]>:

Separate the token change out into a separate PR with RFC.

--
b301c2a2a5b52180f9e9626173e6b67a78782960 by Jeffrey Huynh <[email protected]>:

Change *WithToken tests to *WithTuple

--
5890278fc16c9f900782d32a92d40ecf548aea85 by Jeffrey Huynh <[email protected]>:

Fix missing parenthesis

Merging this change closes #5740

PiperOrigin-RevId: 573976449
@jeffhataws jeffhataws closed this Oct 18, 2023
@jeffhataws jeffhataws reopened this Oct 19, 2023
@jeffhataws jeffhataws force-pushed the cc_coalesce_revival branch 2 times, most recently from 2e861ff to 76a2f0f Compare October 20, 2023 02:42
@jeffhataws
Copy link
Collaborator Author

@alanwaketan what's the best way to check this against the openxla with merged openxla/xla#5740?

@alanwaketan
Copy link
Collaborator

@alanwaketan what's the best way to check this against the openxla with merged openxla/xla#5740?

I'm working on a pin update. Will loop you in once that PR is up.

@jeffhataws
Copy link
Collaborator Author

@alanwaketan what's the best way to check this against the openxla with merged openxla/xla#5740?

I'm working on a pin update. Will loop you in once that PR is up.

@alanwaketan just want to check how things are going with the pin update. Also, how do I make the automatic checks to run on this PR?

@jeffhataws
Copy link
Collaborator Author

@alanwaketan @JackCaoG will you help ensure this is in 2.2?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Nov 9, 2023

@jeffhataws can you rebase this pr? Then we can start reviewing it. Thanks!

@alanwaketan
Copy link
Collaborator

@alanwaketan what's the best way to check this against the openxla with merged openxla/xla#5740?

I'm working on a pin update. Will loop you in once that PR is up.

The pin update is completed. So, please rebase and then I will take a look.

@jeffhataws
Copy link
Collaborator Author

@alanwaketan what's the best way to check this against the openxla with merged openxla/xla#5740?

I'm working on a pin update. Will loop you in once that PR is up.

The pin update is completed. So, please rebase and then I will take a look.

Thanks alanwaketan. I have rebased. Please take a look.

@jeffhataws
Copy link
Collaborator Author

jeffhataws commented Nov 17, 2023

@alanwaketan @JackCaoG I am not sure how to reproduce/debug the above errors.

test/run_tests.sh passes for me on CPU machine with export PJRT_DEVICE=CPU.

When I run test/cpp/run_tests.sh with export PJRT_DEVICE=CPU I get a hang:

Running all cpp test...
+ '[' /tmp/pytorch_cpp_test.log '!=' '' ']'
+ bazel test --config=tpu //torch_xla/csrc/runtime:all //test/cpp:all --test_timeout 1000
(hangs)

@jeffhataws
Copy link
Collaborator Author

@alanwaketan @JackCaoG I am not sure how to reproduce/debug the above errors.

test/run_tests.sh passes for me on CPU machine with export PJRT_DEVICE=CPU.

When I run test/cpp/run_tests.sh with export PJRT_DEVICE=CPU I get a hang:

Running all cpp test...
+ '[' /tmp/pytorch_cpp_test.log '!=' '' ']'
+ bazel test --config=tpu //torch_xla/csrc/runtime:all //test/cpp:all --test_timeout 1000
(hangs)

It wasn't hanging. After a while, it shows:

+ bazel test --config=tpu //torch_xla/csrc/runtime:all //test/cpp:all --test_timeout 1000
^[[A^[[A^[[A^[[A//test/cpp:test_aten_xla_tensor_1                                        PASSED in 23.4s
//test/cpp:test_aten_xla_tensor_2                                        PASSED in 18.5s
//test/cpp:test_aten_xla_tensor_3                                        PASSED in 44.9s
//test/cpp:test_aten_xla_tensor_4                                        PASSED in 15.4s
//test/cpp:test_aten_xla_tensor_5                                        PASSED in 85.5s
//test/cpp:test_aten_xla_tensor_6                                        PASSED in 44.5s
//test/cpp:test_ir                                                       PASSED in 0.5s
//test/cpp:test_lazy                                                     PASSED in 0.5s
//test/cpp:test_replication                                              PASSED in 0.6s
//test/cpp:test_tensor                                                   PASSED in 101.9s
//test/cpp:test_xla_sharding                                             PASSED in 0.5s
//torch_xla/csrc/runtime:cache_test                                      PASSED in 0.0s
//torch_xla/csrc/runtime:pjrt_computation_client_test                    PASSED in 0.7s
//torch_xla/csrc/runtime:sys_util_test                                   PASSED in 0.0s
//torch_xla/csrc/runtime:util_test                                       PASSED in 0.1s
//torch_xla/csrc/runtime:xla_util_test                                   PASSED in 0.4s

Executed 16 out of 16 tests: 16 tests pass.
There were tests whose specified size is too big. Use the --test_verbose_timeout_warnings command line option to see which ones these are.

@JackCaoG
Copy link
Collaborator

You didn't add any cpp test and our cpp test doesn't test distributed so that should be fine? I do see a bunch of python test failing in the CI with

ImportError: /opt/conda/lib/python3.8/site-packages/torch_xla-2.2.0+git08075a8-py3.8-linux-x86_64.egg/_XLAC.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN9torch_xla14tensor_methods10all_gatherERKN3c1013intrusive_ptrINS_9XLATensorENS1_6detail34intrusive_target_default_null_typeIS3_EEEEllSt6vectorISA_IlSaIlEESaISC_EEb

in https://github.com/pytorch/xla/actions/runs/6896526398/job/18766614593?pr=5624

If I have to guess it is that the build pass but when we try to import the XLAC, we find that there are some functions that only has the header but missing in the cpp file? this one seems to be tensor_methods::all_gather. If you try to build this locally and just do a import torch_xla you should see the same failure.

@JackCaoG
Copy link
Collaborator

@alanwaketan is out next week, I will work with you to try to land this change before branch cut(or we will cherry-pick).

Also allow using reduce-scatter's scale param in FSDP.
(revived pytorch#4145)
@jeffhataws
Copy link
Collaborator Author

You didn't add any cpp test and our cpp test doesn't test distributed so that should be fine? I do see a bunch of python test failing in the CI with

ImportError: /opt/conda/lib/python3.8/site-packages/torch_xla-2.2.0+git08075a8-py3.8-linux-x86_64.egg/_XLAC.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN9torch_xla14tensor_methods10all_gatherERKN3c1013intrusive_ptrINS_9XLATensorENS1_6detail34intrusive_target_default_null_typeIS3_EEEEllSt6vectorISA_IlSaIlEESaISC_EEb

in https://github.com/pytorch/xla/actions/runs/6896526398/job/18766614593?pr=5624

If I have to guess it is that the build pass but when we try to import the XLAC, we find that there are some functions that only has the header but missing in the cpp file? this one seems to be tensor_methods::all_gather. If you try to build this locally and just do a import torch_xla you should see the same failure.

Fixed. Removed the previous all_gather method by mistake.

jeffhataws added a commit to jeffhataws/openxla that referenced this pull request Nov 19, 2023
…tter

Imported from GitHub PR openxla#5740

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

Copybara import of the project:

--
7ea1159 by Junmin Hao <[email protected]>:

Add Tuple input and token support to all-gather and reduce-scatter.

Committer: Junmin Hao <[email protected]>

--
cdb873e by Junmin Hao <[email protected]>:

lint fix

--
aad3521 by Jeffrey Huynh <[email protected]>:

Fix hlo_verifier_test failure due to changed expectation

--
32e8145 by Jeffrey Huynh <[email protected]>:

Separate the token change out into a separate PR with RFC.

--
b301c2a by Jeffrey Huynh <[email protected]>:

Change *WithToken tests to *WithTuple

--
5890278 by Jeffrey Huynh <[email protected]>:

Fix missing parenthesis

Merging this change closes openxla#5740

COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0
PiperOrigin-RevId: 573976449
@jeffhataws
Copy link
Collaborator Author

Fixed the lint error.

@jeffhataws
Copy link
Collaborator Author

We can keep trying to fix test errors, but I am not sure if it will be eaiser to split all-gather and reduce-scatter into separate prs. @jeffhataws I take that this pr works in your local forks, not sure what's the difference. Test failures seem real.

Yeah the code was ported from an old version of torch/xla so there were some merge errors. Plus the final version of openxla change openxla/xla#5740 doesn't have token support, so I need to make the corresponding change here.

@JackCaoG
Copy link
Collaborator

hmm, this error seems real

2023-11-27 19:09:57.910045: E external/xla/xla/status_macros.cc:54] INTERNAL: RET_CHECK failure (external/xla/xla/service/shape_inference.cc:2155) scatter_dimension < operand_shape->rank() 
*** Begin stack trace ***
	tsl::CurrentStackTrace[abi:cxx11]()
	
	xla::status_macros::MakeErrorStream::Impl::GetStatus()
	xla::ShapeInference::InferReduceScatterShape(absl::lts_20230125::Span<xla::Shape const* const>, long, long)
	
	
	xla::XlaBuilder::ReportErrorOrReturn(absl::lts_20230125::FunctionRef<absl::lts_20230125::StatusOr<xla::XlaOp> ()>)
	xla::XlaBuilder::ReduceScatter(xla::XlaOp, xla::XlaComputation const&, long, long, 

@jeffhataws jeffhataws requested a review from JackCaoG November 28, 2023 20:36
@jeffhataws
Copy link
Collaborator Author

One of the CPU workflows failed with this:

[25,304 / 29,344] Compiling xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc; 33s local ... (48 actions, 47 running)
[28,230 / 30,817] Compiling xla/service/gpu/cub_sort_kernel.cu.cc; 31s local ... (48 actions, 47 running)
[29,163 / 30,817] Compiling xla/service/algebraic_simplifier.cc; 29s local ... (48 actions, 47 running)
/home/ec2-user/actions-runner/_work/_temp/c18e5d69-957a-4e96-8ea6-714d1254694a.sh: line 1: 27242 Killed                  docker exec --privileged -u jenkins "${pid}" bash -c '.circleci/test.sh'
Error: Process completed with exit code 137.

@JackCaoG
Copy link
Collaborator

hmm, seems like vm oom when building pytorch/xla...

@JackCaoG
Copy link
Collaborator

Let's ignore the CPU failure and focus on GPU for now. GPU failures seems real.

@jeffhataws jeffhataws changed the title Add all-gather/reduce-scatter coalescee for FSDP/ZeRO1 Add all-gather coalescing for FSDP/ZeRO1 Nov 29, 2023
@jeffhataws
Copy link
Collaborator Author

All tests passing with Reduce-Scatter change separated out in #5938 .

xla::XlaOp all_gather_result;
if (pin_layout) {
all_gather_result = xla::AllGather(
xla::Tuple(inputs[0].builder(), type_ctx.second.ops), dim,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this means even if there is single element in the all-gather, we will wrap it inside the tuple.. I need to check with xla teams whether this has any speed implications.

}
} else {
result[0] = all_gather_result;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how token is being used here, in all_reduce, we manually append it at the end of ops

  for (auto& type_ctx : redux.contexts) {
    xla::XlaOp token_op = MaybeConvertTo(chained_token, type_ctx.first);
    type_ctx.second.ops.push_back(token_op);
    type_ctx.second.operand_shapes.push_back(
        ShapeHelper::ShapeOfXlaOp(token_op));

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw the GetOperandList below, but I think this does not gurante when you have multiple types, each types has a token.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the usage of token.

For multiple types case, I think for now we should ensure same type in the list.

}
return {result, torch::lazy::Value(node, inputs.size())};
}

XLATensorPtr all_gather(const XLATensorPtr& input, int64_t dim,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not why does it compile when below code still pass a single IR variable to torch::lazy::MakeNode<AllGather>, while you change the constructor to take arrayRef, maybe arrayRef have a default constructor.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a few comments, overall I think it is better to implement a new BuildAllGatherCoalesced instead of modifying the existing BuildAllGather. The way you do it today will change the HLO generated for the single tensor all-gather. Given that branch cut is this Friday, I don't think we have time to do the performance benchmarking to make sure this is regression free, it is safer to add new features while no touch the existing logic.

@JackCaoG
Copy link
Collaborator

oh I know what's going on. Can you create a branch from pytorch/xla directly instead of creating from a fork. Fork can not used our cache so compilation will take much longer and easier to fail. I already give you write access to the project so you should be able to create a new branch from our repo directly.

@jeffhataws
Copy link
Collaborator Author

oh I know what's going on. Can you create a branch from pytorch/xla directly instead of creating from a fork. Fork can not used our cache so compilation will take much longer and easier to fail. I already give you write access to the project so you should be able to create a new branch from our repo directly.

Thanks. Changed to PR from a branch on pytorch/xla #5950 .

@jeffhataws jeffhataws closed this Nov 30, 2023
jeffhataws added a commit to jeffhataws/openxla that referenced this pull request Dec 10, 2023
…tter

Imported from GitHub PR openxla#5740

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

Copybara import of the project:

--
7ea1159 by Junmin Hao <[email protected]>:

Add Tuple input and token support to all-gather and reduce-scatter.

Committer: Junmin Hao <[email protected]>

--
cdb873e by Junmin Hao <[email protected]>:

lint fix

--
aad3521 by Jeffrey Huynh <[email protected]>:

Fix hlo_verifier_test failure due to changed expectation

--
32e8145 by Jeffrey Huynh <[email protected]>:

Separate the token change out into a separate PR with RFC.

--
b301c2a by Jeffrey Huynh <[email protected]>:

Change *WithToken tests to *WithTuple

--
5890278 by Jeffrey Huynh <[email protected]>:

Fix missing parenthesis

Merging this change closes openxla#5740

COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0
PiperOrigin-RevId: 573976449
jeffhataws added a commit to jeffhataws/openxla that referenced this pull request Dec 11, 2023
…tter

Imported from GitHub PR openxla#5740

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

Copybara import of the project:

--
7ea1159 by Junmin Hao <[email protected]>:

Add Tuple input and token support to all-gather and reduce-scatter.

Committer: Junmin Hao <[email protected]>

--
cdb873e by Junmin Hao <[email protected]>:

lint fix

--
aad3521 by Jeffrey Huynh <[email protected]>:

Fix hlo_verifier_test failure due to changed expectation

--
32e8145 by Jeffrey Huynh <[email protected]>:

Separate the token change out into a separate PR with RFC.

--
b301c2a by Jeffrey Huynh <[email protected]>:

Change *WithToken tests to *WithTuple

--
5890278 by Jeffrey Huynh <[email protected]>:

Fix missing parenthesis

Merging this change closes openxla#5740

COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0
PiperOrigin-RevId: 573976449
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants