Skip to content

Commit

Permalink
PR openxla#5740: Add tuple input support to all-gather and reduce-sca…
Browse files Browse the repository at this point in the history
…tter

Imported from GitHub PR openxla#5740

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

Copybara import of the project:

--
7ea1159 by Junmin Hao <[email protected]>:

Add Tuple input and token support to all-gather and reduce-scatter.

Committer: Junmin Hao <[email protected]>

--
cdb873e by Junmin Hao <[email protected]>:

lint fix

--
aad3521 by Jeffrey Huynh <[email protected]>:

Fix hlo_verifier_test failure due to changed expectation

--
32e8145 by Jeffrey Huynh <[email protected]>:

Separate the token change out into a separate PR with RFC.

--
b301c2a by Jeffrey Huynh <[email protected]>:

Change *WithToken tests to *WithTuple

--
5890278 by Jeffrey Huynh <[email protected]>:

Fix missing parenthesis

Merging this change closes openxla#5740

COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0
PiperOrigin-RevId: 573976449
  • Loading branch information
jeffhataws committed Dec 11, 2023
1 parent 83d9984 commit 721f725
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 23 deletions.
29 changes: 21 additions & 8 deletions xla/client/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2883,10 +2883,24 @@ XlaOp XlaBuilder::AllGatherImpl(const XlaOp operand,
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));

TF_ASSIGN_OR_RETURN(
Shape inferred_shape,
ShapeInference::InferAllGatherShape({operand_shape},
all_gather_dimension, shard_count));
std::vector<const Shape*> operand_shapes;
std::vector<XlaOp> operands;
if (operand_shape->IsTuple()) {
if (operand_shape->tuple_shapes_size() == 0) {
return Unimplemented("0 element tuple AllGather is not supported");
}
for (int i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
operand_shapes.push_back(&operand_shape->tuple_shapes(i));
operands.push_back(GetTupleElement(operand, i));
}
} else {
operand_shapes.push_back(operand_shape);
operands.push_back(operand);
}

TF_ASSIGN_OR_RETURN(Shape inferred_shape,
ShapeInference::InferAllGatherShape(
operand_shapes, all_gather_dimension, shard_count));
if (layout) {
*inferred_shape.mutable_layout() = *layout;
instr.set_constrain_layout(true);
Expand All @@ -2908,7 +2922,7 @@ XlaOp XlaBuilder::AllGatherImpl(const XlaOp operand,
AddInstruction(std::move(instr),
async ? HloOpcode::kAllGatherStart
: HloOpcode::kAllGather,
{operand}));
operands));
return all_gather;
});
}
Expand Down Expand Up @@ -3320,8 +3334,7 @@ XlaOp XlaBuilder::ReduceScatter(
operand_shape->tuple_shapes(0).element_type()) {
return Unimplemented(
"All the shapes of a tuple input of ReduceScatter must have "
"the same "
"element type");
"the same element type");
}
operand_shapes.push_back(&operand_shape->tuple_shapes(i));
operands.push_back(GetTupleElement(operand, i));
Expand Down Expand Up @@ -3355,7 +3368,7 @@ XlaOp XlaBuilder::ReduceScatter(

TF_ASSIGN_OR_RETURN(
auto reduce_scatter,
AddInstruction(std::move(instr), HloOpcode::kReduceScatter, {operand}));
AddInstruction(std::move(instr), HloOpcode::kReduceScatter, operands));
return reduce_scatter;
});
}
Expand Down
46 changes: 46 additions & 0 deletions xla/client/xla_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,22 @@ TEST_F(XlaBuilderTest, AllGatherR2) {
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64})));
}

TEST_F(XlaBuilderTest, AllGatherWithTuple) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x");
auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2");
AllGather(Tuple(&b, {x, x2}), /*all_gather_dimension=*/0,
/*shard_count=*/4);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();

EXPECT_EQ(root->opcode(), HloOpcode::kAllGather);
EXPECT_TRUE(ShapeUtil::Equal(
root->shape(),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {16}),
ShapeUtil::MakeShape(F32, {64, 4})})));
}

TEST_F(XlaBuilderTest, ReduceScatter) {
XlaBuilder b(TestName());
XlaComputation to_apply;
Expand All @@ -444,6 +460,36 @@ TEST_F(XlaBuilderTest, ReduceScatter) {
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 8})));
}

TEST_F(XlaBuilderTest, ReduceScatterWithTuple) {
XlaBuilder b(TestName());
XlaComputation to_apply;
{
auto sub_builder = b.CreateSubBuilder("add");
auto arg0 =
Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), "x");
auto arg1 =
Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), "y");
Add(arg0, arg1);
TF_ASSERT_OK_AND_ASSIGN(to_apply, sub_builder->Build());
}
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2");
ReplicaGroup group;
group.add_replica_ids(0);
group.add_replica_ids(1);
ReduceScatter(Tuple(&b, {x, x2}), to_apply, /*scatter_dimension=*/1,
/*shard_count=*/2,
/*replica_groups=*/{group});
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();

EXPECT_EQ(root->opcode(), HloOpcode::kReduceScatter);
EXPECT_TRUE(ShapeUtil::Equal(
root->shape(),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4, 8}),
ShapeUtil::MakeShape(F32, {16, 2})})));
}

