Skip to content

Commit

Permalink
[mhlo] AllGather variadic operands support
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed Jan 4, 2024
1 parent 511d186 commit 567b97b
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 30 deletions.
12 changes: 12 additions & 0 deletions xla/client/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5042,6 +5042,18 @@ XlaOp AllGather(const XlaOp operand, int64_t all_gather_dimension,
layout, use_global_device_ids);
}

XlaOp AllGatherTuple(const absl::Span<const XlaOp> operands,
int64_t all_gather_dimension, int64_t shard_count,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
const std::optional<Layout>& layout,
const std::optional<bool> use_global_device_ids) {
CHECK(!operands.empty());
return operands[0].builder()->AllGather(
operands[0].builder()->Tuple(operands), all_gather_dimension, shard_count,
replica_groups, channel_id, layout, use_global_device_ids);
}

XlaOp CrossReplicaSum(const XlaOp operand,
absl::Span<const ReplicaGroup> replica_groups) {
return operand.builder()->CrossReplicaSum(operand, replica_groups);
Expand Down
13 changes: 13 additions & 0 deletions xla/client/xla_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,12 @@ class XlaBuilder {
const std::optional<ChannelHandle>& channel_id,
const std::optional<Layout>& layout,
std::optional<bool> use_global_device_ids);
friend XlaOp AllGatherTuple(absl::Span<const XlaOp> operands,
int64_t all_gather_dimension, int64_t shard_count,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
const std::optional<Layout>& layout,
std::optional<bool> use_global_device_ids);
friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
Expand Down Expand Up @@ -2441,6 +2447,13 @@ XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension,
const std::optional<Layout>& layout = std::nullopt,
std::optional<bool> use_global_device_ids = std::nullopt);

XlaOp AllGatherTuple(
absl::Span<const XlaOp> operands, int64_t all_gather_dimension,
int64_t shard_count, absl::Span<const ReplicaGroup> replica_groups = {},
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
std::optional<bool> use_global_device_ids = std::nullopt);

// Enqueues an operation that do an AllReduce of the operand cross cores. Here
// AllReduce means doing a reduction on the input operand cross cores and then
// broadcasting the reduction result to those cores. The reduction function is
Expand Down
15 changes: 15 additions & 0 deletions xla/client/xla_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,21 @@ TEST_F(XlaBuilderTest, AllGatherWithTuple) {
ShapeUtil::MakeShape(F32, {64, 4})})));
}

TEST_F(XlaBuilderTest, AllGatherTuple) {
XlaBuilder b(TestName());
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {128, 4}), "p0");
auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {128, 8}), "p1");
AllGatherTuple({p0, p1}, /*all_gather_dimension=*/1, /*shard_count=*/4);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
auto tuple_shape =
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {128, 16}),
ShapeUtil::MakeShape(F32, {128, 32})});
EXPECT_THAT(root, GmockMatch(m::Op()
.WithOpcode(HloOpcode::kAllGather)
.WithShapeEqualTo(&tuple_shape)));
}

TEST_F(XlaBuilderTest, ReduceScatter) {
XlaBuilder b(TestName());
XlaComputation to_apply;
Expand Down
19 changes: 14 additions & 5 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2013,18 +2013,27 @@ LogicalResult AllGatherOp::verify() {
if (auto channelHandleAttr = getChannelHandleAttr())
channelId = channelHandleAttr.getHandle();

return hlo::verifyAllGatherOp(getLoc(), getOperand(), getAllGatherDim(),
getReplicaGroups(), channelId,
getUseGlobalDeviceIds(), getResult());
if (getOperands().empty())
return emitOptionalError(getLoc(),
"AllGather must have have at least one operand");

for (auto i = 0; i < getNumOperands(); ++i) {
if (failed(hlo::verifyAllGatherOp(
getLoc(), getOperand(i), getAllGatherDim(), getReplicaGroups(),
channelId, getUseGlobalDeviceIds(), getResult(i)))) {
return failure();
}
}
return success();
}

void AllGatherOp::build(OpBuilder& odsBuilder, OperationState& odsState,
Type resultType, Value operand,
IntegerAttr allGatherDim,
DenseIntElementsAttr replicaGroups,
ChannelHandleAttr channelHandle) {
AllGatherOp::build(odsBuilder, odsState, resultType, operand, allGatherDim,
replicaGroups, channelHandle,
AllGatherOp::build(odsBuilder, odsState, resultType, ValueRange(operand),
allGatherDim, replicaGroups, channelHandle,
/*use_global_device_ids=*/nullptr);
}

Expand Down
9 changes: 5 additions & 4 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -1455,8 +1455,9 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]>
string summary = "AllGather operation";
string description = [{
Within each process group in the process grid, concatenates the values of the
`operand` tensor from each process along `all_gather_dim` and produces a
`result` tensor.
operand tensor from each process along `all_gather_dim` and produces a
result tensor. The `computation` is applied separately for each operand in
`operands`, producing one result per operand.

See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_gather
Expand All @@ -1474,13 +1475,13 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]>
}];

