-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mhlo] AllGather variadic operands support #8151
Conversation
2e3f49f
to
ed93860
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good based on the similarity to [mhlo] AllReduce tuple support.
I wouldn't mind if GleasonK also took a look.
xla/mlir_hlo/mhlo/IR/hlo_ops.td
Outdated
def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]> { | ||
def MHLO_AllGatherOp : MHLO_Op<"all_gather", [ | ||
SameOperandsAndResultElementType, | ||
SingleBlockImplicitTerminator<"ReturnOp"> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The SingleBlockImplicitTerminator trait is not needed here. Unlike all-reduce op which has an attached region for the reduction computation, AllGather does not, so this trait is not relevant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
ed93860
to
567b97b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Apologies for the delay, this fell off my radar with holidays!
getUseGlobalDeviceIds(), getResult()); | ||
if (getOperands().empty()) | ||
return emitOptionalError(getLoc(), | ||
"AllGather must have have at least one operand"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we verify anywhere that there are same number of operand and results?
I'm not sure if there's an upstream MLIR trait for this, if not then let's add another check here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Added validation for that case to AllGatherOp::verify
if (getNumOperands() != getNumResults())
return emitOptionalError(
getLoc(),
"AllGather requires the same number of operands and results");
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am surprised there isn't an MLIR core trait to check this. If not, may be adding one for MLIR might be useful (not as a part of this PR though).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list of OpTraits is here in LLVM mlir/include/mlir/IR/OpDefinition.h. I do not see a trait which compares Number of Operands and Results.
BTW, getOperands().empty()
check might be redundant because SameOperandsAndResultElementType::verifyTrait
calls OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op)
Operation.cpp which has the following
if (failed(verifyAtLeastNOperands(op, 1)) ||
failed(verifyAtLeastNResults(op, 1)))
return failure();
all_gather_op.getOperand().getType().cast<TensorType>(); | ||
TensorType result_type = all_gather_op.getType(); | ||
all_gather_op.getOperand(0).getType().cast<TensorType>(); | ||
TensorType result_type = all_gather_op.getType(0).cast<TensorType>(); | ||
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) | ||
return failure(); | ||
if (operands.size() != 1) return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's raise this check before the call to getOperand(0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mhlo/IR/hlo_ops.td
defines that MHLO_AllGatherOp
has Variadic<MHLO_Tensor>:$operands,
I tried to use mhlo.all_gather()
without operands in export.mlir - got error
error: unexpected error: 'mhlo.all_gather' op expected 1 or more operands, but found 0
%0:2 = "mhlo.all_gather"() {
mhlo parser validates that mhlo.all_gather()
should have at least one operand.
It should be safe to use all_gather_op.getOperand(0)
in mlir_hlo_to_hlo.cc::ExportXlaOp
since mlir validation happens before ExportXlaOp
call.
567b97b
to
713913e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM.
713913e
to
c9331d0
Compare
xla/mlir_hlo/mhlo/IR/hlo_ops.cc
Outdated
getLoc(), | ||
"AllGather requires the same number of operands and results"); | ||
|
||
for (auto i = 0; i < getNumOperands(); ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mhlo/IR/hlo_ops.cc:2027: error: comparison of integers of different signs: 'int' and 'unsigned int' [-Werror,-Wsign-compare]
for (auto i = 0; i < getNumOperands(); ++i) {
~ ^ ~~~~~~~~~~~~~~~~
We have an internal build failure on this. Let's use unsigned i
to make sure the proper type is inferred.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed in 2 places
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import/copybara still failing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the next action is on my end, need to update some internal callsites now that getOperand()/getResult()
methods don't exist anymore (since they assume 1op/res). Will report back shortly.
Edit: Need to address the bazelrc feedback then I should be able to take care of the remaining integ work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed .tf_configure.bazelrc
and tools/python_bin_path.sh
.
import/copybara
check is OK now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe I've made all the internal changes needed. Hopefully should have this merged tomorrow. I'll need to submit internally, once that happens this PR will be marked as closed and a GH commit attributed to you will be added to the main branch. Will give an update tomorrow!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank You, Kevin, for your help!
c9331d0
to
93e11aa
Compare
93e11aa
to
e8bcccb
Compare
.tf_configure.bazelrc
Outdated
@@ -0,0 +1,16 @@ | |||
build --action_env PYTHON_BIN_PATH="/usr/local/bin/python3.9" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this and tools/python_bin_path.sh
should be removed. (I'll add .tf_configure.bazelrc
to the gitignore, I'm surprised this hasn't happened before)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed both. Sorry for the mess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries at all!!
e8bcccb
to
fb53ead
Compare
Imported from GitHub PR openxla/xla#8151 Currently AllGather in HLO supports multiple operands/results, while MHLO only supports a single operand/result. This change addresses the parity gap by adding MHLO AllGather variadic operands support. This change was inspired by previous commit [2457fc1](openxla/xla@2457fc1) - [mhlo] AllReduce tuple support. Jun 7, 2023 by @GleasonK Related commit: - [PR-5740](openxla/xla#5740) [hlo] Add tuple input support to all-gather and reduce-scatter (Oct 16, 2023 by @jeffhataws) @GleasonK @cheshire @burmako @jurahul @thomasjoerg Could you review this PR? Copybara import of the project: -- fb53ead74cbb40177a3680c8f807149d39c396b7 by Alexander Pivovarov <[email protected]>: [mhlo] AllGather variadic operands support Merging this change closes #8151 PiperOrigin-RevId: 599175008
Imported from GitHub PR openxla/xla#8151 Currently AllGather in HLO supports multiple operands/results, while MHLO only supports a single operand/result. This change addresses the parity gap by adding MHLO AllGather variadic operands support. This change was inspired by previous commit [2457fc1](openxla/xla@2457fc1) - [mhlo] AllReduce tuple support. Jun 7, 2023 by @GleasonK Related commit: - [PR-5740](openxla/xla#5740) [hlo] Add tuple input support to all-gather and reduce-scatter (Oct 16, 2023 by @jeffhataws) @GleasonK @cheshire @burmako @jurahul @thomasjoerg Could you review this PR? Copybara import of the project: -- fb53ead74cbb40177a3680c8f807149d39c396b7 by Alexander Pivovarov <[email protected]>: [mhlo] AllGather variadic operands support Merging this change closes #8151 PiperOrigin-RevId: 599175008
Do I also need to update the documentation https://www.tensorflow.org/mlir/hlo_ops#mhloall_gather_mhloallgatherop? |
Seems that the documentation was updated automatically
|
Currently AllGather in HLO supports multiple operands/results, while MHLO only supports a single operand/result. This change addresses the parity gap by adding MHLO AllGather variadic operands support.
This change was inspired by previous commit 2457fc1 - [mhlo] AllReduce tuple support. Jun 7, 2023 by @GleasonK
Related commit:
@GleasonK @cheshire @burmako @jurahul @thomasjoerg Could you review this PR?