TEST_F(XlaBuilderTest, AllToAll) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
Expand Down
51 changes: 36 additions & 15 deletions xla/service/all_gather_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,51 @@ HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) {
return reduction;
}

Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) {
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
GetCollectiveOpGroupMode(ag->channel_id().has_value(),
ag->use_global_device_ids()));
TF_ASSIGN_OR_RETURN(
std::vector<HloInstruction*> start_indices,
HloInstruction* TranslateAllGatherToAllReducePerOperand(
CollectiveOpGroupMode group_mode, const HloAllGatherInstruction& ag,
const Shape& output_shape, HloInstruction* operand, HloComputation* comp) {
std::vector<HloInstruction*> start_indices =
CreateStartIndicesForCollectiveDecomposition(
group_mode, ag->replica_groups(), ag->operand(0)->shape(),
ag->all_gather_dimension(), comp));
group_mode, ag.replica_groups(), operand->shape(),
ag.all_gather_dimension(), comp)
.value();

auto zero = comp->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(ag->shape().element_type())));
LiteralUtil::Zero(output_shape.element_type())));
zero = comp->AddInstruction(
HloInstruction::CreateBroadcast(ag->shape(), zero, {}));
HloInstruction::CreateBroadcast(output_shape, zero, {}));

auto dus = comp->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
zero->shape(), zero, ag->mutable_operand(0), start_indices));
zero->shape(), zero, operand, start_indices));
auto ar = comp->AddInstruction(HloInstruction::CreateAllReduce(
dus->shape(), {dus},
MakeBinaryAdd(dus->shape().element_type(), comp->parent()),
ag->replica_groups(),
/*constrain_layout=*/ag->constrain_layout(), ag->channel_id(),
ag->use_global_device_ids()));
TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(ar));
ag.replica_groups(),
/*constrain_layout=*/ag.constrain_layout(), ag.channel_id(),
ag.use_global_device_ids()));
return ar;
}

Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) {
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
GetCollectiveOpGroupMode(ag->channel_id().has_value(),
ag->use_global_device_ids()));
if (ag->operand_count() > 1) {
std::vector<HloInstruction*> tuple_inputs;
for (int i = 0; i < ag->operand_count(); ++i) {
auto* input_operand = ag->mutable_operand(i);
const auto& output_shape = ag->shape().tuple_shapes(i);
auto* ar = TranslateAllGatherToAllReducePerOperand(
group_mode, *ag, output_shape, input_operand, comp);
tuple_inputs.push_back(ar);
}
auto tup = comp->AddInstruction(HloInstruction::CreateTuple(tuple_inputs));
TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(tup));
} else {
auto* ar = TranslateAllGatherToAllReducePerOperand(
group_mode, *ag, ag->shape(), ag->mutable_operand(0), comp);
TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(ar));
}
TF_RETURN_IF_ERROR(comp->RemoveInstructionAndUnusedOperands(ag));
return OkStatus();
}
Expand Down
28 changes: 28 additions & 0 deletions xla/service/all_gather_decomposer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,5 +155,33 @@ ENTRY entry {
op::Constant(), op::Multiply(id, op::Constant()))));
}

TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithTuple) {
const std::string module_str = R"(
HloModule module
ENTRY entry {
param0 = f32[10,20] parameter(0)
param1 = f32[10,16] parameter(1)
ROOT ag = (f32[10,80], f32[10,64]) all-gather(param0, param1),
replica_groups={}, dimensions={1}
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule((module_str)));
AllGatherDecomposer decomposer;
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(
module->entry_computation()->root_instruction(),
op::Tuple(
op::AllReduce(op::DynamicUpdateSlice(
op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(),
op::Multiply(op::ReplicaId(), op::Constant()))),
op::AllReduce(op::DynamicUpdateSlice(
op::Broadcast(op::Constant()), op::Parameter(1), op::Constant(),
op::Multiply(op::ReplicaId(), op::Constant())))));
}

} // namespace
} // namespace xla
4 changes: 4 additions & 0 deletions xla/service/cpu/ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,10 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
return HandleAllReduceMultipleReplica(crs);
}

Status IrEmitter::HandleReduceScatter(HloInstruction* rs) {
return Unimplemented("ReduceScatter is not implemented on CPU.");
}

Status IrEmitter::HandleAllToAll(HloInstruction* instruction) {
auto* instr = Cast<HloAllToAllInstruction>(instruction);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction));
Expand Down
1 change: 1 addition & 0 deletions xla/service/cpu/ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleFft(HloInstruction* fft) override;
Status HandleAllReduce(HloInstruction* crs) override;
Status HandleReduceScatter(HloInstruction* crs) override;
Status HandleCollectivePermute(HloInstruction* crs) override;
Status HandleInfeed(HloInstruction* instruction) override;
Status HandleOutfeed(HloInstruction* outfeed) override;
Expand Down
2 changes: 2 additions & 0 deletions xla/service/hlo_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo,
ag->use_global_device_ids()));
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode));
TF_RET_CHECK(ag->all_gather_dimension() >= 0);
TF_RET_CHECK(ag->operand_count() >= 1);

int64_t shard_count;
for (int64_t i = 0; i < ag->operand_count(); ++i) {
Expand Down Expand Up @@ -521,6 +522,7 @@ Status ShapeVerifier::HandleReduceScatter(HloInstruction* hlo) {
ars->use_global_device_ids()));
TF_RETURN_IF_ERROR(CheckReplicaGroups(ars, group_mode));
TF_RET_CHECK(ars->scatter_dimension() >= 0);
TF_RET_CHECK(ars->operand_count() >= 1);

for (int64_t i = 0; i < ars->operand_count(); ++i) {
TF_RET_CHECK(ars->scatter_dimension() < ars->operand(i)->shape().rank());
Expand Down

0 comments on commit 721f725

Please sign in to comment.