Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tuple input support to all-gather and reduce-scatter #5740

Closed
wants to merge 9 commits into from
29 changes: 21 additions & 8 deletions xla/client/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2883,10 +2883,24 @@ XlaOp XlaBuilder::AllGatherImpl(const XlaOp operand,
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));

TF_ASSIGN_OR_RETURN(
Shape inferred_shape,
ShapeInference::InferAllGatherShape({operand_shape},
all_gather_dimension, shard_count));
std::vector<const Shape*> operand_shapes;
std::vector<XlaOp> operands;
if (operand_shape->IsTuple()) {
if (operand_shape->tuple_shapes_size() == 0) {
return Unimplemented("0 element tuple AllGather is not supported");
}
for (int i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
operand_shapes.push_back(&operand_shape->tuple_shapes(i));
operands.push_back(GetTupleElement(operand, i));
}
} else {
operand_shapes.push_back(operand_shape);
operands.push_back(operand);
}

TF_ASSIGN_OR_RETURN(Shape inferred_shape,
ShapeInference::InferAllGatherShape(
operand_shapes, all_gather_dimension, shard_count));
if (layout) {
*inferred_shape.mutable_layout() = *layout;
instr.set_constrain_layout(true);
Expand All @@ -2908,7 +2922,7 @@ XlaOp XlaBuilder::AllGatherImpl(const XlaOp operand,
AddInstruction(std::move(instr),
async ? HloOpcode::kAllGatherStart
: HloOpcode::kAllGather,
{operand}));
operands));
return all_gather;
});
}
Expand Down Expand Up @@ -3320,8 +3334,7 @@ XlaOp XlaBuilder::ReduceScatter(
operand_shape->tuple_shapes(0).element_type()) {
return Unimplemented(
"All the shapes of a tuple input of ReduceScatter must have "
"the same "
"element type");
"the same element type");
}
operand_shapes.push_back(&operand_shape->tuple_shapes(i));
operands.push_back(GetTupleElement(operand, i));
Expand Down Expand Up @@ -3355,7 +3368,7 @@ XlaOp XlaBuilder::ReduceScatter(

TF_ASSIGN_OR_RETURN(
auto reduce_scatter,
AddInstruction(std::move(instr), HloOpcode::kReduceScatter, {operand}));
AddInstruction(std::move(instr), HloOpcode::kReduceScatter, operands));
return reduce_scatter;
});
}
Expand Down
46 changes: 46 additions & 0 deletions xla/client/xla_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,22 @@ TEST_F(XlaBuilderTest, AllGatherR2) {
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64})));
}

TEST_F(XlaBuilderTest, AllGatherWithTuple) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x");
auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2");
AllGather(Tuple(&b, {x, x2}), /*all_gather_dimension=*/0,
jeffhataws marked this conversation as resolved.
Show resolved Hide resolved
/*shard_count=*/4);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();

EXPECT_EQ(root->opcode(), HloOpcode::kAllGather);
EXPECT_TRUE(ShapeUtil::Equal(
root->shape(),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {16}),
ShapeUtil::MakeShape(F32, {64, 4})})));
}

TEST_F(XlaBuilderTest, ReduceScatter) {
XlaBuilder b(TestName());
XlaComputation to_apply;
Expand All @@ -444,6 +460,36 @@ TEST_F(XlaBuilderTest, ReduceScatter) {
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 8})));
}

TEST_F(XlaBuilderTest, ReduceScatterWithTuple) {
XlaBuilder b(TestName());
XlaComputation to_apply;
{
auto sub_builder = b.CreateSubBuilder("add");
auto arg0 =
Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), "x");
auto arg1 =
Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), "y");
Add(arg0, arg1);
TF_ASSERT_OK_AND_ASSIGN(to_apply, sub_builder->Build());
}
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2");
ReplicaGroup group;
group.add_replica_ids(0);
group.add_replica_ids(1);
ReduceScatter(Tuple(&b, {x, x2}), to_apply, /*scatter_dimension=*/1,
/*shard_count=*/2,
/*replica_groups=*/{group});
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();

EXPECT_EQ(root->opcode(), HloOpcode::kReduceScatter);
EXPECT_TRUE(ShapeUtil::Equal(
root->shape(),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4, 8}),
ShapeUtil::MakeShape(F32, {16, 2})})));
}

TEST_F(XlaBuilderTest, AllToAll) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
Expand Down
51 changes: 36 additions & 15 deletions xla/service/all_gather_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,51 @@ HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) {
return reduction;
}

Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) {
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
GetCollectiveOpGroupMode(ag->channel_id().has_value(),
ag->use_global_device_ids()));
TF_ASSIGN_OR_RETURN(
std::vector<HloInstruction*> start_indices,
HloInstruction* TranslateAllGatherToAllReducePerOperand(
CollectiveOpGroupMode group_mode, const HloAllGatherInstruction& ag,
const Shape& output_shape, HloInstruction* operand, HloComputation* comp) {
std::vector<HloInstruction*> start_indices =
CreateStartIndicesForCollectiveDecomposition(
group_mode, ag->replica_groups(), ag->operand(0)->shape(),
ag->all_gather_dimension(), comp));
group_mode, ag.replica_groups(), operand->shape(),
ag.all_gather_dimension(), comp)
.value();

