Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include token as part of the input/output tuple in all-gather and reduce-scatter #7338

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions xla/client/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2883,10 +2883,30 @@ XlaOp XlaBuilder::AllGatherImpl(const XlaOp operand,
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));

TF_ASSIGN_OR_RETURN(
Shape inferred_shape,
ShapeInference::InferAllGatherShape({operand_shape},
all_gather_dimension, shard_count));
std::vector<const Shape*> operand_shapes;
std::vector<XlaOp> operands;
if (operand_shape->IsTuple()) {
if (operand_shape->tuple_shapes_size() == 0) {
return Unimplemented("0 element tuple AllGather is not supported");
}
for (int i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
if (operand_shape->tuple_shapes(i).element_type() !=
operand_shape->tuple_shapes(0).element_type()) {
return Unimplemented(
"All the shapes of a tuple input of AllGather must have the same "
"element type");
}
operand_shapes.push_back(&operand_shape->tuple_shapes(i));
operands.push_back(GetTupleElement(operand, i));
}
} else {
operand_shapes.push_back(operand_shape);
operands.push_back(operand);
}

TF_ASSIGN_OR_RETURN(Shape inferred_shape,
ShapeInference::InferAllGatherShape(
operand_shapes, all_gather_dimension, shard_count));
if (layout) {
*inferred_shape.mutable_layout() = *layout;
instr.set_constrain_layout(true);
Expand All @@ -2908,7 +2928,7 @@ XlaOp XlaBuilder::AllGatherImpl(const XlaOp operand,
AddInstruction(std::move(instr),
async ? HloOpcode::kAllGatherStart
: HloOpcode::kAllGather,
{operand}));
operands));
return all_gather;
});
}
Expand Down Expand Up @@ -3320,8 +3340,7 @@ XlaOp XlaBuilder::ReduceScatter(
operand_shape->tuple_shapes(0).element_type()) {
return Unimplemented(
"All the shapes of a tuple input of ReduceScatter must have "
"the same "
"element type");
"the same element type");
}
operand_shapes.push_back(&operand_shape->tuple_shapes(i));
operands.push_back(GetTupleElement(operand, i));
Expand Down Expand Up @@ -3355,7 +3374,7 @@ XlaOp XlaBuilder::ReduceScatter(

TF_ASSIGN_OR_RETURN(
auto reduce_scatter,
AddInstruction(std::move(instr), HloOpcode::kReduceScatter, {operand}));
AddInstruction(std::move(instr), HloOpcode::kReduceScatter, operands));
return reduce_scatter;
});
}
Expand Down
50 changes: 50 additions & 0 deletions xla/client/xla_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,24 @@ TEST_F(XlaBuilderTest, AllGatherR2) {
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64})));
}

TEST_F(XlaBuilderTest, AllGatherWithToken) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x");
auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2");
auto t = Parameter(&b, 2, ShapeUtil::MakeScalarShape(F32), "t");
AllGather(Tuple(&b, {x, x2, t}), /*all_gather_dimension=*/0,
/*shard_count=*/4);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();

EXPECT_EQ(root->opcode(), HloOpcode::kAllGather);
EXPECT_TRUE(ShapeUtil::Equal(
root->shape(),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {16}),
ShapeUtil::MakeShape(F32, {64, 4}),
ShapeUtil::MakeScalarShape(F32)})));
}

TEST_F(XlaBuilderTest, ReduceScatter) {
XlaBuilder b(TestName());
XlaComputation to_apply;
Expand All @@ -444,6 +462,38 @@ TEST_F(XlaBuilderTest, ReduceScatter) {
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 8})));
}

TEST_F(XlaBuilderTest, ReduceScatterWithToken) {
XlaBuilder b(TestName());
XlaComputation to_apply;
{
auto sub_builder = b.CreateSubBuilder("add");
auto arg0 =
Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), "x");
auto arg1 =
Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), "y");
Add(arg0, arg1);
TF_ASSERT_OK_AND_ASSIGN(to_apply, sub_builder->Build());
}
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2");
auto t = Parameter(&b, 2, ShapeUtil::MakeScalarShape(F32), "t");
ReplicaGroup group;
group.add_replica_ids(0);
group.add_replica_ids(1);
ReduceScatter(Tuple(&b, {x, x2, t}), to_apply, /*scatter_dimension=*/1,
/*shard_count=*/2,
/*replica_groups=*/{group});
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();

EXPECT_EQ(root->opcode(), HloOpcode::kReduceScatter);
EXPECT_TRUE(ShapeUtil::Equal(
root->shape(),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4, 8}),
ShapeUtil::MakeShape(F32, {16, 2}),
ShapeUtil::MakeScalarShape(F32)})));
}

