Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mhlo] AllGather variadic operands support #8151

Closed
wants to merge 1 commit into from

Conversation

apivovarov
Copy link
Contributor

@apivovarov apivovarov commented Jan 3, 2024

Currently AllGather in HLO supports multiple operands/results, while MHLO only supports a single operand/result. This change addresses the parity gap by adding MHLO AllGather variadic operands support.

This change was inspired by previous commit 2457fc1 - [mhlo] AllReduce tuple support. Jun 7, 2023 by @GleasonK

Related commit:

  • PR-5740 [hlo] Add tuple input support to all-gather and reduce-scatter (Oct 16, 2023 by @jeffhataws)

@GleasonK @cheshire @burmako @jurahul @thomasjoerg Could you review this PR?

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 3, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 3, 2024
@apivovarov apivovarov force-pushed the mhlo_allgather_variadic branch from 2e3f49f to ed93860 Compare January 3, 2024 23:38
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 3, 2024
@apivovarov apivovarov changed the title MHLO AllReduce variadic operands support MHLO AllGather variadic operands support Jan 3, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 3, 2024
@apivovarov apivovarov changed the title MHLO AllGather variadic operands support [mhlo] AllGather variadic operands support Jan 3, 2024
@kamaljeeti kamaljeeti requested a review from tdanyluk January 4, 2024 04:10
@tdanyluk tdanyluk requested a review from GleasonK January 4, 2024 15:32
Copy link
Contributor

@tdanyluk tdanyluk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good based on the similarity to [mhlo] AllReduce tuple support.

I wouldn't mind if GleasonK also took a look.

def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]> {
def MHLO_AllGatherOp : MHLO_Op<"all_gather", [
SameOperandsAndResultElementType,
SingleBlockImplicitTerminator<"ReturnOp">
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SingleBlockImplicitTerminator trait is not needed here. Unlike all-reduce op which has an attached region for the reduction computation, AllGather does not, so this trait is not relevant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@apivovarov apivovarov force-pushed the mhlo_allgather_variadic branch from ed93860 to 567b97b Compare January 4, 2024 19:30
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 4, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 4, 2024
@apivovarov
Copy link
Contributor Author

apivovarov commented Jan 9, 2024

Hi Tamás, hi Rahul, seems that Kevin does not have time/resources to look at this PR. Can we merge it without his review?
@tdanyluk @jurahul

Copy link
Member

@GleasonK GleasonK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Apologies for the delay, this fell off my radar with holidays!

getUseGlobalDeviceIds(), getResult());
if (getOperands().empty())
return emitOptionalError(getLoc(),
"AllGather must have have at least one operand");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we verify anywhere that there are same number of operand and results?

I'm not sure if there's an upstream MLIR trait for this, if not then let's add another check here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Added validation for that case to AllGatherOp::verify

  if (getNumOperands() != getNumResults())
    return emitOptionalError(
        getLoc(),
        "AllGather requires the same number of operands and results");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised there isn't an MLIR core trait to check this. If not, may be adding one for MLIR might be useful (not as a part of this PR though).

Copy link
Contributor Author

@apivovarov apivovarov Jan 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list of OpTraits is here in LLVM mlir/include/mlir/IR/OpDefinition.h. I do not see a trait which compares Number of Operands and Results.
BTW, getOperands().empty() check might be redundant because SameOperandsAndResultElementType::verifyTrait calls OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) Operation.cpp which has the following

  if (failed(verifyAtLeastNOperands(op, 1)) ||
      failed(verifyAtLeastNResults(op, 1)))
    return failure();

all_gather_op.getOperand().getType().cast<TensorType>();
TensorType result_type = all_gather_op.getType();
all_gather_op.getOperand(0).getType().cast<TensorType>();
TensorType result_type = all_gather_op.getType(0).cast<TensorType>();
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
return failure();
if (operands.size() != 1) return failure();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's raise this check before the call to getOperand(0)

Copy link
Contributor Author

@apivovarov apivovarov Jan 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mhlo/IR/hlo_ops.td defines that MHLO_AllGatherOp has Variadic<MHLO_Tensor>:$operands,
I tried to use mhlo.all_gather() without operands in export.mlir - got error

error: unexpected error: 'mhlo.all_gather' op expected 1 or more operands, but found 0
  %0:2 = "mhlo.all_gather"() {

mhlo parser validates that mhlo.all_gather() should have at least one operand.
It should be safe to use all_gather_op.getOperand(0) in mlir_hlo_to_hlo.cc::ExportXlaOp since mlir validation happens before ExportXlaOp call.

@apivovarov apivovarov force-pushed the mhlo_allgather_variadic branch from 567b97b to 713913e Compare January 10, 2024 01:11
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 10, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 10, 2024
@apivovarov
Copy link
Contributor Author

Just want to check if any other changes are needed for this PR? @GleasonK @jurahul @tdanyluk

@apivovarov
Copy link
Contributor Author

apivovarov commented Jan 16, 2024

This PR summary:

Should we finally merge this PR? @ezhulenev @GleasonK @jurahul @tdanyluk

Copy link
Contributor

@jurahul jurahul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM.

@apivovarov apivovarov force-pushed the mhlo_allgather_variadic branch from 713913e to c9331d0 Compare January 16, 2024 20:24
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 16, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 16, 2024
getLoc(),
"AllGather requires the same number of operands and results");

