-
Notifications
You must be signed in to change notification settings - Fork 488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mhlo] AllGather variadic operands support #8151
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -949,21 +949,49 @@ LogicalResult ExportXlaOp(AddDependencyOp op, OpLoweringContext ctx) { | |
|
||
LogicalResult ExportXlaOp(AllGatherOp op, OpLoweringContext ctx) { | ||
auto& value_map = *ctx.values; | ||
xla::XlaOp operand; | ||
if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op))) | ||
return failure(); | ||
TensorType operand_type = op.getOperand().getType().cast<TensorType>(); | ||
TensorType result_type = op.getType(); | ||
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) | ||
|
||
SmallVector<xla::XlaOp> operands; | ||
if (failed(GetTuple(op.getOperation(), op.getOperands(), ctx, operands))) { | ||
return failure(); | ||
} | ||
|
||
mlir::FailureOr<xla::Shape> shape_or = ExtractXlaShape(op.getOperation()); | ||
if (failed(shape_or)) return failure(); | ||
|
||
auto all_gather_dim = op.getAllGatherDim(); | ||
int64_t shard_count = result_type.getDimSize(all_gather_dim) / | ||
operand_type.getDimSize(all_gather_dim); | ||
value_map[op] = xla::AllGather( | ||
operand, all_gather_dim, shard_count, | ||
Convert_replica_groups(op.getReplicaGroups()), | ||
Convert_channel_handle(op.getChannelHandle()), std::nullopt, | ||
Convert_use_global_device_ids(op.getUseGlobalDeviceIds())); | ||
int64_t shard_count = 0; | ||
for (size_t i = 0; i < operands.size(); ++i) { | ||
TensorType operand_type = op.getOperand(i).getType().cast<TensorType>(); | ||
TensorType result_type = op.getType(i).cast<TensorType>(); | ||
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) | ||
return failure(); | ||
if (i == 0) { | ||
shard_count = result_type.getDimSize(all_gather_dim) / | ||
operand_type.getDimSize(all_gather_dim); | ||
} | ||
} | ||
|
||
if (shape_or->IsTuple()) { | ||
std::optional<xla::Layout> layout = std::nullopt; | ||
if (shape_or->has_layout()) { | ||
layout = shape_or->layout(); | ||
} | ||
auto tuple = xla::AllGatherTuple( | ||
operands, all_gather_dim, shard_count, | ||
Convert_replica_groups(op.getReplicaGroups()), | ||
Convert_channel_handle(op.getChannelHandle()), layout, | ||
Convert_use_global_device_ids(op.getUseGlobalDeviceIds())); | ||
for (auto [index, result] : llvm::enumerate(op.getResults())) { | ||
value_map[result] = xla::GetTupleElement(tuple, index); | ||
} | ||
} else { | ||
value_map[op->getResults()[0]] = xla::AllGather( | ||
operands[0], all_gather_dim, shard_count, | ||
Convert_replica_groups(op.getReplicaGroups()), | ||
Convert_channel_handle(op.getChannelHandle()), std::nullopt, | ||
Convert_use_global_device_ids(op.getUseGlobalDeviceIds())); | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
|
@@ -1093,8 +1121,8 @@ LogicalResult ExportXlaOp(AsyncStartOp op, OpLoweringContext ctx) { | |
dyn_cast_or_null<AllGatherOp>(callee.getBody().front().front()); | ||
if (all_gather_op && SimplyReturnedOp(all_gather_op)) { | ||
TensorType operand_type = | ||
all_gather_op.getOperand().getType().cast<TensorType>(); | ||
TensorType result_type = all_gather_op.getType(); | ||
all_gather_op.getOperand(0).getType().cast<TensorType>(); | ||
TensorType result_type = all_gather_op.getType(0).cast<TensorType>(); | ||
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) | ||
return failure(); | ||
if (operands.size() != 1) return failure(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's raise this check before the call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
mhlo parser validates that |
||
|
@@ -1268,7 +1296,7 @@ LogicalResult ExportXlaOp(AsyncDoneOp op, OpLoweringContext ctx) { | |
if (all_gather_op && SimplyReturnedOp(all_gather_op)) { | ||
value_map[op.getResult(0)] = | ||
xla::internal::XlaBuilderFriend::BuildAllGatherDone( | ||
ctx.builder, operand, xla::TypeToShape(all_gather_op.getType())); | ||
ctx.builder, operand, xla::TypeToShape(all_gather_op.getType(0))); | ||
return success(); | ||
} | ||
auto all_reduce_op = | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we verify anywhere that there are same number of operand and results?
I'm not sure if there's an upstream MLIR trait for this, if not then let's add another check here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Added validation for that case to
AllGatherOp::verify
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am surprised there isn't an MLIR core trait to check this. If not, may be adding one for MLIR might be useful (not as a part of this PR though).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list of OpTraits is here in LLVM mlir/include/mlir/IR/OpDefinition.h. I do not see a trait which compares Number of Operands and Results.
BTW,
getOperands().empty()
check might be redundant becauseSameOperandsAndResultElementType::verifyTrait
callsOpTrait::impl::verifySameOperandsAndResultElementType(Operation *op)
Operation.cpp which has the following