Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mhlo] AllGather variadic operands support #8151

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions xla/client/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5015,6 +5015,18 @@ XlaOp AllGather(const XlaOp operand, int64_t all_gather_dimension,
layout, use_global_device_ids);
}

XlaOp AllGatherTuple(const absl::Span<const XlaOp> operands,
int64_t all_gather_dimension, int64_t shard_count,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
const std::optional<Layout>& layout,
const std::optional<bool> use_global_device_ids) {
CHECK(!operands.empty());
return operands[0].builder()->AllGather(
operands[0].builder()->Tuple(operands), all_gather_dimension, shard_count,
replica_groups, channel_id, layout, use_global_device_ids);
}

XlaOp CrossReplicaSum(const XlaOp operand,
absl::Span<const ReplicaGroup> replica_groups) {
return operand.builder()->CrossReplicaSum(operand, replica_groups);
Expand Down
13 changes: 13 additions & 0 deletions xla/client/xla_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,12 @@ class XlaBuilder {
const std::optional<ChannelHandle>& channel_id,
const std::optional<Layout>& layout,
std::optional<bool> use_global_device_ids);
friend XlaOp AllGatherTuple(absl::Span<const XlaOp> operands,
int64_t all_gather_dimension, int64_t shard_count,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
const std::optional<Layout>& layout,
std::optional<bool> use_global_device_ids);
friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
Expand Down Expand Up @@ -2431,6 +2437,13 @@ XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension,
const std::optional<Layout>& layout = std::nullopt,
std::optional<bool> use_global_device_ids = std::nullopt);

XlaOp AllGatherTuple(
absl::Span<const XlaOp> operands, int64_t all_gather_dimension,
int64_t shard_count, absl::Span<const ReplicaGroup> replica_groups = {},
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
std::optional<bool> use_global_device_ids = std::nullopt);

// Enqueues an operation that do an AllReduce of the operand cross cores. Here
// AllReduce means doing a reduction on the input operand cross cores and then
// broadcasting the reduction result to those cores. The reduction function is
Expand Down
15 changes: 15 additions & 0 deletions xla/client/xla_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,21 @@ TEST_F(XlaBuilderTest, AllGatherWithTuple) {
ShapeUtil::MakeShape(F32, {64, 4})})));
}

TEST_F(XlaBuilderTest, AllGatherTuple) {
XlaBuilder b(TestName());
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {128, 4}), "p0");
auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {128, 8}), "p1");
AllGatherTuple({p0, p1}, /*all_gather_dimension=*/1, /*shard_count=*/4);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
auto tuple_shape =
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {128, 16}),
ShapeUtil::MakeShape(F32, {128, 32})});
EXPECT_THAT(root, GmockMatch(m::Op()
.WithOpcode(HloOpcode::kAllGather)
.WithShapeEqualTo(&tuple_shape)));
}

TEST_F(XlaBuilderTest, ReduceScatter) {
XlaBuilder b(TestName());
XlaComputation to_apply;
Expand Down
23 changes: 18 additions & 5 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2017,18 +2017,31 @@ LogicalResult AllGatherOp::verify() {
if (auto channelHandleAttr = getChannelHandleAttr())
channelId = channelHandleAttr.getHandle();

return hlo::verifyAllGatherOp(getLoc(), getOperand(), getAllGatherDim(),
getReplicaGroups(), channelId,
getUseGlobalDeviceIds(), getResult());
if (getOperands().empty())
return emitOptionalError(getLoc(),
"AllGather must have have at least one operand");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we verify anywhere that there are same number of operand and results?

I'm not sure if there's an upstream MLIR trait for this, if not then let's add another check here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Added validation for that case to AllGatherOp::verify

  if (getNumOperands() != getNumResults())
    return emitOptionalError(
        getLoc(),
        "AllGather requires the same number of operands and results");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised there isn't an MLIR core trait to check this. If not, may be adding one for MLIR might be useful (not as a part of this PR though).

Copy link
Contributor Author

@apivovarov apivovarov Jan 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list of OpTraits is here in LLVM mlir/include/mlir/IR/OpDefinition.h. I do not see a trait which compares Number of Operands and Results.
BTW, getOperands().empty() check might be redundant because SameOperandsAndResultElementType::verifyTrait calls OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) Operation.cpp which has the following

  if (failed(verifyAtLeastNOperands(op, 1)) ||
      failed(verifyAtLeastNResults(op, 1)))
    return failure();

if (getNumOperands() != getNumResults())
return emitOptionalError(
getLoc(),
"AllGather requires the same number of operands and results");

for (unsigned i = 0; i < getNumOperands(); ++i) {
if (failed(hlo::verifyAllGatherOp(
getLoc(), getOperand(i), getAllGatherDim(), getReplicaGroups(),
channelId, getUseGlobalDeviceIds(), getResult(i)))) {
return failure();
}
}
return success();
}