TEST_F(XlaBuilderTest, AllToAll) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
Expand Down
53 changes: 38 additions & 15 deletions xla/service/all_gather_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,53 @@ HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) {
return reduction;
}

Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) {
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
GetCollectiveOpGroupMode(ag->channel_id().has_value(),
ag->use_global_device_ids()));
TF_ASSIGN_OR_RETURN(
std::vector<HloInstruction*> start_indices,
HloInstruction* TranslateAllGatherToAllReducePerOperand(
CollectiveOpGroupMode group_mode, const HloAllGatherInstruction& ag,
const Shape& output_shape, HloInstruction* operand, HloComputation* comp) {
std::vector<HloInstruction*> start_indices =
CreateStartIndicesForCollectiveDecomposition(
group_mode, ag->replica_groups(), ag->operand(0)->shape(),
ag->all_gather_dimension(), comp));
group_mode, ag.replica_groups(), operand->shape(),
ag.all_gather_dimension(), comp)
.value();

auto zero = comp->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(ag->shape().element_type())));
LiteralUtil::Zero(output_shape.element_type())));
zero = comp->AddInstruction(
HloInstruction::CreateBroadcast(ag->shape(), zero, {}));
HloInstruction::CreateBroadcast(output_shape, zero, {}));

auto dus = comp->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
zero->shape(), zero, ag->mutable_operand(0), start_indices));
zero->shape(), zero, operand, start_indices));
auto ar = comp->AddInstruction(HloInstruction::CreateAllReduce(
dus->shape(), {dus},
MakeBinaryAdd(dus->shape().element_type(), comp->parent()),
ag->replica_groups(),
/*constrain_layout=*/ag->constrain_layout(), ag->channel_id(),
ag->use_global_device_ids()));
TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(ar));
ag.replica_groups(),
/*constrain_layout=*/ag.constrain_layout(), ag.channel_id(),
ag.use_global_device_ids()));
return ar;
}

Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) {
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
GetCollectiveOpGroupMode(ag->channel_id().has_value(),
ag->use_global_device_ids()));
if (ag->operand_count() > 1) {
HloInstruction* token = ag->mutable_operands().back();
std::vector<HloInstruction*> tuple_inputs;
for (int i = 0; i < ag->operand_count() - 1; ++i) {
auto* input_operand = ag->mutable_operand(i);
const auto& output_shape = ag->shape().tuple_shapes(i);
auto* ar = TranslateAllGatherToAllReducePerOperand(
group_mode, *ag, output_shape, input_operand, comp);
tuple_inputs.push_back(ar);
}
tuple_inputs.push_back(token);
auto tup = comp->AddInstruction(HloInstruction::CreateTuple(tuple_inputs));
TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(tup));
} else {
auto* ar = TranslateAllGatherToAllReducePerOperand(
group_mode, *ag, ag->shape(), ag->mutable_operand(0), comp);
TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(ar));
}
TF_RETURN_IF_ERROR(comp->RemoveInstructionAndUnusedOperands(ag));
return OkStatus();
}
Expand Down
30 changes: 30 additions & 0 deletions xla/service/all_gather_decomposer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,5 +155,35 @@ ENTRY entry {
op::Constant(), op::Multiply(id, op::Constant()))));
}

TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithToken) {
const std::string module_str = R"(
HloModule module

ENTRY entry {
param0 = f32[10,20] parameter(0)
param1 = f32[10,16] parameter(1)
t = f32[] parameter(2)
ROOT ag = (f32[10,80], f32[10,64], f32[]) all-gather(param0, param1, t),
replica_groups={}, dimensions={1}
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule((module_str)));
AllGatherDecomposer decomposer;
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(
module->entry_computation()->root_instruction(),
op::Tuple(
op::AllReduce(op::DynamicUpdateSlice(
op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(),
op::Multiply(op::ReplicaId(), op::Constant()))),
op::AllReduce(op::DynamicUpdateSlice(
op::Broadcast(op::Constant()), op::Parameter(1), op::Constant(),
op::Multiply(op::ReplicaId(), op::Constant()))),
op::Parameter(2)));
}

} // namespace
} // namespace xla
4 changes: 4 additions & 0 deletions xla/service/cpu/ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,10 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
return HandleAllReduceMultipleReplica(crs);
}

Status IrEmitter::HandleReduceScatter(HloInstruction* rs) {
return Unimplemented("ReduceScatter is not implemented on CPU.");
}

Status IrEmitter::HandleAllToAll(HloInstruction* instruction) {
auto* instr = Cast<HloAllToAllInstruction>(instruction);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction));
Expand Down
1 change: 1 addition & 0 deletions xla/service/cpu/ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleFft(HloInstruction* fft) override;
Status HandleAllReduce(HloInstruction* crs) override;
Status HandleReduceScatter(HloInstruction* crs) override;
Status HandleCollectivePermute(HloInstruction* crs) override;
Status HandleInfeed(HloInstruction* instruction) override;
Status HandleOutfeed(HloInstruction* outfeed) override;
Expand Down
30 changes: 25 additions & 5 deletions xla/service/hlo_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,20 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo,
ag->use_global_device_ids()));
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode));
TF_RET_CHECK(ag->all_gather_dimension() >= 0);
TF_RET_CHECK(ag->operand_count() >= 1);

