-
Notifications
You must be signed in to change notification settings - Fork 470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tuple input support to all-gather and reduce-scatter #5740
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Hi @jeffhataws, can you sign up on this CLA provided below. |
Done in the latest merge from main. |
39d57a6
to
aad3521
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, I think my comment on the earlier version of this change is still valid. Its doing 3 things (1) adding a convenience to pass in tuples into AllGather and ReduceScatter XLA Builder API and have the API flatten these tuples. and (2) Adding a way to control the ordering of the collective ops by threading them through "tokens", and (3) allowing these threading token to be either a token type in XLA or any scalar type.
Part 1 can be easily split into its own change is not concerning, so that should be the first step. Parts 2 and 3 really need to be discussed with the broader team. Can we start an RFC around this? XLA already has control dependency mechanism to arbitrarily order instructions. This will be a new mechanism for collectives. It seem we should have a common mechanism to express any scheduling constrains as HLO comes into XLA as I can image this token based scheme potentially applicable to all other ops, for various reasons (like controlling memory pressure or something else).
Also, allowing any scalar as a token for these instructions seems confusing and that should be discussed as well. Also, there is another aspect of this needing equivalent changes in StableHLO as well.
@burmako can you please advice on how HLO/StableHLO level changes like parts of this PR should be driven?
xla/client/xla_builder.cc
Outdated
return Unimplemented("0 element tuple AllGather is not supported"); | ||
} | ||
for (int i = 0; i < operand_shape->tuple_shapes_size(); ++i) { | ||
if (operand_shape->tuple_shapes(i).element_type() != |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this required here? I don't think an all-gather in particular has this restriction at HLO level. For all-reduce such restriction makes sense since the reduction computation is typed.
Also, can we share the tuple flattening code with all-reduce/reduce-scatter and make a helper function for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jurahul you are right. Removed the check for same shape.
The other request of sharing tuple flattening code will be a larger refactor across existing all-reduce/reduce-scatter so I want to avoid that to keep the changes to minimum and self-contained.
+1 to @jurahul's feedback. I would also like to request an RFC for (2) and (3), so that we can discuss design tradeoffs across multiple groups and end up with clear documentation of the new functionality. We haven't yet standardized on an RFC process for StableHLO, and I've seen several ways of approaching this. My recommendation would be to post a proposal on openxla-discuss, with a description of updated semantics for the affected ops. Here's a question which may help to get the RFC started. The StableHLO specification has a description of semantics and constraints for the existing collective ops, e.g. here's an example for all_gather. How would this specification need to change to accommodate the proposal in this PR? |
Hi @jeffhataws, any update on this? Thanks. |
Thanks for checking. Will post an update by next week, with RFC as requested. |
@@ -434,6 +434,7 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo, | |||
ag->use_global_device_ids())); | |||
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode)); | |||
TF_RET_CHECK(ag->all_gather_dimension() >= 0); | |||
TF_RET_CHECK(ag->operand_count() >= 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jurahul, @burmako, @radhakrishnaba I think this is correct. Let me know if you want me to remove this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also seems ok to me.
@@ -521,6 +522,7 @@ Status ShapeVerifier::HandleReduceScatter(HloInstruction* hlo) { | |||
ars->use_global_device_ids())); | |||
TF_RETURN_IF_ERROR(CheckReplicaGroups(ars, group_mode)); | |||
TF_RET_CHECK(ars->scatter_dimension() >= 0); | |||
TF_RET_CHECK(ars->operand_count() >= 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here.
@@ -1289,6 +1289,10 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { | |||
return HandleAllReduceMultipleReplica(crs); | |||
} | |||
|
|||
Status IrEmitter::HandleReduceScatter(HloInstruction* rs) { | |||
return Unimplemented("ReduceScatter is not implemented on CPU."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jurahul, @burmako, @radhakrishnaba I think this is correct. Let me know if you want me to remove this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems fine to me,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -1289,6 +1289,10 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { | |||
return HandleAllReduceMultipleReplica(crs); | |||
} | |||
|
|||
Status IrEmitter::HandleReduceScatter(HloInstruction* rs) { | |||
return Unimplemented("ReduceScatter is not implemented on CPU."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems fine to me,
@@ -434,6 +434,7 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo, | |||
ag->use_global_device_ids())); | |||
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode)); | |||
TF_RET_CHECK(ag->all_gather_dimension() >= 0); | |||
TF_RET_CHECK(ag->operand_count() >= 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also seems ok to me.
Hi @thomasjoerg will you help complete the review? Thanks. |
@ddunl can you please help take a look why this is blocked? |
Disabled the already failing tests by editing |
Imported from GitHub PR openxla/xla#5740 This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of #58377 and to be used in conjunction with pytorch/xla#5624 . In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation. Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed. Copybara import of the project: -- 7ea1159a1464efddebe9384e87ed6df504d89b2e by Junmin Hao <[email protected]>: Add Tuple input and token support to all-gather and reduce-scatter. Committer: Junmin Hao <[email protected]> -- cdb873e6d97f5f12b3d3c587bb5782d58e3554c5 by Junmin Hao <[email protected]>: lint fix -- aad352117ba950ac5ae62330e3980f4b5898a701 by Jeffrey Huynh <[email protected]>: Fix hlo_verifier_test failure due to changed expectation -- 32e814524b88a474af5e4e904c0dd19841430b86 by Jeffrey Huynh <[email protected]>: Separate the token change out into a separate PR with RFC. -- b301c2a2a5b52180f9e9626173e6b67a78782960 by Jeffrey Huynh <[email protected]>: Change *WithToken tests to *WithTuple -- 5890278fc16c9f900782d32a92d40ecf548aea85 by Jeffrey Huynh <[email protected]>: Fix missing parenthesis Merging this change closes #5740 PiperOrigin-RevId: 573976449
…tter Imported from GitHub PR openxla#5740 This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 . In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation. Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed. Copybara import of the project: -- 7ea1159 by Junmin Hao <[email protected]>: Add Tuple input and token support to all-gather and reduce-scatter. Committer: Junmin Hao <[email protected]> -- cdb873e by Junmin Hao <[email protected]>: lint fix -- aad3521 by Jeffrey Huynh <[email protected]>: Fix hlo_verifier_test failure due to changed expectation -- 32e8145 by Jeffrey Huynh <[email protected]>: Separate the token change out into a separate PR with RFC. -- b301c2a by Jeffrey Huynh <[email protected]>: Change *WithToken tests to *WithTuple -- 5890278 by Jeffrey Huynh <[email protected]>: Fix missing parenthesis Merging this change closes openxla#5740 COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0 PiperOrigin-RevId: 573976449
…tter Imported from GitHub PR openxla#5740 This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 . In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation. Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed. Copybara import of the project: -- 7ea1159 by Junmin Hao <[email protected]>: Add Tuple input and token support to all-gather and reduce-scatter. Committer: Junmin Hao <[email protected]> -- cdb873e by Junmin Hao <[email protected]>: lint fix -- aad3521 by Jeffrey Huynh <[email protected]>: Fix hlo_verifier_test failure due to changed expectation -- 32e8145 by Jeffrey Huynh <[email protected]>: Separate the token change out into a separate PR with RFC. -- b301c2a by Jeffrey Huynh <[email protected]>: Change *WithToken tests to *WithTuple -- 5890278 by Jeffrey Huynh <[email protected]>: Fix missing parenthesis Merging this change closes openxla#5740 COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0 PiperOrigin-RevId: 573976449
…tter Imported from GitHub PR openxla#5740 This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 . In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation. Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed. Copybara import of the project: -- 7ea1159 by Junmin Hao <[email protected]>: Add Tuple input and token support to all-gather and reduce-scatter. Committer: Junmin Hao <[email protected]> -- cdb873e by Junmin Hao <[email protected]>: lint fix -- aad3521 by Jeffrey Huynh <[email protected]>: Fix hlo_verifier_test failure due to changed expectation -- 32e8145 by Jeffrey Huynh <[email protected]>: Separate the token change out into a separate PR with RFC. -- b301c2a by Jeffrey Huynh <[email protected]>: Change *WithToken tests to *WithTuple -- 5890278 by Jeffrey Huynh <[email protected]>: Fix missing parenthesis Merging this change closes openxla#5740 COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0 PiperOrigin-RevId: 573976449
Imported from GitHub PR openxla/xla#8151 Currently AllGather in HLO supports multiple operands/results, while MHLO only supports a single operand/result. This change addresses the parity gap by adding MHLO AllGather variadic operands support. This change was inspired by previous commit [2457fc1](openxla/xla@2457fc1) - [mhlo] AllReduce tuple support. Jun 7, 2023 by @GleasonK Related commit: - [PR-5740](openxla/xla#5740) [hlo] Add tuple input support to all-gather and reduce-scatter (Oct 16, 2023 by @jeffhataws) @GleasonK @cheshire @burmako @jurahul @thomasjoerg Could you review this PR? Copybara import of the project: -- fb53ead74cbb40177a3680c8f807149d39c396b7 by Alexander Pivovarov <[email protected]>: [mhlo] AllGather variadic operands support Merging this change closes #8151 PiperOrigin-RevId: 599175008
Imported from GitHub PR #8151 Currently AllGather in HLO supports multiple operands/results, while MHLO only supports a single operand/result. This change addresses the parity gap by adding MHLO AllGather variadic operands support. This change was inspired by previous commit [2457fc1](2457fc1) - [mhlo] AllReduce tuple support. Jun 7, 2023 by @GleasonK Related commit: - [PR-5740](#5740) [hlo] Add tuple input support to all-gather and reduce-scatter (Oct 16, 2023 by @jeffhataws) @GleasonK @cheshire @burmako @jurahul @thomasjoerg Could you review this PR? Copybara import of the project: -- fb53ead by Alexander Pivovarov <[email protected]>: [mhlo] AllGather variadic operands support Merging this change closes #8151 COPYBARA_INTEGRATE_REVIEW=#8151 from apivovarov:mhlo_allgather_variadic fb53ead PiperOrigin-RevId: 599175008
Imported from GitHub PR openxla/xla#8151 Currently AllGather in HLO supports multiple operands/results, while MHLO only supports a single operand/result. This change addresses the parity gap by adding MHLO AllGather variadic operands support. This change was inspired by previous commit [2457fc1](openxla/xla@2457fc1) - [mhlo] AllReduce tuple support. Jun 7, 2023 by @GleasonK Related commit: - [PR-5740](openxla/xla#5740) [hlo] Add tuple input support to all-gather and reduce-scatter (Oct 16, 2023 by @jeffhataws) @GleasonK @cheshire @burmako @jurahul @thomasjoerg Could you review this PR? Copybara import of the project: -- fb53ead74cbb40177a3680c8f807149d39c396b7 by Alexander Pivovarov <[email protected]>: [mhlo] AllGather variadic operands support Merging this change closes #8151 PiperOrigin-RevId: 599175008
This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 .
In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.
Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.