let arguments = (ins
MHLO_Tensor:$operand,
Variadic<MHLO_Tensor>:$operands,
I64Attr:$all_gather_dim,
I64ElementsAttr:$replica_groups,
OptionalAttr<MHLO_ChannelHandle>:$channel_handle,
UnitAttr:$use_global_device_ids
);
let results = (outs MHLO_Tensor);
let results = (outs Variadic<MHLO_Tensor>);
// use_global_device_ids is rarely used, so we add simplified builder methods
// for convenience.
let builders = [
Expand Down
11 changes: 11 additions & 0 deletions xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,17 @@ func.func @all_to_all_i5(%data: tensor<4x16xf32>) -> tensor<16x4xf32> {

// -----

func.func @all_gather_variadic(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>) {
%0:2 = "mhlo.all_gather"(%arg0, %arg1) {
all_gather_dim = 1 : i64,
channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
} : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>)
func.return %0#0, %0#1 : tensor<8x8xf32>, tensor<8x16xf32>
}

// -----

func.func @allgather_gather_along_zero_dimension(%arg0: tensor<128x0xf32>) -> tensor<128x100xf32> {
// expected-error@+1 {{dimension size of operand at 'all_gather_dim' cannot be zero}}
%0 = "mhlo.all_gather"(%arg0) {
Expand Down
22 changes: 18 additions & 4 deletions xla/translate/hlo_to_mhlo/hlo_function_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,12 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
}
case HloOpcode::kAllGather: {
auto all_gather = Cast<HloAllGatherInstruction>(instruction);
auto result_tuple_ty = result_type.dyn_cast<mlir::TupleType>();

llvm::SmallVector<Type> result_types = {result_type};
if (result_tuple_ty) {
result_types = llvm::to_vector(result_tuple_ty.getTypes());
}
attributes.push_back(builder_->getNamedAttr(
"all_gather_dim",
builder_->getI64IntegerAttr(all_gather->all_gather_dimension())));
Expand All @@ -1437,10 +1443,15 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
ConvertChannelHandle(all_gather->channel_id().value()));
if (all_gather->use_global_device_ids())
attributes.push_back(ConvertUseGlobalDeviceIds());
return func_builder
->create<mlir::mhlo::AllGatherOp>(loc, result_type, operands,
attributes)
.getOperation();
auto all_gather_op = func_builder->create<mlir::mhlo::AllGatherOp>(
loc, result_types, operands, attributes);
if (result_tuple_ty) {
return func_builder
->create<mlir::mhlo::TupleOp>(loc, result_type,
all_gather_op.getResults())
.getOperation();
}
return all_gather_op.getOperation();
}
case HloOpcode::kAllGatherStart: {
auto all_gather_start = Cast<HloAllGatherInstruction>(instruction);
Expand All @@ -1454,6 +1465,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
ConvertChannelHandle(all_gather_start->channel_id().value()));
if (all_gather_start->use_global_device_ids())
attributes.push_back(ConvertUseGlobalDeviceIds());
if (all_gather_start->operands().size() > 1)
return InvalidArgument(
"Async tuple all-gather is not supported in MHLO");

return ImportOldStyleAsyncStart<mlir::mhlo::AllGatherOp>(
attributes, operands, loc, result_type, func_builder, "all_gather_",
Expand Down
12 changes: 12 additions & 0 deletions xla/translate/hlo_to_mhlo/tests/import.hlotxt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
ROOT ag = f32[128,128] all-gather(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true
}

// CHECK-LABEL: func private @test_all_gather_variadic
%test_all_gather_variadic {
input.0 = f32[128,8] parameter(0)
input.1 = f32[128,16] parameter(1)
// CHECK-NEXT: "mhlo.all_gather"(%arg0, %arg1)
// CHECK-SAME: all_gather_dim = 1 : i64
// CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
// CHECK-SAME: use_global_device_ids
ROOT ag = (f32[128,32], f32[128,64]) all-gather(f32[128,8] input.0, f32[128,16] input.1), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true
}

// Test all-to-all

// CHECK-LABEL: func private @test_all_to_all
Expand Down
60 changes: 44 additions & 16 deletions xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -949,21 +949,49 @@ LogicalResult ExportXlaOp(AddDependencyOp op, OpLoweringContext ctx) {

LogicalResult ExportXlaOp(AllGatherOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp operand;
if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op)))
return failure();
TensorType operand_type = op.getOperand().getType().cast<TensorType>();
TensorType result_type = op.getType();
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())