auto zero = comp->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(ag->shape().element_type())));
LiteralUtil::Zero(output_shape.element_type())));
zero = comp->AddInstruction(
HloInstruction::CreateBroadcast(ag->shape(), zero, {}));
HloInstruction::CreateBroadcast(output_shape, zero, {}));

auto dus = comp->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
zero->shape(), zero, ag->mutable_operand(0), start_indices));
zero->shape(), zero, operand, start_indices));
auto ar = comp->AddInstruction(HloInstruction::CreateAllReduce(
dus->shape(), {dus},
MakeBinaryAdd(dus->shape().element_type(), comp->parent()),
ag->replica_groups(),
/*constrain_layout=*/ag->constrain_layout(), ag->channel_id(),
ag->use_global_device_ids()));
TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(ar));
ag.replica_groups(),
/*constrain_layout=*/ag.constrain_layout(), ag.channel_id(),
ag.use_global_device_ids()));
return ar;
}

Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) {
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
GetCollectiveOpGroupMode(ag->channel_id().has_value(),
ag->use_global_device_ids()));
if (ag->operand_count() > 1) {
std::vector<HloInstruction*> tuple_inputs;
for (int i = 0; i < ag->operand_count(); ++i) {
auto* input_operand = ag->mutable_operand(i);
const auto& output_shape = ag->shape().tuple_shapes(i);
auto* ar = TranslateAllGatherToAllReducePerOperand(
group_mode, *ag, output_shape, input_operand, comp);
tuple_inputs.push_back(ar);
}
auto tup = comp->AddInstruction(HloInstruction::CreateTuple(tuple_inputs));
TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(tup));
} else {
auto* ar = TranslateAllGatherToAllReducePerOperand(
group_mode, *ag, ag->shape(), ag->mutable_operand(0), comp);
TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(ar));
}
TF_RETURN_IF_ERROR(comp->RemoveInstructionAndUnusedOperands(ag));
return OkStatus();
}
Expand Down
28 changes: 28 additions & 0 deletions xla/service/all_gather_decomposer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,5 +155,33 @@ ENTRY entry {
op::Constant(), op::Multiply(id, op::Constant()))));
}

TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithTuple) {
const std::string module_str = R"(
HloModule module

ENTRY entry {
param0 = f32[10,20] parameter(0)
param1 = f32[10,16] parameter(1)
ROOT ag = (f32[10,80], f32[10,64]) all-gather(param0, param1),
replica_groups={}, dimensions={1}
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule((module_str)));
AllGatherDecomposer decomposer;
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(
module->entry_computation()->root_instruction(),
op::Tuple(
op::AllReduce(op::DynamicUpdateSlice(
op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(),
op::Multiply(op::ReplicaId(), op::Constant()))),
op::AllReduce(op::DynamicUpdateSlice(
op::Broadcast(op::Constant()), op::Parameter(1), op::Constant(),
op::Multiply(op::ReplicaId(), op::Constant())))));
}

} // namespace
} // namespace xla
4 changes: 4 additions & 0 deletions xla/service/cpu/ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,10 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
return HandleAllReduceMultipleReplica(crs);
}

Status IrEmitter::HandleReduceScatter(HloInstruction* rs) {
return Unimplemented("ReduceScatter is not implemented on CPU.");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jurahul, @burmako, @radhakrishnaba I think this is correct. Let me know if you want me to remove this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine to me,

}

Status IrEmitter::HandleAllToAll(HloInstruction* instruction) {
auto* instr = Cast<HloAllToAllInstruction>(instruction);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction));
Expand Down
1 change: 1 addition & 0 deletions xla/service/cpu/ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleFft(HloInstruction* fft) override;
Status HandleAllReduce(HloInstruction* crs) override;
Status HandleReduceScatter(HloInstruction* crs) override;
Status HandleCollectivePermute(HloInstruction* crs) override;
Status HandleInfeed(HloInstruction* instruction) override;
Status HandleOutfeed(HloInstruction* outfeed) override;
Expand Down
2 changes: 2 additions & 0 deletions xla/service/hlo_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo,
ag->use_global_device_ids()));
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode));
TF_RET_CHECK(ag->all_gather_dimension() >= 0);
TF_RET_CHECK(ag->operand_count() >= 1);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jurahul, @burmako, @radhakrishnaba I think this is correct. Let me know if you want me to remove this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also seems ok to me.


int64_t shard_count;
for (int64_t i = 0; i < ag->operand_count(); ++i) {
Expand Down Expand Up @@ -521,6 +522,7 @@ Status ShapeVerifier::HandleReduceScatter(HloInstruction* hlo) {
ars->use_global_device_ids()));
TF_RETURN_IF_ERROR(CheckReplicaGroups(ars, group_mode));
TF_RET_CHECK(ars->scatter_dimension() >= 0);
TF_RET_CHECK(ars->operand_count() >= 1);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here.


for (int64_t i = 0; i < ars->operand_count(); ++i) {
TF_RET_CHECK(ars->scatter_dimension() < ars->operand(i)->shape().rank());
Expand Down