void AllGatherOp::build(OpBuilder& odsBuilder, OperationState& odsState,
Type resultType, Value operand,
IntegerAttr allGatherDim,
DenseIntElementsAttr replicaGroups,
ChannelHandleAttr channelHandle) {
AllGatherOp::build(odsBuilder, odsState, resultType, operand, allGatherDim,
replicaGroups, channelHandle,
AllGatherOp::build(odsBuilder, odsState, resultType, ValueRange(operand),
allGatherDim, replicaGroups, channelHandle,
/*use_global_device_ids=*/nullptr);
}

Expand Down
9 changes: 5 additions & 4 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -1449,8 +1449,9 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]>
string summary = "AllGather operation";
string description = [{
Within each process group in the process grid, concatenates the values of the
`operand` tensor from each process along `all_gather_dim` and produces a
`result` tensor.
operand tensor from each process along `all_gather_dim` and produces a
result tensor. The `computation` is applied separately for each operand in
`operands`, producing one result per operand.

See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_gather
Expand All @@ -1468,13 +1469,13 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]>
}];

let arguments = (ins
MHLO_Tensor:$operand,
Variadic<MHLO_Tensor>:$operands,
I64Attr:$all_gather_dim,
I64ElementsAttr:$replica_groups,
OptionalAttr<MHLO_ChannelHandle>:$channel_handle,
UnitAttr:$use_global_device_ids
);
let results = (outs MHLO_Tensor);
let results = (outs Variadic<MHLO_Tensor>);
// use_global_device_ids is rarely used, so we add simplified builder methods
// for convenience.
let builders = [
Expand Down
11 changes: 11 additions & 0 deletions xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,17 @@ func.func @all_to_all_i5(%data: tensor<4x16xf32>) -> tensor<16x4xf32> {

// -----

func.func @all_gather_variadic(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>) {
%0:2 = "mhlo.all_gather"(%arg0, %arg1) {
all_gather_dim = 1 : i64,
channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
} : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>)
func.return %0#0, %0#1 : tensor<8x8xf32>, tensor<8x16xf32>
}

// -----

func.func @allgather_gather_along_zero_dimension(%arg0: tensor<128x0xf32>) -> tensor<128x100xf32> {
// expected-error@+1 {{dimension size of operand at 'all_gather_dim' cannot be zero}}
%0 = "mhlo.all_gather"(%arg0) {
Expand Down
22 changes: 18 additions & 4 deletions xla/translate/hlo_to_mhlo/hlo_function_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1422,6 +1422,12 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
}
case HloOpcode::kAllGather: {
auto all_gather = Cast<HloAllGatherInstruction>(instruction);
auto result_tuple_ty = result_type.dyn_cast<mlir::TupleType>();

llvm::SmallVector<Type> result_types = {result_type};
if (result_tuple_ty) {
result_types = llvm::to_vector(result_tuple_ty.getTypes());
}
attributes.push_back(builder_->getNamedAttr(
"all_gather_dim",
builder_->getI64IntegerAttr(all_gather->all_gather_dimension())));
Expand All @@ -1432,10 +1438,15 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
ConvertChannelHandle(all_gather->channel_id().value()));
if (all_gather->use_global_device_ids())
attributes.push_back(ConvertUseGlobalDeviceIds());
return func_builder
->create<mlir::mhlo::AllGatherOp>(loc, result_type, operands,
attributes)
.getOperation();
auto all_gather_op = func_builder->create<mlir::mhlo::AllGatherOp>(
loc, result_types, operands, attributes);
if (result_tuple_ty) {
return func_builder
->create<mlir::mhlo::TupleOp>(loc, result_type,
all_gather_op.getResults())
.getOperation();
}
return all_gather_op.getOperation();
}
case HloOpcode::kAllGatherStart: {
auto all_gather_start = Cast<HloAllGatherInstruction>(instruction);
Expand All @@ -1449,6 +1460,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
ConvertChannelHandle(all_gather_start->channel_id().value()));
if (all_gather_start->use_global_device_ids())
attributes.push_back(ConvertUseGlobalDeviceIds());
if (all_gather_start->operands().size() > 1)
return InvalidArgument(
"Async tuple all-gather is not supported in MHLO");

return ImportOldStyleAsyncStart<mlir::mhlo::AllGatherOp>(
attributes, operands, loc, result_type, func_builder, "all_gather_",
Expand Down
12 changes: 12 additions & 0 deletions xla/translate/hlo_to_mhlo/tests/import.hlotxt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
ROOT ag = f32[128,128] all-gather(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true
}

// CHECK-LABEL: func private @test_all_gather_variadic
%test_all_gather_variadic {
input.0 = f32[128,8] parameter(0)
input.1 = f32[128,16] parameter(1)
// CHECK-NEXT: "mhlo.all_gather"(%arg0, %arg1)
// CHECK-SAME: all_gather_dim = 1 : i64
// CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
// CHECK-SAME: use_global_device_ids
ROOT ag = (f32[128,32], f32[128,64]) all-gather(f32[128,8] input.0, f32[128,16] input.1), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true
}

// Test all-to-all

// CHECK-LABEL: func private @test_all_to_all
Expand Down
60 changes: 44 additions & 16 deletions xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -949,21 +949,49 @@ LogicalResult ExportXlaOp(AddDependencyOp op, OpLoweringContext ctx) {

LogicalResult ExportXlaOp(AllGatherOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp operand;
if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op)))
return failure();
TensorType operand_type = op.getOperand().getType().cast<TensorType>();
TensorType result_type = op.getType();
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())

