diff --git a/xla/client/xla_builder.cc b/xla/client/xla_builder.cc index 63bcebcdd2732..2fad69041d74b 100644 --- a/xla/client/xla_builder.cc +++ b/xla/client/xla_builder.cc @@ -5015,6 +5015,18 @@ XlaOp AllGather(const XlaOp operand, int64_t all_gather_dimension, layout, use_global_device_ids); } +XlaOp AllGatherTuple(const absl::Span operands, + int64_t all_gather_dimension, int64_t shard_count, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& layout, + const std::optional use_global_device_ids) { + CHECK(!operands.empty()); + return operands[0].builder()->AllGather( + operands[0].builder()->Tuple(operands), all_gather_dimension, shard_count, + replica_groups, channel_id, layout, use_global_device_ids); +} + XlaOp CrossReplicaSum(const XlaOp operand, absl::Span replica_groups) { return operand.builder()->CrossReplicaSum(operand, replica_groups); diff --git a/xla/client/xla_builder.h b/xla/client/xla_builder.h index 973a80cfc513b..356f6a94c9a3d 100644 --- a/xla/client/xla_builder.h +++ b/xla/client/xla_builder.h @@ -1437,6 +1437,12 @@ class XlaBuilder { const std::optional& channel_id, const std::optional& layout, std::optional use_global_device_ids); + friend XlaOp AllGatherTuple(absl::Span operands, + int64_t all_gather_dimension, int64_t shard_count, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& layout, + std::optional use_global_device_ids); friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, absl::Span replica_groups, const std::optional& channel_id, @@ -2431,6 +2437,13 @@ XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, const std::optional& layout = std::nullopt, std::optional use_global_device_ids = std::nullopt); +XlaOp AllGatherTuple( + absl::Span operands, int64_t all_gather_dimension, + int64_t shard_count, absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then // broadcasting the reduction result to those cores. The reduction function is diff --git a/xla/client/xla_builder_test.cc b/xla/client/xla_builder_test.cc index 66ae9ba8a690f..ddd1d51f1691a 100644 --- a/xla/client/xla_builder_test.cc +++ b/xla/client/xla_builder_test.cc @@ -454,6 +454,21 @@ TEST_F(XlaBuilderTest, AllGatherWithTuple) { ShapeUtil::MakeShape(F32, {64, 4})}))); } +TEST_F(XlaBuilderTest, AllGatherTuple) { + XlaBuilder b(TestName()); + auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {128, 4}), "p0"); + auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {128, 8}), "p1"); + AllGatherTuple({p0, p1}, /*all_gather_dimension=*/1, /*shard_count=*/4); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + auto tuple_shape = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {128, 16}), + ShapeUtil::MakeShape(F32, {128, 32})}); + EXPECT_THAT(root, GmockMatch(m::Op() + .WithOpcode(HloOpcode::kAllGather) + .WithShapeEqualTo(&tuple_shape))); +} + TEST_F(XlaBuilderTest, ReduceScatter) { XlaBuilder b(TestName()); XlaComputation to_apply; diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 5c0685bd0ae7a..84eec165c3b86 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -2017,9 +2017,22 @@ LogicalResult AllGatherOp::verify() { if (auto channelHandleAttr = getChannelHandleAttr()) channelId = channelHandleAttr.getHandle(); - return hlo::verifyAllGatherOp(getLoc(), getOperand(), getAllGatherDim(), - getReplicaGroups(), channelId, - getUseGlobalDeviceIds(), getResult()); + if (getOperands().empty()) + return emitOptionalError(getLoc(), + "AllGather must have have at least one operand"); + if (getNumOperands() != getNumResults()) + return emitOptionalError( + getLoc(), + "AllGather requires the same number of operands and results"); + + for (unsigned i = 0; i < getNumOperands(); ++i) { + if (failed(hlo::verifyAllGatherOp( + getLoc(), getOperand(i), getAllGatherDim(), getReplicaGroups(), + channelId, getUseGlobalDeviceIds(), getResult(i)))) { + return failure(); + } + } + return success(); } void AllGatherOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -2027,8 +2040,8 @@ void AllGatherOp::build(OpBuilder& odsBuilder, OperationState& odsState, IntegerAttr allGatherDim, DenseIntElementsAttr replicaGroups, ChannelHandleAttr channelHandle) { - AllGatherOp::build(odsBuilder, odsState, resultType, operand, allGatherDim, - replicaGroups, channelHandle, + AllGatherOp::build(odsBuilder, odsState, resultType, ValueRange(operand), + allGatherDim, replicaGroups, channelHandle, /*use_global_device_ids=*/nullptr); } diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/xla/mlir_hlo/mhlo/IR/hlo_ops.td index ccca005dc3554..5d25d6ae4583c 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -1449,8 +1449,9 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]> string summary = "AllGather operation"; string description = [{ Within each process group in the process grid, concatenates the values of the - `operand` tensor from each process along `all_gather_dim` and produces a - `result` tensor. + operand tensor from each process along `all_gather_dim` and produces a + result tensor. The `computation` is applied separately for each operand in + `operands`, producing one result per operand. See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_gather @@ -1468,13 +1469,13 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]> }]; let arguments = (ins - MHLO_Tensor:$operand, + Variadic:$operands, I64Attr:$all_gather_dim, I64ElementsAttr:$replica_groups, OptionalAttr:$channel_handle, UnitAttr:$use_global_device_ids ); - let results = (outs MHLO_Tensor); + let results = (outs Variadic); // use_global_device_ids is rarely used, so we add simplified builder methods // for convenience. let builders = [ diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index 9bb91a3f3ac8e..5db74efa72f52 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -753,6 +753,17 @@ func.func @all_to_all_i5(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { // ----- +func.func @all_gather_variadic(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>) { + %0:2 = "mhlo.all_gather"(%arg0, %arg1) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>) + func.return %0#0, %0#1 : tensor<8x8xf32>, tensor<8x16xf32> +} + +// ----- + func.func @allgather_gather_along_zero_dimension(%arg0: tensor<128x0xf32>) -> tensor<128x100xf32> { // expected-error@+1 {{dimension size of operand at 'all_gather_dim' cannot be zero}} %0 = "mhlo.all_gather"(%arg0) { diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc index 0d9705c16320a..d3632cf29e9cf 100644 --- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -1422,6 +1422,12 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( } case HloOpcode::kAllGather: { auto all_gather = Cast(instruction); + auto result_tuple_ty = result_type.dyn_cast(); + + llvm::SmallVector result_types = {result_type}; + if (result_tuple_ty) { + result_types = llvm::to_vector(result_tuple_ty.getTypes()); + } attributes.push_back(builder_->getNamedAttr( "all_gather_dim", builder_->getI64IntegerAttr(all_gather->all_gather_dimension()))); @@ -1432,10 +1438,15 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( ConvertChannelHandle(all_gather->channel_id().value())); if (all_gather->use_global_device_ids()) attributes.push_back(ConvertUseGlobalDeviceIds()); - return func_builder - ->create(loc, result_type, operands, - attributes) - .getOperation(); + auto all_gather_op = func_builder->create( + loc, result_types, operands, attributes); + if (result_tuple_ty) { + return func_builder + ->create(loc, result_type, + all_gather_op.getResults()) + .getOperation(); + } + return all_gather_op.getOperation(); } case HloOpcode::kAllGatherStart: { auto all_gather_start = Cast(instruction); @@ -1449,6 +1460,9 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( ConvertChannelHandle(all_gather_start->channel_id().value())); if (all_gather_start->use_global_device_ids()) attributes.push_back(ConvertUseGlobalDeviceIds()); + if (all_gather_start->operands().size() > 1) + return InvalidArgument( + "Async tuple all-gather is not supported in MHLO"); return ImportOldStyleAsyncStart( attributes, operands, loc, result_type, func_builder, "all_gather_", diff --git a/xla/translate/hlo_to_mhlo/tests/import.hlotxt b/xla/translate/hlo_to_mhlo/tests/import.hlotxt index c6f92908f0bcf..5df4e887cd26c 100644 --- a/xla/translate/hlo_to_mhlo/tests/import.hlotxt +++ b/xla/translate/hlo_to_mhlo/tests/import.hlotxt @@ -72,6 +72,18 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT ag = f32[128,128] all-gather(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true } +// CHECK-LABEL: func private @test_all_gather_variadic +%test_all_gather_variadic { + input.0 = f32[128,8] parameter(0) + input.1 = f32[128,16] parameter(1) + // CHECK-NEXT: "mhlo.all_gather"(%arg0, %arg1) + // CHECK-SAME: all_gather_dim = 1 : i64 + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + // CHECK-SAME: use_global_device_ids + ROOT ag = (f32[128,32], f32[128,64]) all-gather(f32[128,8] input.0, f32[128,16] input.1), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true +} + // Test all-to-all // CHECK-LABEL: func private @test_all_to_all diff --git a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 414d3e72689f0..ac5897c4de8aa 100644 --- a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -949,21 +949,49 @@ LogicalResult ExportXlaOp(AddDependencyOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(AllGatherOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; - xla::XlaOp operand; - if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op))) - return failure(); - TensorType operand_type = op.getOperand().getType().cast(); - TensorType result_type = op.getType(); - if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) + + SmallVector operands; + if (failed(GetTuple(op.getOperation(), op.getOperands(), ctx, operands))) { return failure(); + } + + mlir::FailureOr shape_or = ExtractXlaShape(op.getOperation()); + if (failed(shape_or)) return failure(); + auto all_gather_dim = op.getAllGatherDim(); - int64_t shard_count = result_type.getDimSize(all_gather_dim) / - operand_type.getDimSize(all_gather_dim); - value_map[op] = xla::AllGather( - operand, all_gather_dim, shard_count, - Convert_replica_groups(op.getReplicaGroups()), - Convert_channel_handle(op.getChannelHandle()), std::nullopt, - Convert_use_global_device_ids(op.getUseGlobalDeviceIds())); + int64_t shard_count = 0; + for (size_t i = 0; i < operands.size(); ++i) { + TensorType operand_type = op.getOperand(i).getType().cast(); + TensorType result_type = op.getType(i).cast(); + if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) + return failure(); + if (i == 0) { + shard_count = result_type.getDimSize(all_gather_dim) / + operand_type.getDimSize(all_gather_dim); + } + } + + if (shape_or->IsTuple()) { + std::optional layout = std::nullopt; + if (shape_or->has_layout()) { + layout = shape_or->layout(); + } + auto tuple = xla::AllGatherTuple( + operands, all_gather_dim, shard_count, + Convert_replica_groups(op.getReplicaGroups()), + Convert_channel_handle(op.getChannelHandle()), layout, + Convert_use_global_device_ids(op.getUseGlobalDeviceIds())); + for (auto [index, result] : llvm::enumerate(op.getResults())) { + value_map[result] = xla::GetTupleElement(tuple, index); + } + } else { + value_map[op->getResults()[0]] = xla::AllGather( + operands[0], all_gather_dim, shard_count, + Convert_replica_groups(op.getReplicaGroups()), + Convert_channel_handle(op.getChannelHandle()), std::nullopt, + Convert_use_global_device_ids(op.getUseGlobalDeviceIds())); + } + return success(); } @@ -1093,8 +1121,8 @@ LogicalResult ExportXlaOp(AsyncStartOp op, OpLoweringContext ctx) { dyn_cast_or_null(callee.getBody().front().front()); if (all_gather_op && SimplyReturnedOp(all_gather_op)) { TensorType operand_type = - all_gather_op.getOperand().getType().cast(); - TensorType result_type = all_gather_op.getType(); + all_gather_op.getOperand(0).getType().cast(); + TensorType result_type = all_gather_op.getType(0).cast(); if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) return failure(); if (operands.size() != 1) return failure(); @@ -1268,7 +1296,7 @@ LogicalResult ExportXlaOp(AsyncDoneOp op, OpLoweringContext ctx) { if (all_gather_op && SimplyReturnedOp(all_gather_op)) { value_map[op.getResult(0)] = xla::internal::XlaBuilderFriend::BuildAllGatherDone( - ctx.builder, operand, xla::TypeToShape(all_gather_op.getType())); + ctx.builder, operand, xla::TypeToShape(all_gather_op.getType(0))); return success(); } auto all_reduce_op = diff --git a/xla/translate/mhlo_to_hlo/tests/export.mlir b/xla/translate/mhlo_to_hlo/tests/export.mlir index 101d68e515a44..63144f253e593 100644 --- a/xla/translate/mhlo_to_hlo/tests/export.mlir +++ b/xla/translate/mhlo_to_hlo/tests/export.mlir @@ -187,6 +187,44 @@ func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>, // ----- +// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}} +func.func @all_gather_0(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) attributes {execution_thread = "main"} { + %0:2 = "mhlo.all_gather"(%arg0, %arg1) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, + use_global_device_ids + } : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) + func.return %0#0, %0#1 : tensor<8x2xf32>, tensor<8x4xf32> +} + +func.func @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) { + %0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<8x2xf32>, tensor<8x4xf32>) -> !mhlo.async_bundle,tensor<8x4xf32>>, tuple,tensor<8x4xf32>>> + %1:2 = "mhlo.async_done"(%0) {called_computation = @all_gather_0, execution_thread = "main"} : (!mhlo.async_bundle,tensor<8x4xf32>>, tuple,tensor<8x4xf32>>>) -> (tensor<8x2xf32>, tensor<8x4xf32>) + return %1#0, %1#1 : tensor<8x2xf32>, tensor<8x4xf32> +} + +// ----- + +func.func private @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> tuple, tensor<8x16xf32>> { + // CHECK: %[[ARG0:.*]] = f32[8,2] parameter(0) + // CHECK-NEXT: %[[ARG1:.*]] = f32[8,4] parameter(1) + // CHECK-NEXT: %[[TUPLE:.*]] = (f32[8,2], f32[8,4]) tuple + // CHECK-NEXT: %[[TUPLE_ARG0:.*]] = f32[8,2] get-tuple-element((f32[8,2], f32[8,4]) %[[TUPLE]]), index=0 + // CHECK-NEXT: %[[TUPLE_ARG1:.*]] = f32[8,4] get-tuple-element((f32[8,2], f32[8,4]) %[[TUPLE]]), index=1 + // CHECK-NEXT: (f32[8,8], f32[8,16]) all-gather(f32[8,2] %[[TUPLE_ARG0]], f32[8,4] %[[TUPLE_ARG1]]), channel_id=1, replica_groups={{.*}}, dimensions={1} + %0:2 = "mhlo.all_gather"(%arg0, %arg1) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, + use_global_device_ids + } : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>) + %1 = mhlo.tuple %0#0, %0#1 {xla_shape = "(f32[8,8]{0,1}, f32[8,16]{0,1})"} : tuple, tensor<8x16xf32>> + return %1 : tuple, tensor<8x16xf32>> +} + +// ----- + // CHECK: HloModule func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { %0 = "mhlo.all_reduce"(%arg0) ({ @@ -2110,7 +2148,7 @@ func.func @main(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) { // CHECK-NOT: sharding= // CHECK: [[TUPLE1:%.*]] = token[] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=1 // CHECK-NOT: sharding= -// CHECK: ROOT {{%.*}} = (s32[3,4], token[]) tuple(s32[3,4] [[TUPLE0]], token[] [[TUPLE1]]) +// CHECK: ROOT {{%.*}} = (s32[3,4], token[]) tuple(s32[3,4] [[TUPLE0]], token[] [[TUPLE1]]) // -----