Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tuple input support to all-gather and reduce-scatter #5740

Closed
wants to merge 9 commits into from

Conversation

jeffhataws
Copy link
Contributor

@jeffhataws jeffhataws commented Sep 20, 2023

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

@google-cla
Copy link

google-cla bot commented Sep 20, 2023

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@radhakrishnaba
Copy link

Hi @jeffhataws, can you sign up on this CLA provided below.
https://cla.developers.google.com/clas

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Sep 21, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Sep 21, 2023
@jeffhataws
Copy link
Contributor Author

Hi @jeffhataws, can you sign up on this CLA provided below. https://cla.developers.google.com/clas

Done in the latest merge from main.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Sep 22, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Sep 22, 2023
Copy link
Contributor

@jurahul jurahul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, I think my comment on the earlier version of this change is still valid. Its doing 3 things (1) adding a convenience to pass in tuples into AllGather and ReduceScatter XLA Builder API and have the API flatten these tuples. and (2) Adding a way to control the ordering of the collective ops by threading them through "tokens", and (3) allowing these threading token to be either a token type in XLA or any scalar type.

Part 1 can be easily split into its own change is not concerning, so that should be the first step. Parts 2 and 3 really need to be discussed with the broader team. Can we start an RFC around this? XLA already has control dependency mechanism to arbitrarily order instructions. This will be a new mechanism for collectives. It seem we should have a common mechanism to express any scheduling constrains as HLO comes into XLA as I can image this token based scheme potentially applicable to all other ops, for various reasons (like controlling memory pressure or something else).

Also, allowing any scalar as a token for these instructions seems confusing and that should be discussed as well. Also, there is another aspect of this needing equivalent changes in StableHLO as well.

@burmako can you please advice on how HLO/StableHLO level changes like parts of this PR should be driven?

return Unimplemented("0 element tuple AllGather is not supported");
}
for (int i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
if (operand_shape->tuple_shapes(i).element_type() !=
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this required here? I don't think an all-gather in particular has this restriction at HLO level. For all-reduce such restriction makes sense since the reduction computation is typed.

Also, can we share the tuple flattening code with all-reduce/reduce-scatter and make a helper function for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jurahul you are right. Removed the check for same shape.

The other request of sharing tuple flattening code will be a larger refactor across existing all-reduce/reduce-scatter so I want to avoid that to keep the changes to minimum and self-contained.

@burmako
Copy link
Contributor

burmako commented Sep 27, 2023

+1 to @jurahul's feedback. I would also like to request an RFC for (2) and (3), so that we can discuss design tradeoffs across multiple groups and end up with clear documentation of the new functionality.

We haven't yet standardized on an RFC process for StableHLO, and I've seen several ways of approaching this. My recommendation would be to post a proposal on openxla-discuss, with a description of updated semantics for the affected ops.

Here's a question which may help to get the RFC started. The StableHLO specification has a description of semantics and constraints for the existing collective ops, e.g. here's an example for all_gather. How would this specification need to change to accommodate the proposal in this PR?

@jeffhataws
Copy link
Contributor Author

burmako

Thanks @jurahul and @burmako . Will make modifications based on your feedbacks.

@radhakrishnaba
Copy link

Hi @jeffhataws, any update on this? Thanks.

@jeffhataws
Copy link
Contributor Author

Hi @jeffhataws, any update on this? Thanks.

Thanks for checking. Will post an update by next week, with RFC as requested.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Oct 9, 2023
@@ -434,6 +434,7 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo,
ag->use_global_device_ids()));
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode));
TF_RET_CHECK(ag->all_gather_dimension() >= 0);
TF_RET_CHECK(ag->operand_count() >= 1);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jurahul, @burmako, @radhakrishnaba I think this is correct. Let me know if you want me to remove this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also seems ok to me.

@@ -521,6 +522,7 @@ Status ShapeVerifier::HandleReduceScatter(HloInstruction* hlo) {
ars->use_global_device_ids()));
TF_RETURN_IF_ERROR(CheckReplicaGroups(ars, group_mode));
TF_RET_CHECK(ars->scatter_dimension() >= 0);
TF_RET_CHECK(ars->operand_count() >= 1);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here.

@@ -1289,6 +1289,10 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
return HandleAllReduceMultipleReplica(crs);
}