SmallVector<xla::XlaOp> operands;
if (failed(GetTuple(op.getOperation(), op.getOperands(), ctx, operands))) {
return failure();
}

mlir::FailureOr<xla::Shape> shape_or = ExtractXlaShape(op.getOperation());
if (failed(shape_or)) return failure();

auto all_gather_dim = op.getAllGatherDim();
int64_t shard_count = result_type.getDimSize(all_gather_dim) /
operand_type.getDimSize(all_gather_dim);
value_map[op] = xla::AllGather(
operand, all_gather_dim, shard_count,
Convert_replica_groups(op.getReplicaGroups()),
Convert_channel_handle(op.getChannelHandle()), std::nullopt,
Convert_use_global_device_ids(op.getUseGlobalDeviceIds()));
int64_t shard_count = 0;
for (size_t i = 0; i < operands.size(); ++i) {
TensorType operand_type = op.getOperand(i).getType().cast<TensorType>();
TensorType result_type = op.getType(i).cast<TensorType>();
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
return failure();
if (i == 0) {
shard_count = result_type.getDimSize(all_gather_dim) /
operand_type.getDimSize(all_gather_dim);
}
}

if (shape_or->IsTuple()) {
std::optional<xla::Layout> layout = std::nullopt;
if (shape_or->has_layout()) {
layout = shape_or->layout();
}
auto tuple = xla::AllGatherTuple(
operands, all_gather_dim, shard_count,
Convert_replica_groups(op.getReplicaGroups()),
Convert_channel_handle(op.getChannelHandle()), layout,
Convert_use_global_device_ids(op.getUseGlobalDeviceIds()));
for (auto [index, result] : llvm::enumerate(op.getResults())) {
value_map[result] = xla::GetTupleElement(tuple, index);
}
} else {
value_map[op->getResults()[0]] = xla::AllGather(
operands[0], all_gather_dim, shard_count,
Convert_replica_groups(op.getReplicaGroups()),
Convert_channel_handle(op.getChannelHandle()), std::nullopt,
Convert_use_global_device_ids(op.getUseGlobalDeviceIds()));
}

return success();
}