int64_t shard_count;
// There can be one token in the input Tuple. The token is a scalar or
// `token`.
bool token_encountered = false;
for (int64_t i = 0; i < ag->operand_count(); ++i) {
TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(i)->shape().rank());
const Shape& operand_shape = ag->operand(i)->shape();
if (operand_shape.IsToken() || operand_shape.rank() == 0) {
TF_RET_CHECK(!token_encountered) << "AllGather can have at most 1 token.";
token_encountered = true;
continue;
}
TF_RET_CHECK(ag->all_gather_dimension() < operand_shape.rank());

Shape output_shape;
if (hlo->opcode() == HloOpcode::kAllGather) {
Expand All @@ -451,9 +461,9 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo,
}
TF_RET_CHECK(ag->all_gather_dimension() < output_shape.rank());
if (i == 0) {
shard_count = CeilOfRatio(
output_shape.dimensions(ag->all_gather_dimension()),
ag->operand(i)->shape().dimensions(ag->all_gather_dimension()));
shard_count =
CeilOfRatio(output_shape.dimensions(ag->all_gather_dimension()),
operand_shape.dimensions(ag->all_gather_dimension()));
}
}

Expand Down Expand Up @@ -521,9 +531,19 @@ Status ShapeVerifier::HandleReduceScatter(HloInstruction* hlo) {
ars->use_global_device_ids()));
TF_RETURN_IF_ERROR(CheckReplicaGroups(ars, group_mode));
TF_RET_CHECK(ars->scatter_dimension() >= 0);
TF_RET_CHECK(ars->operand_count() >= 1);

// There can be one token in the inputs. The token is a scalar or `token`.
bool token_encountered = false;
for (int64_t i = 0; i < ars->operand_count(); ++i) {
TF_RET_CHECK(ars->scatter_dimension() < ars->operand(i)->shape().rank());
const Shape& operand_shape = ars->operand(i)->shape();
if (operand_shape.IsToken() || operand_shape.rank() == 0) {
TF_RET_CHECK(!token_encountered)
<< "ReduceScatter can have at most 1 token.";
token_encountered = true;
continue;
}
TF_RET_CHECK(ars->scatter_dimension() < operand_shape.rank());

const Shape& output_shape = (ars->operand_count() == 1)
? ars->shape()
Expand Down
57 changes: 56 additions & 1 deletion xla/service/hlo_verifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2338,7 +2338,7 @@ TEST_F(HloVerifierTest, ReduceScatterInvalidScatterDim) {
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
HasSubstr("ars->scatter_dimension() < ars->operand(i)->shape().rank()"));
HasSubstr("ars->scatter_dimension() < operand_shape.rank()"));
}

TEST_F(HloVerifierTest, ReduceScatterNonUniformGroups) {
Expand All @@ -2364,6 +2364,7 @@ TEST_F(HloVerifierTest, ReduceScatterNonUniformGroups) {
HasSubstr("Replica groups expected to be of uniform size"));
}


TEST_F(HloVerifierTest, ScatterInvalidScatterDim) {
const char* const hlo_string = R"(
HloModule Module
Expand Down Expand Up @@ -2391,6 +2392,60 @@ TEST_F(HloVerifierTest, ScatterInvalidScatterDim) {
HasSubstr("Invalid scatter_dims_to_operand_dims mapping"));
}


TEST_F(HloVerifierTest, ReduceScatterTwoTokens) {
const char* const hlo_string = R"(
HloModule Module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}

ENTRY CRS {
input = f32[8]{0} parameter(0)
token1 = f32[] parameter(1)
token2 = f32[] parameter(2)
ROOT crs = (f32[4]{0}, f32[], f32[]) reduce-scatter(input, token1, token2),
replica_groups={}, to_apply=add,
dimensions={0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));

auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
HasSubstr("ReduceScatter can have at most 1 token."));
}



TEST_F(HloVerifierTest, AllGatherTwoTokens) {
const char* const hlo_string = R"(
HloModule Module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}

ENTRY CRS {
input = f32[8]{0} parameter(0)
token1 = f32[] parameter(1)
token2 = f32[] parameter(2)
ROOT crs = (f32[4]{0}, f32[], f32[]) all-gather(input, token1, token2),
replica_groups={}, dimensions={0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));

auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
HasSubstr("AllGather can have at most 1 token."));
}

TEST_F(HloVerifierTest, VerifyBroadcastDimensionsOrder) {
const char* const hlo = R"(
HloModule module
Expand Down
Loading
Loading