SmallVector<xla::XlaOp> operands;
if (failed(GetTuple(op.getOperation(), op.getOperands(), ctx, operands))) {
return failure();
}

mlir::FailureOr<xla::Shape> shape_or = ExtractXlaShape(op.getOperation());
if (failed(shape_or)) return failure();

auto all_gather_dim = op.getAllGatherDim();
int64_t shard_count = result_type.getDimSize(all_gather_dim) /
operand_type.getDimSize(all_gather_dim);
value_map[op] = xla::AllGather(
operand, all_gather_dim, shard_count,
Convert_replica_groups(op.getReplicaGroups()),
Convert_channel_handle(op.getChannelHandle()), std::nullopt,
Convert_use_global_device_ids(op.getUseGlobalDeviceIds()));
int64_t shard_count = 0;
for (auto i = 0; i < operands.size(); ++i) {
TensorType operand_type = op.getOperand(i).getType().cast<TensorType>();
TensorType result_type = op.getType(i).cast<TensorType>();
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
return failure();
if (i == 0) {
shard_count = result_type.getDimSize(all_gather_dim) /
operand_type.getDimSize(all_gather_dim);
}
}

if (shape_or->IsTuple()) {
std::optional<xla::Layout> layout = std::nullopt;
if (shape_or->has_layout()) {
layout = shape_or->layout();
}
auto tuple = xla::AllGatherTuple(
operands, all_gather_dim, shard_count,
Convert_replica_groups(op.getReplicaGroups()),
Convert_channel_handle(op.getChannelHandle()), layout,
Convert_use_global_device_ids(op.getUseGlobalDeviceIds()));
for (auto [index, result] : llvm::enumerate(op.getResults())) {
value_map[result] = xla::GetTupleElement(tuple, index);
}
} else {
value_map[op->getResults()[0]] = xla::AllGather(
operands[0], all_gather_dim, shard_count,
Convert_replica_groups(op.getReplicaGroups()),
Convert_channel_handle(op.getChannelHandle()), std::nullopt,
Convert_use_global_device_ids(op.getUseGlobalDeviceIds()));
}

return success();
}