Expand Down Expand Up @@ -1093,8 +1121,8 @@ LogicalResult ExportXlaOp(AsyncStartOp op, OpLoweringContext ctx) {
dyn_cast_or_null<AllGatherOp>(callee.getBody().front().front());
if (all_gather_op && SimplyReturnedOp(all_gather_op)) {
TensorType operand_type =
all_gather_op.getOperand().getType().cast<TensorType>();
TensorType result_type = all_gather_op.getType();
all_gather_op.getOperand(0).getType().cast<TensorType>();
TensorType result_type = all_gather_op.getType(0).cast<TensorType>();
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
return failure();
if (operands.size() != 1) return failure();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's raise this check before the call to getOperand(0)

Copy link
Contributor Author

@apivovarov apivovarov Jan 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mhlo/IR/hlo_ops.td defines that MHLO_AllGatherOp has Variadic<MHLO_Tensor>:$operands,
I tried to use mhlo.all_gather() without operands in export.mlir - got error

error: unexpected error: 'mhlo.all_gather' op expected 1 or more operands, but found 0
  %0:2 = "mhlo.all_gather"() {

mhlo parser validates that mhlo.all_gather() should have at least one operand.
It should be safe to use all_gather_op.getOperand(0) in mlir_hlo_to_hlo.cc::ExportXlaOp since mlir validation happens before ExportXlaOp call.

Expand Down Expand Up @@ -1268,7 +1296,7 @@ LogicalResult ExportXlaOp(AsyncDoneOp op, OpLoweringContext ctx) {
if (all_gather_op && SimplyReturnedOp(all_gather_op)) {
value_map[op.getResult(0)] =
xla::internal::XlaBuilderFriend::BuildAllGatherDone(
ctx.builder, operand, xla::TypeToShape(all_gather_op.getType()));
ctx.builder, operand, xla::TypeToShape(all_gather_op.getType(0)));
return success();
}
auto all_reduce_op =
Expand Down
40 changes: 39 additions & 1 deletion xla/translate/mhlo_to_hlo/tests/export.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,44 @@ func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>,

// -----

// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}}
func.func @all_gather_0(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) attributes {execution_thread = "main"} {
%0:2 = "mhlo.all_gather"(%arg0, %arg1) {
all_gather_dim = 1 : i64,
channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
use_global_device_ids
} : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>)
func.return %0#0, %0#1 : tensor<8x2xf32>, tensor<8x4xf32>
}

func.func @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) {
%0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<8x2xf32>, tensor<8x4xf32>) -> !mhlo.async_bundle<tuple<tensor<8x2xf32>,tensor<8x4xf32>>, tuple<tensor<8x2xf32>,tensor<8x4xf32>>>
%1:2 = "mhlo.async_done"(%0) {called_computation = @all_gather_0, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<8x2xf32>,tensor<8x4xf32>>, tuple<tensor<8x2xf32>,tensor<8x4xf32>>>) -> (tensor<8x2xf32>, tensor<8x4xf32>)
return %1#0, %1#1 : tensor<8x2xf32>, tensor<8x4xf32>
}

// -----

func.func private @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> tuple<tensor<8x8xf32>, tensor<8x16xf32>> {
// CHECK: %[[ARG0:.*]] = f32[8,2] parameter(0)
// CHECK-NEXT: %[[ARG1:.*]] = f32[8,4] parameter(1)
// CHECK-NEXT: %[[TUPLE:.*]] = (f32[8,2], f32[8,4]) tuple
// CHECK-NEXT: %[[TUPLE_ARG0:.*]] = f32[8,2] get-tuple-element((f32[8,2], f32[8,4]) %[[TUPLE]]), index=0
// CHECK-NEXT: %[[TUPLE_ARG1:.*]] = f32[8,4] get-tuple-element((f32[8,2], f32[8,4]) %[[TUPLE]]), index=1
// CHECK-NEXT: (f32[8,8], f32[8,16]) all-gather(f32[8,2] %[[TUPLE_ARG0]], f32[8,4] %[[TUPLE_ARG1]]), channel_id=1, replica_groups={{.*}}, dimensions={1}
%0:2 = "mhlo.all_gather"(%arg0, %arg1) {
all_gather_dim = 1 : i64,
channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
use_global_device_ids
} : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>)
%1 = mhlo.tuple %0#0, %0#1 {xla_shape = "(f32[8,8]{0,1}, f32[8,16]{0,1})"} : tuple<tensor<8x8xf32>, tensor<8x16xf32>>
return %1 : tuple<tensor<8x8xf32>, tensor<8x16xf32>>
}

// -----

// CHECK: HloModule
func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
%0 = "mhlo.all_reduce"(%arg0) ({
Expand Down Expand Up @@ -2110,7 +2148,7 @@ func.func @main(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) {
// CHECK-NOT: sharding=
// CHECK: [[TUPLE1:%.*]] = token[] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=1
// CHECK-NOT: sharding=
// CHECK: ROOT {{%.*}} = (s32[3,4], token[]) tuple(s32[3,4] [[TUPLE0]], token[] [[TUPLE1]])
// CHECK: ROOT {{%.*}} = (s32[3,4], token[]) tuple(s32[3,4] [[TUPLE0]], token[] [[TUPLE1]])

// -----

Expand Down
Loading