From 99207fc93a118c0cb42f06cbe0804a8ed68e80b2 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Wed, 17 Jan 2024 08:02:10 -0800 Subject: [PATCH] PR #8151: [mhlo] AllGather variadic operands support Imported from GitHub PR https://github.com/openxla/xla/pull/8151 Currently AllGather in HLO supports multiple operands/results, while MHLO only supports a single operand/result. This change addresses the parity gap by adding MHLO AllGather variadic operands support. This change was inspired by previous commit [2457fc1](https://github.com/openxla/xla/commit/2457fc12a6cdfed22daa70a7303f166d502af7dd) - [mhlo] AllReduce tuple support. Jun 7, 2023 by @GleasonK Related commit: - [PR-5740](https://github.com/openxla/xla/pull/5740) [hlo] Add tuple input support to all-gather and reduce-scatter (Oct 16, 2023 by @jeffhataws) @GleasonK @cheshire @burmako @jurahul @thomasjoerg Could you review this PR? Copybara import of the project: -- fb53ead74cbb40177a3680c8f807149d39c396b7 by Alexander Pivovarov : [mhlo] AllGather variadic operands support Merging this change closes #8151 PiperOrigin-RevId: 599175008 --- mhlo/IR/hlo_ops.cc | 22 +++++++++++++++++----- mhlo/IR/hlo_ops.td | 9 +++++---- tests/Dialect/mhlo/ops.mlir | 11 +++++++++++ 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/mhlo/IR/hlo_ops.cc b/mhlo/IR/hlo_ops.cc index 5c0685bd0..ccddc63c7 100644 --- a/mhlo/IR/hlo_ops.cc +++ b/mhlo/IR/hlo_ops.cc @@ -2017,9 +2017,21 @@ LogicalResult AllGatherOp::verify() { if (auto channelHandleAttr = getChannelHandleAttr()) channelId = channelHandleAttr.getHandle(); - return hlo::verifyAllGatherOp(getLoc(), getOperand(), getAllGatherDim(), - getReplicaGroups(), channelId, - getUseGlobalDeviceIds(), getResult()); + if (getOperands().empty()) + return emitOptionalError(getLoc(), + "AllGather must have have at least one operand"); + if (getNumOperands() != getNumResults()) + return emitOptionalError( + getLoc(), "AllGather requires the same number of operands and results"); + + for (unsigned i = 0; i < getNumOperands(); ++i) { + if (failed(hlo::verifyAllGatherOp( + getLoc(), getOperand(i), getAllGatherDim(), getReplicaGroups(), + channelId, getUseGlobalDeviceIds(), getResult(i)))) { + return failure(); + } + } + return success(); } void AllGatherOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -2027,8 +2039,8 @@ void AllGatherOp::build(OpBuilder& odsBuilder, OperationState& odsState, IntegerAttr allGatherDim, DenseIntElementsAttr replicaGroups, ChannelHandleAttr channelHandle) { - AllGatherOp::build(odsBuilder, odsState, resultType, operand, allGatherDim, - replicaGroups, channelHandle, + AllGatherOp::build(odsBuilder, odsState, resultType, ValueRange(operand), + allGatherDim, replicaGroups, channelHandle, /*use_global_device_ids=*/nullptr); } diff --git a/mhlo/IR/hlo_ops.td b/mhlo/IR/hlo_ops.td index ccca005dc..5d25d6ae4 100644 --- a/mhlo/IR/hlo_ops.td +++ b/mhlo/IR/hlo_ops.td @@ -1449,8 +1449,9 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]> string summary = "AllGather operation"; string description = [{ Within each process group in the process grid, concatenates the values of the - `operand` tensor from each process along `all_gather_dim` and produces a - `result` tensor. + operand tensor from each process along `all_gather_dim` and produces a + result tensor. The `computation` is applied separately for each operand in + `operands`, producing one result per operand. See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_gather @@ -1468,13 +1469,13 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]> }]; let arguments = (ins - MHLO_Tensor:$operand, + Variadic:$operands, I64Attr:$all_gather_dim, I64ElementsAttr:$replica_groups, OptionalAttr:$channel_handle, UnitAttr:$use_global_device_ids ); - let results = (outs MHLO_Tensor); + let results = (outs Variadic); // use_global_device_ids is rarely used, so we add simplified builder methods // for convenience. let builders = [ diff --git a/tests/Dialect/mhlo/ops.mlir b/tests/Dialect/mhlo/ops.mlir index 9bb91a3f3..5db74efa7 100644 --- a/tests/Dialect/mhlo/ops.mlir +++ b/tests/Dialect/mhlo/ops.mlir @@ -753,6 +753,17 @@ func.func @all_to_all_i5(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { // ----- +func.func @all_gather_variadic(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>) { + %0:2 = "mhlo.all_gather"(%arg0, %arg1) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>) + func.return %0#0, %0#1 : tensor<8x8xf32>, tensor<8x16xf32> +} + +// ----- + func.func @allgather_gather_along_zero_dimension(%arg0: tensor<128x0xf32>) -> tensor<128x100xf32> { // expected-error@+1 {{dimension size of operand at 'all_gather_dim' cannot be zero}} %0 = "mhlo.all_gather"(%arg0) {