diff --git a/xla/client/xla_builder.cc b/xla/client/xla_builder.cc index a03253635defdd..6540d6db7220e0 100644 --- a/xla/client/xla_builder.cc +++ b/xla/client/xla_builder.cc @@ -2883,10 +2883,24 @@ XlaOp XlaBuilder::AllGatherImpl(const XlaOp operand, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferAllGatherShape({operand_shape}, - all_gather_dimension, shard_count)); + std::vector operand_shapes; + std::vector operands; + if (operand_shape->IsTuple()) { + if (operand_shape->tuple_shapes_size() == 0) { + return Unimplemented("0 element tuple AllGather is not supported"); + } + for (int i = 0; i < operand_shape->tuple_shapes_size(); ++i) { + operand_shapes.push_back(&operand_shape->tuple_shapes(i)); + operands.push_back(GetTupleElement(operand, i)); + } + } else { + operand_shapes.push_back(operand_shape); + operands.push_back(operand); + } + + TF_ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferAllGatherShape( + operand_shapes, all_gather_dimension, shard_count)); if (layout) { *inferred_shape.mutable_layout() = *layout; instr.set_constrain_layout(true); @@ -2908,7 +2922,7 @@ XlaOp XlaBuilder::AllGatherImpl(const XlaOp operand, AddInstruction(std::move(instr), async ? HloOpcode::kAllGatherStart : HloOpcode::kAllGather, - {operand})); + operands)); return all_gather; }); } @@ -3320,8 +3334,7 @@ XlaOp XlaBuilder::ReduceScatter( operand_shape->tuple_shapes(0).element_type()) { return Unimplemented( "All the shapes of a tuple input of ReduceScatter must have " - "the same " - "element type"); + "the same element type"); } operand_shapes.push_back(&operand_shape->tuple_shapes(i)); operands.push_back(GetTupleElement(operand, i)); @@ -3355,7 +3368,7 @@ XlaOp XlaBuilder::ReduceScatter( TF_ASSIGN_OR_RETURN( auto reduce_scatter, - AddInstruction(std::move(instr), HloOpcode::kReduceScatter, {operand})); + AddInstruction(std::move(instr), HloOpcode::kReduceScatter, operands)); return reduce_scatter; }); } diff --git a/xla/client/xla_builder_test.cc b/xla/client/xla_builder_test.cc index 3214f5e6ad4590..60503f890be013 100644 --- a/xla/client/xla_builder_test.cc +++ b/xla/client/xla_builder_test.cc @@ -418,6 +418,22 @@ TEST_F(XlaBuilderTest, AllGatherR2) { ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64}))); } +TEST_F(XlaBuilderTest, AllGatherWithTuple) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x"); + auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2"); + AllGather(Tuple(&b, {x, x2}), /*all_gather_dimension=*/0, + /*shard_count=*/4); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + EXPECT_EQ(root->opcode(), HloOpcode::kAllGather); + EXPECT_TRUE(ShapeUtil::Equal( + root->shape(), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {16}), + ShapeUtil::MakeShape(F32, {64, 4})}))); +} + TEST_F(XlaBuilderTest, ReduceScatter) { XlaBuilder b(TestName()); XlaComputation to_apply; @@ -444,6 +460,36 @@ TEST_F(XlaBuilderTest, ReduceScatter) { ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 8}))); } +TEST_F(XlaBuilderTest, ReduceScatterWithTuple) { + XlaBuilder b(TestName()); + XlaComputation to_apply; + { + auto sub_builder = b.CreateSubBuilder("add"); + auto arg0 = + Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), "x"); + auto arg1 = + Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), "y"); + Add(arg0, arg1); + TF_ASSERT_OK_AND_ASSIGN(to_apply, sub_builder->Build()); + } + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); + auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2"); + ReplicaGroup group; + group.add_replica_ids(0); + group.add_replica_ids(1); + ReduceScatter(Tuple(&b, {x, x2}), to_apply, /*scatter_dimension=*/1, + /*shard_count=*/2, + /*replica_groups=*/{group}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + EXPECT_EQ(root->opcode(), HloOpcode::kReduceScatter); + EXPECT_TRUE(ShapeUtil::Equal( + root->shape(), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4, 8}), + ShapeUtil::MakeShape(F32, {16, 2})}))); +} + TEST_F(XlaBuilderTest, AllToAll) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); diff --git a/xla/service/all_gather_decomposer.cc b/xla/service/all_gather_decomposer.cc index 88e9c9d1b5d18e..7e5325a52a595d 100644 --- a/xla/service/all_gather_decomposer.cc +++ b/xla/service/all_gather_decomposer.cc @@ -52,30 +52,51 @@ HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) { return reduction; } -Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) { - TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, - GetCollectiveOpGroupMode(ag->channel_id().has_value(), - ag->use_global_device_ids())); - TF_ASSIGN_OR_RETURN( - std::vector start_indices, +HloInstruction* TranslateAllGatherToAllReducePerOperand( + CollectiveOpGroupMode group_mode, const HloAllGatherInstruction& ag, + const Shape& output_shape, HloInstruction* operand, HloComputation* comp) { + std::vector start_indices = CreateStartIndicesForCollectiveDecomposition( - group_mode, ag->replica_groups(), ag->operand(0)->shape(), - ag->all_gather_dimension(), comp)); + group_mode, ag.replica_groups(), operand->shape(), + ag.all_gather_dimension(), comp) + .value(); auto zero = comp->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(ag->shape().element_type()))); + LiteralUtil::Zero(output_shape.element_type()))); zero = comp->AddInstruction( - HloInstruction::CreateBroadcast(ag->shape(), zero, {})); + HloInstruction::CreateBroadcast(output_shape, zero, {})); auto dus = comp->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - zero->shape(), zero, ag->mutable_operand(0), start_indices)); + zero->shape(), zero, operand, start_indices)); auto ar = comp->AddInstruction(HloInstruction::CreateAllReduce( dus->shape(), {dus}, MakeBinaryAdd(dus->shape().element_type(), comp->parent()), - ag->replica_groups(), - /*constrain_layout=*/ag->constrain_layout(), ag->channel_id(), - ag->use_global_device_ids())); - TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(ar)); + ag.replica_groups(), + /*constrain_layout=*/ag.constrain_layout(), ag.channel_id(), + ag.use_global_device_ids())); + return ar; +} + +Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) { + TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(ag->channel_id().has_value(), + ag->use_global_device_ids())); + if (ag->operand_count() > 1) { + std::vector tuple_inputs; + for (int i = 0; i < ag->operand_count(); ++i) { + auto* input_operand = ag->mutable_operand(i); + const auto& output_shape = ag->shape().tuple_shapes(i); + auto* ar = TranslateAllGatherToAllReducePerOperand( + group_mode, *ag, output_shape, input_operand, comp); + tuple_inputs.push_back(ar); + } + auto tup = comp->AddInstruction(HloInstruction::CreateTuple(tuple_inputs)); + TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(tup)); + } else { + auto* ar = TranslateAllGatherToAllReducePerOperand( + group_mode, *ag, ag->shape(), ag->mutable_operand(0), comp); + TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(ar)); + } TF_RETURN_IF_ERROR(comp->RemoveInstructionAndUnusedOperands(ag)); return OkStatus(); } diff --git a/xla/service/all_gather_decomposer_test.cc b/xla/service/all_gather_decomposer_test.cc index 7fae210429f571..4eb629a3473112 100644 --- a/xla/service/all_gather_decomposer_test.cc +++ b/xla/service/all_gather_decomposer_test.cc @@ -155,5 +155,33 @@ ENTRY entry { op::Constant(), op::Multiply(id, op::Constant())))); } +TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithTuple) { + const std::string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + param1 = f32[10,16] parameter(1) + ROOT ag = (f32[10,80], f32[10,64]) all-gather(param0, param1), + replica_groups={}, dimensions={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple( + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(), + op::Multiply(op::ReplicaId(), op::Constant()))), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(1), op::Constant(), + op::Multiply(op::ReplicaId(), op::Constant()))))); +} + } // namespace } // namespace xla diff --git a/xla/service/cpu/ir_emitter.cc b/xla/service/cpu/ir_emitter.cc index 7013a45f1edcfb..3daa2ba9b2d2b2 100644 --- a/xla/service/cpu/ir_emitter.cc +++ b/xla/service/cpu/ir_emitter.cc @@ -1289,6 +1289,10 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { return HandleAllReduceMultipleReplica(crs); } +Status IrEmitter::HandleReduceScatter(HloInstruction* rs) { + return Unimplemented("ReduceScatter is not implemented on CPU."); +} + Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { auto* instr = Cast(instruction); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction)); diff --git a/xla/service/cpu/ir_emitter.h b/xla/service/cpu/ir_emitter.h index c4534fa619d650..3a194d054cb5fd 100644 --- a/xla/service/cpu/ir_emitter.h +++ b/xla/service/cpu/ir_emitter.h @@ -144,6 +144,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; Status HandleAllReduce(HloInstruction* crs) override; + Status HandleReduceScatter(HloInstruction* crs) override; Status HandleCollectivePermute(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* instruction) override; Status HandleOutfeed(HloInstruction* outfeed) override; diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index 9c8c86336c8bfb..5ad9d2323bd221 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -434,6 +434,7 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo, ag->use_global_device_ids())); TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode)); TF_RET_CHECK(ag->all_gather_dimension() >= 0); + TF_RET_CHECK(ag->operand_count() >= 1); int64_t shard_count; for (int64_t i = 0; i < ag->operand_count(); ++i) { @@ -521,6 +522,7 @@ Status ShapeVerifier::HandleReduceScatter(HloInstruction* hlo) { ars->use_global_device_ids())); TF_RETURN_IF_ERROR(CheckReplicaGroups(ars, group_mode)); TF_RET_CHECK(ars->scatter_dimension() >= 0); + TF_RET_CHECK(ars->operand_count() >= 1); for (int64_t i = 0; i < ars->operand_count(); ++i) { TF_RET_CHECK(ars->scatter_dimension() < ars->operand(i)->shape().rank());