Status IrEmitter::HandleReduceScatter(HloInstruction* rs) {
return Unimplemented("ReduceScatter is not implemented on CPU.");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jurahul, @burmako, @radhakrishnaba I think this is correct. Let me know if you want me to remove this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine to me,

Copy link
Contributor

@jurahul jurahul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -1289,6 +1289,10 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
return HandleAllReduceMultipleReplica(crs);
}

Status IrEmitter::HandleReduceScatter(HloInstruction* rs) {
return Unimplemented("ReduceScatter is not implemented on CPU.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine to me,

@@ -434,6 +434,7 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo,
ag->use_global_device_ids()));
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode));
TF_RET_CHECK(ag->all_gather_dimension() >= 0);
TF_RET_CHECK(ag->operand_count() >= 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also seems ok to me.

@jeffhataws
Copy link
Contributor Author

Hi @thomasjoerg will you help complete the review? Thanks.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Oct 11, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Oct 11, 2023
@gbaned gbaned added the ready to pull PR ready for merge process label Oct 13, 2023
@JackCaoG
Copy link

@cheshire @burmako Do you guys know what's the process of merging this pr on the open source side? I saw that copybara is currently failing for this pr.

@jurahul
Copy link
Contributor

jurahul commented Oct 16, 2023

@ddunl can you please help take a look why this is blocked?

@ddunl
Copy link
Member

ddunl commented Oct 16, 2023

Disabled the already failing tests by editing TAP_PROJECTS in the description, hopefully will submit this upcoming try

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 17, 2023
Imported from GitHub PR openxla/xla#5740

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of #58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

Copybara import of the project:

--
7ea1159a1464efddebe9384e87ed6df504d89b2e by Junmin Hao <[email protected]>:

Add Tuple input and token support to all-gather and reduce-scatter.

Committer: Junmin Hao <[email protected]>

--
cdb873e6d97f5f12b3d3c587bb5782d58e3554c5 by Junmin Hao <[email protected]>:

lint fix

--
aad352117ba950ac5ae62330e3980f4b5898a701 by Jeffrey Huynh <[email protected]>:

Fix hlo_verifier_test failure due to changed expectation

--
32e814524b88a474af5e4e904c0dd19841430b86 by Jeffrey Huynh <[email protected]>:

Separate the token change out into a separate PR with RFC.

--
b301c2a2a5b52180f9e9626173e6b67a78782960 by Jeffrey Huynh <[email protected]>:

Change *WithToken tests to *WithTuple

--
5890278fc16c9f900782d32a92d40ecf548aea85 by Jeffrey Huynh <[email protected]>:

Fix missing parenthesis

Merging this change closes #5740

PiperOrigin-RevId: 573976449
jeffhataws added a commit to jeffhataws/openxla that referenced this pull request Nov 19, 2023
…tter

Imported from GitHub PR openxla#5740

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

Copybara import of the project:

--
7ea1159 by Junmin Hao <[email protected]>:

Add Tuple input and token support to all-gather and reduce-scatter.

Committer: Junmin Hao <[email protected]>

--
cdb873e by Junmin Hao <[email protected]>:

lint fix

--
aad3521 by Jeffrey Huynh <[email protected]>:

Fix hlo_verifier_test failure due to changed expectation

--
32e8145 by Jeffrey Huynh <[email protected]>:

Separate the token change out into a separate PR with RFC.

--
b301c2a by Jeffrey Huynh <[email protected]>:

Change *WithToken tests to *WithTuple

--
5890278 by Jeffrey Huynh <[email protected]>:

Fix missing parenthesis

Merging this change closes openxla#5740

COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0
PiperOrigin-RevId: 573976449
jeffhataws added a commit to jeffhataws/openxla that referenced this pull request Dec 10, 2023
…tter

Imported from GitHub PR openxla#5740

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

Copybara import of the project:

--
7ea1159 by Junmin Hao <[email protected]>:

Add Tuple input and token support to all-gather and reduce-scatter.

Committer: Junmin Hao <[email protected]>

--
cdb873e by Junmin Hao <[email protected]>:

lint fix

--
aad3521 by Jeffrey Huynh <[email protected]>:

Fix hlo_verifier_test failure due to changed expectation

--
32e8145 by Jeffrey Huynh <[email protected]>:

Separate the token change out into a separate PR with RFC.

--
b301c2a by Jeffrey Huynh <[email protected]>:

Change *WithToken tests to *WithTuple

--
5890278 by Jeffrey Huynh <[email protected]>:

Fix missing parenthesis

Merging this change closes openxla#5740

COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0
PiperOrigin-RevId: 573976449
jeffhataws added a commit to jeffhataws/openxla that referenced this pull request Dec 11, 2023
…tter

Imported from GitHub PR openxla#5740

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

Copybara import of the project:

--
7ea1159 by Junmin Hao <[email protected]>:

Add Tuple input and token support to all-gather and reduce-scatter.

Committer: Junmin Hao <[email protected]>

--
cdb873e by Junmin Hao <[email protected]>:

lint fix

--
aad3521 by Jeffrey Huynh <[email protected]>:

Fix hlo_verifier_test failure due to changed expectation

--
32e8145 by Jeffrey Huynh <[email protected]>:

Separate the token change out into a separate PR with RFC.

--
b301c2a by Jeffrey Huynh <[email protected]>:

Change *WithToken tests to *WithTuple

--
5890278 by Jeffrey Huynh <[email protected]>:

Fix missing parenthesis

Merging this change closes openxla#5740

COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0
PiperOrigin-RevId: 573976449
copybara-service bot pushed a commit to tensorflow/mlir-hlo that referenced this pull request Jan 17, 2024
Imported from GitHub PR openxla/xla#8151

Currently AllGather in HLO supports multiple operands/results, while MHLO only supports a single operand/result. This change addresses the parity gap by adding MHLO AllGather variadic operands support.

This change was inspired by previous commit [2457fc1](openxla/xla@2457fc1) - [mhlo] AllReduce tuple support. Jun 7, 2023 by @GleasonK

Related commit:
- [PR-5740](openxla/xla#5740) [hlo] Add tuple input support to all-gather and reduce-scatter (Oct 16, 2023 by @jeffhataws)

@GleasonK @cheshire @burmako @jurahul @thomasjoerg Could you review this PR?
Copybara import of the project:

--
fb53ead74cbb40177a3680c8f807149d39c396b7 by Alexander Pivovarov <[email protected]>:

[mhlo] AllGather variadic operands support

Merging this change closes #8151

PiperOrigin-RevId: 599175008
copybara-service bot pushed a commit that referenced this pull request Jan 17, 2024
Imported from GitHub PR #8151

Currently AllGather in HLO supports multiple operands/results, while MHLO only supports a single operand/result. This change addresses the parity gap by adding MHLO AllGather variadic operands support.

This change was inspired by previous commit [2457fc1](2457fc1) - [mhlo] AllReduce tuple support. Jun 7, 2023 by @GleasonK

Related commit:
- [PR-5740](#5740) [hlo] Add tuple input support to all-gather and reduce-scatter (Oct 16, 2023 by @jeffhataws)

@GleasonK @cheshire @burmako @jurahul @thomasjoerg Could you review this PR?
Copybara import of the project:

--
fb53ead by Alexander Pivovarov <[email protected]>:

[mhlo] AllGather variadic operands support

Merging this change closes #8151

COPYBARA_INTEGRATE_REVIEW=#8151 from apivovarov:mhlo_allgather_variadic fb53ead
PiperOrigin-RevId: 599175008
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jan 17, 2024
Imported from GitHub PR openxla/xla#8151

Currently AllGather in HLO supports multiple operands/results, while MHLO only supports a single operand/result. This change addresses the parity gap by adding MHLO AllGather variadic operands support.

This change was inspired by previous commit [2457fc1](openxla/xla@2457fc1) - [mhlo] AllReduce tuple support. Jun 7, 2023 by @GleasonK

Related commit:
- [PR-5740](openxla/xla#5740) [hlo] Add tuple input support to all-gather and reduce-scatter (Oct 16, 2023 by @jeffhataws)

@GleasonK @cheshire @burmako @jurahul @thomasjoerg Could you review this PR?
Copybara import of the project:

--
fb53ead74cbb40177a3680c8f807149d39c396b7 by Alexander Pivovarov <[email protected]>:

[mhlo] AllGather variadic operands support

Merging this change closes #8151

PiperOrigin-RevId: 599175008
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready to pull PR ready for merge process
Projects
None yet
Development

Successfully merging this pull request may close these issues.