Expand Down Expand Up @@ -1095,8 +1123,8 @@ LogicalResult ExportXlaOp(AsyncStartOp op, OpLoweringContext ctx) {
dyn_cast_or_null<AllGatherOp>(callee.getBody().front().front());
if (all_gather_op && SimplyReturnedOp(all_gather_op)) {
TensorType operand_type =
all_gather_op.getOperand().getType().cast<TensorType>();
TensorType result_type = all_gather_op.getType();
all_gather_op.getOperand(0).getType().cast<TensorType>();
TensorType result_type = all_gather_op.getType(0).cast<TensorType>();
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
return failure();
if (operands.size() != 1) return failure();
Expand Down Expand Up @@ -1290,7 +1318,7 @@ LogicalResult ExportXlaOp(AsyncDoneOp op, OpLoweringContext ctx) {
if (all_gather_op && SimplyReturnedOp(all_gather_op)) {
value_map[op.getResult(0)] =
xla::internal::XlaBuilderFriend::BuildAllGatherDone(
ctx.builder, operand, xla::TypeToShape(all_gather_op.getType()));
ctx.builder, operand, xla::TypeToShape(all_gather_op.getType(0)));
return success();
}
auto all_reduce_op =
Expand Down
40 changes: 39 additions & 1 deletion xla/translate/mhlo_to_hlo/tests/export.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,44 @@ func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>,

// -----

// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}}
func.func @all_gather_0(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) attributes {execution_thread = "main"} {
%0:2 = "mhlo.all_gather"(%arg0, %arg1) {
all_gather_dim = 1 : i64,
channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
use_global_device_ids
} : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>)
func.return %0#0, %0#1 : tensor<8x2xf32>, tensor<8x4xf32>
}

func.func @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) {
%0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<8x2xf32>, tensor<8x4xf32>) -> !mhlo.async_bundle<tuple<tensor<8x2xf32>,tensor<8x4xf32>>, tuple<tensor<8x2xf32>,tensor<8x4xf32>>>
%1:2 = "mhlo.async_done"(%0) {called_computation = @all_gather_0, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<8x2xf32>,tensor<8x4xf32>>, tuple<tensor<8x2xf32>,tensor<8x4xf32>>>) -> (tensor<8x2xf32>, tensor<8x4xf32>)
return %1#0, %1#1 : tensor<8x2xf32>, tensor<8x4xf32>
}

// -----

func.func private @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> tuple<tensor<8x8xf32>, tensor<8x16xf32>> {
// CHECK: %[[ARG0:.*]] = f32[8,2] parameter(0)
// CHECK-NEXT: %[[ARG1:.*]] = f32[8,4] parameter(1)
// CHECK-NEXT: %[[TUPLE:.*]] = (f32[8,2], f32[8,4]) tuple
// CHECK-NEXT: %[[TUPLE_ARG0:.*]] = f32[8,2] get-tuple-element((f32[8,2], f32[8,4]) %[[TUPLE]]), index=0
// CHECK-NEXT: %[[TUPLE_ARG1:.*]] = f32[8,4] get-tuple-element((f32[8,2], f32[8,4]) %[[TUPLE]]), index=1
// CHECK-NEXT: (f32[8,8], f32[8,16]) all-gather(f32[8,2] %[[TUPLE_ARG0]], f32[8,4] %[[TUPLE_ARG1]]), channel_id=1, replica_groups={{.*}}, dimensions={1}
%0:2 = "mhlo.all_gather"(%arg0, %arg1) {
all_gather_dim = 1 : i64,
channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
use_global_device_ids
} : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>)
%1 = mhlo.tuple %0#0, %0#1 {xla_shape = "(f32[8,8]{0,1}, f32[8,16]{0,1})"} : tuple<tensor<8x8xf32>, tensor<8x16xf32>>
return %1 : tuple<tensor<8x8xf32>, tensor<8x16xf32>>
}

// -----

// CHECK: HloModule
func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
%0 = "mhlo.all_reduce"(%arg0) ({
Expand Down Expand Up @@ -2110,7 +2148,7 @@ func.func @main(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) {
// CHECK-NOT: sharding=
// CHECK: [[TUPLE1:%.*]] = token[] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=1
// CHECK-NOT: sharding=
// CHECK: ROOT {{%.*}} = (s32[3,4], token[]) tuple(s32[3,4] [[TUPLE0]], token[] [[TUPLE1]])
// CHECK: ROOT {{%.*}} = (s32[3,4], token[]) tuple(s32[3,4] [[TUPLE0]], token[] [[TUPLE1]])

// -----

Expand Down

0 comments on commit 567b97b

Please sign in to comment.