Skip to content

Commit

Permalink
PR #8151: [mhlo] AllGather variadic operands support
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#8151

Currently AllGather in HLO supports multiple operands/results, while MHLO only supports a single operand/result. This change addresses the parity gap by adding MHLO AllGather variadic operands support.

This change was inspired by previous commit [2457fc1](openxla/xla@2457fc1) - [mhlo] AllReduce tuple support. Jun 7, 2023 by @GleasonK

Related commit:
- [PR-5740](openxla/xla#5740) [hlo] Add tuple input support to all-gather and reduce-scatter (Oct 16, 2023 by @jeffhataws)

@GleasonK @cheshire @burmako @jurahul @thomasjoerg Could you review this PR?
Copybara import of the project:

--
fb53ead74cbb40177a3680c8f807149d39c396b7 by Alexander Pivovarov <[email protected]>:

[mhlo] AllGather variadic operands support

Merging this change closes #8151

PiperOrigin-RevId: 599175008
  • Loading branch information
apivovarov authored and TensorFlow MLIR Team committed Jan 17, 2024
1 parent 4490cef commit 99207fc
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
22 changes: 17 additions & 5 deletions mhlo/IR/hlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2017,18 +2017,30 @@ LogicalResult AllGatherOp::verify() {
if (auto channelHandleAttr = getChannelHandleAttr())
channelId = channelHandleAttr.getHandle();

return hlo::verifyAllGatherOp(getLoc(), getOperand(), getAllGatherDim(),
getReplicaGroups(), channelId,
getUseGlobalDeviceIds(), getResult());
if (getOperands().empty())
return emitOptionalError(getLoc(),
"AllGather must have have at least one operand");
if (getNumOperands() != getNumResults())
return emitOptionalError(
getLoc(), "AllGather requires the same number of operands and results");

for (unsigned i = 0; i < getNumOperands(); ++i) {
if (failed(hlo::verifyAllGatherOp(
getLoc(), getOperand(i), getAllGatherDim(), getReplicaGroups(),
channelId, getUseGlobalDeviceIds(), getResult(i)))) {
return failure();
}
}
return success();
}

void AllGatherOp::build(OpBuilder& odsBuilder, OperationState& odsState,
Type resultType, Value operand,
IntegerAttr allGatherDim,
DenseIntElementsAttr replicaGroups,
ChannelHandleAttr channelHandle) {
AllGatherOp::build(odsBuilder, odsState, resultType, operand, allGatherDim,
replicaGroups, channelHandle,
AllGatherOp::build(odsBuilder, odsState, resultType, ValueRange(operand),
allGatherDim, replicaGroups, channelHandle,
/*use_global_device_ids=*/nullptr);
}

Expand Down
9 changes: 5 additions & 4 deletions mhlo/IR/hlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -1449,8 +1449,9 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]>
string summary = "AllGather operation";
string description = [{
Within each process group in the process grid, concatenates the values of the
`operand` tensor from each process along `all_gather_dim` and produces a
`result` tensor.
operand tensor from each process along `all_gather_dim` and produces a
result tensor. The `computation` is applied separately for each operand in
`operands`, producing one result per operand.

See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_gather
Expand All @@ -1468,13 +1469,13 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]>
}];

let arguments = (ins
MHLO_Tensor:$operand,
Variadic<MHLO_Tensor>:$operands,
I64Attr:$all_gather_dim,
I64ElementsAttr:$replica_groups,
OptionalAttr<MHLO_ChannelHandle>:$channel_handle,
UnitAttr:$use_global_device_ids
);
let results = (outs MHLO_Tensor);
let results = (outs Variadic<MHLO_Tensor>);
// use_global_device_ids is rarely used, so we add simplified builder methods
// for convenience.
let builders = [
Expand Down
11 changes: 11 additions & 0 deletions tests/Dialect/mhlo/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,17 @@ func.func @all_to_all_i5(%data: tensor<4x16xf32>) -> tensor<16x4xf32> {

// -----

func.func @all_gather_variadic(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>) {
%0:2 = "mhlo.all_gather"(%arg0, %arg1) {
all_gather_dim = 1 : i64,
channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
} : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>)
func.return %0#0, %0#1 : tensor<8x8xf32>, tensor<8x16xf32>
}

// -----

func.func @allgather_gather_along_zero_dimension(%arg0: tensor<128x0xf32>) -> tensor<128x100xf32> {
// expected-error@+1 {{dimension size of operand at 'all_gather_dim' cannot be zero}}
%0 = "mhlo.all_gather"(%arg0) {
Expand Down

0 comments on commit 99207fc

Please sign in to comment.