for (auto i = 0; i < getNumOperands(); ++i) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mhlo/IR/hlo_ops.cc:2027: error: comparison of integers of different signs: 'int' and 'unsigned int' [-Werror,-Wsign-compare]
  for (auto i = 0; i < getNumOperands(); ++i) {
                   ~ ^ ~~~~~~~~~~~~~~~~

We have an internal build failure on this. Let's use unsigned i to make sure the proper type is inferred.

Copy link
Contributor Author

@apivovarov apivovarov Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in 2 places

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import/copybara still failing

Copy link
Member

@GleasonK GleasonK Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the next action is on my end, need to update some internal callsites now that getOperand()/getResult() methods don't exist anymore (since they assume 1op/res). Will report back shortly.

Edit: Need to address the bazelrc feedback then I should be able to take care of the remaining integ work.

Copy link
Contributor Author

@apivovarov apivovarov Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed .tf_configure.bazelrc and tools/python_bin_path.sh.
import/copybara check is OK now!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe I've made all the internal changes needed. Hopefully should have this merged tomorrow. I'll need to submit internally, once that happens this PR will be marked as closed and a GH commit attributed to you will be added to the main branch. Will give an update tomorrow!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank You, Kevin, for your help!

@apivovarov apivovarov force-pushed the mhlo_allgather_variadic branch from c9331d0 to 93e11aa Compare January 16, 2024 21:10
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 16, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 16, 2024
@apivovarov apivovarov force-pushed the mhlo_allgather_variadic branch from 93e11aa to e8bcccb Compare January 16, 2024 21:24
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 16, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 16, 2024
@@ -0,0 +1,16 @@
build --action_env PYTHON_BIN_PATH="/usr/local/bin/python3.9"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this and tools/python_bin_path.sh should be removed. (I'll add .tf_configure.bazelrc to the gitignore, I'm surprised this hasn't happened before)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed both. Sorry for the mess

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries at all!!

@apivovarov apivovarov force-pushed the mhlo_allgather_variadic branch from e8bcccb to fb53ead Compare January 16, 2024 23:27
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 16, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 16, 2024
@GleasonK GleasonK self-assigned this Jan 17, 2024
copybara-service bot pushed a commit to tensorflow/mlir-hlo that referenced this pull request Jan 17, 2024
Imported from GitHub PR openxla/xla#8151

Currently AllGather in HLO supports multiple operands/results, while MHLO only supports a single operand/result. This change addresses the parity gap by adding MHLO AllGather variadic operands support.

This change was inspired by previous commit [2457fc1](openxla/xla@2457fc1) - [mhlo] AllReduce tuple support. Jun 7, 2023 by @GleasonK

Related commit:
- [PR-5740](openxla/xla#5740) [hlo] Add tuple input support to all-gather and reduce-scatter (Oct 16, 2023 by @jeffhataws)

@GleasonK @cheshire @burmako @jurahul @thomasjoerg Could you review this PR?
Copybara import of the project:

--
fb53ead74cbb40177a3680c8f807149d39c396b7 by Alexander Pivovarov <[email protected]>:

[mhlo] AllGather variadic operands support

Merging this change closes #8151

PiperOrigin-RevId: 599175008
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jan 17, 2024
Imported from GitHub PR openxla/xla#8151

Currently AllGather in HLO supports multiple operands/results, while MHLO only supports a single operand/result. This change addresses the parity gap by adding MHLO AllGather variadic operands support.

This change was inspired by previous commit [2457fc1](openxla/xla@2457fc1) - [mhlo] AllReduce tuple support. Jun 7, 2023 by @GleasonK

Related commit:
- [PR-5740](openxla/xla#5740) [hlo] Add tuple input support to all-gather and reduce-scatter (Oct 16, 2023 by @jeffhataws)

@GleasonK @cheshire @burmako @jurahul @thomasjoerg Could you review this PR?
Copybara import of the project:

--
fb53ead74cbb40177a3680c8f807149d39c396b7 by Alexander Pivovarov <[email protected]>:

[mhlo] AllGather variadic operands support

Merging this change closes #8151

PiperOrigin-RevId: 599175008
@apivovarov apivovarov deleted the mhlo_allgather_variadic branch January 17, 2024 17:12
@apivovarov
Copy link
Contributor Author

Do I also need to update the documentation https://www.tensorflow.org/mlir/hlo_ops#mhloall_gather_mhloallgatherop?
What repo has the source of this page?

@apivovarov
Copy link
Contributor Author

Seems that the documentation was updated automatically

operands | variadic of tensor of f8E4M3B11FNUZ type ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants