From 7ea1159a1464efddebe9384e87ed6df504d89b2e Mon Sep 17 00:00:00 2001 From: Junmin Hao Date: Mon, 31 Oct 2022 00:12:42 -0700 Subject: [PATCH 1/6] Add Tuple input and token support to all-gather and reduce-scatter. Committer: Junmin Hao --- xla/client/xla_builder.cc | 30 ++++++++++--- xla/client/xla_builder_test.cc | 52 +++++++++++++++++++++ xla/service/all_gather_decomposer.cc | 53 +++++++++++++++------- xla/service/all_gather_decomposer_test.cc | 32 +++++++++++++ xla/service/cpu/ir_emitter.cc | 4 ++ xla/service/cpu/ir_emitter.h | 1 + xla/service/hlo_verifier.cc | 26 +++++++++-- xla/service/hlo_verifier_test.cc | 55 +++++++++++++++++++++++ xla/service/shape_inference.cc | 16 +++++++ 9 files changed, 246 insertions(+), 23 deletions(-) diff --git a/xla/client/xla_builder.cc b/xla/client/xla_builder.cc index a03253635defd..6f2b03c2eb0de 100644 --- a/xla/client/xla_builder.cc +++ b/xla/client/xla_builder.cc @@ -2883,9 +2883,30 @@ XlaOp XlaBuilder::AllGatherImpl(const XlaOp operand, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + std::vector operand_shapes; + std::vector operands; + if (operand_shape->IsTuple()) { + if (operand_shape->tuple_shapes_size() == 0) { + return Unimplemented("0 element tuple AllGather is not supported"); + } + for (int i = 0; i < operand_shape->tuple_shapes_size(); ++i) { + if (operand_shape->tuple_shapes(i).element_type() != + operand_shape->tuple_shapes(0).element_type()) { + return Unimplemented( + "All the shapes of a tuple input of AllGather must have the same " + "element type"); + } + operand_shapes.push_back(&operand_shape->tuple_shapes(i)); + operands.push_back(GetTupleElement(operand, i)); + } + } else { + operand_shapes.push_back(operand_shape); + operands.push_back(operand); + } + TF_ASSIGN_OR_RETURN( Shape inferred_shape, - ShapeInference::InferAllGatherShape({operand_shape}, + ShapeInference::InferAllGatherShape(operand_shapes, all_gather_dimension, shard_count)); if (layout) { *inferred_shape.mutable_layout() = *layout; @@ -2908,7 +2929,7 @@ XlaOp XlaBuilder::AllGatherImpl(const XlaOp operand, AddInstruction(std::move(instr), async ? HloOpcode::kAllGatherStart : HloOpcode::kAllGather, - {operand})); + operands)); return all_gather; }); } @@ -3320,8 +3341,7 @@ XlaOp XlaBuilder::ReduceScatter( operand_shape->tuple_shapes(0).element_type()) { return Unimplemented( "All the shapes of a tuple input of ReduceScatter must have " - "the same " - "element type"); + "the same element type"); } operand_shapes.push_back(&operand_shape->tuple_shapes(i)); operands.push_back(GetTupleElement(operand, i)); @@ -3355,7 +3375,7 @@ XlaOp XlaBuilder::ReduceScatter( TF_ASSIGN_OR_RETURN( auto reduce_scatter, - AddInstruction(std::move(instr), HloOpcode::kReduceScatter, {operand})); + AddInstruction(std::move(instr), HloOpcode::kReduceScatter, operands)); return reduce_scatter; }); } diff --git a/xla/client/xla_builder_test.cc b/xla/client/xla_builder_test.cc index 3214f5e6ad459..74f4b27f0670c 100644 --- a/xla/client/xla_builder_test.cc +++ b/xla/client/xla_builder_test.cc @@ -418,6 +418,25 @@ TEST_F(XlaBuilderTest, AllGatherR2) { ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64}))); } +TEST_F(XlaBuilderTest, AllGatherWithToken) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x"); + auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2"); + auto t = Parameter(&b, 2, ShapeUtil::MakeScalarShape(F32), "t"); + AllGather(Tuple(&b, {x, x2, t}), /*all_gather_dimension=*/0, /*shard_count=*/4); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + EXPECT_EQ(root->opcode(), HloOpcode::kAllGather); + EXPECT_TRUE( + ShapeUtil::Equal(root->shape(), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {16}), + ShapeUtil::MakeShape(F32, {64, 4}), + ShapeUtil::MakeScalarShape(F32)}) + )); +} + TEST_F(XlaBuilderTest, ReduceScatter) { XlaBuilder b(TestName()); XlaComputation to_apply; @@ -444,6 +463,39 @@ TEST_F(XlaBuilderTest, ReduceScatter) { ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 8}))); } +TEST_F(XlaBuilderTest, ReduceScatterWithToken) { + XlaBuilder b(TestName()); + XlaComputation to_apply; + { + auto sub_builder = b.CreateSubBuilder("add"); + auto arg0 = + Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), "x"); + auto arg1 = + Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), "y"); + Add(arg0, arg1); + TF_ASSERT_OK_AND_ASSIGN(to_apply, sub_builder->Build()); + } + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); + auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2"); + auto t = Parameter(&b, 2, ShapeUtil::MakeScalarShape(F32), "t"); + ReplicaGroup group; + group.add_replica_ids(0); + group.add_replica_ids(1); + ReduceScatter(Tuple(&b, {x, x2, t}), to_apply, /*scatter_dimension=*/1, /*shard_count=*/2, + /*replica_groups=*/{group}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + EXPECT_EQ(root->opcode(), HloOpcode::kReduceScatter); + EXPECT_TRUE( + ShapeUtil::Equal(root->shape(), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {4, 8}), + ShapeUtil::MakeShape(F32, {16, 2}), + ShapeUtil::MakeScalarShape(F32)}) + )); +} + TEST_F(XlaBuilderTest, AllToAll) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); diff --git a/xla/service/all_gather_decomposer.cc b/xla/service/all_gather_decomposer.cc index 88e9c9d1b5d18..ea777ef185abb 100644 --- a/xla/service/all_gather_decomposer.cc +++ b/xla/service/all_gather_decomposer.cc @@ -52,30 +52,53 @@ HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) { return reduction; } -Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) { - TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, - GetCollectiveOpGroupMode(ag->channel_id().has_value(), - ag->use_global_device_ids())); - TF_ASSIGN_OR_RETURN( - std::vector start_indices, +HloInstruction* TranslateAllGatherToAllReducePerOperand( + CollectiveOpGroupMode group_mode, + const HloAllGatherInstruction& ag, const Shape& output_shape, + HloInstruction* operand, HloComputation* comp) { + std::vector start_indices = CreateStartIndicesForCollectiveDecomposition( - group_mode, ag->replica_groups(), ag->operand(0)->shape(), - ag->all_gather_dimension(), comp)); + group_mode, ag.replica_groups(), operand->shape(), + ag.all_gather_dimension(), comp).value(); auto zero = comp->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(ag->shape().element_type()))); + LiteralUtil::Zero(output_shape.element_type()))); zero = comp->AddInstruction( - HloInstruction::CreateBroadcast(ag->shape(), zero, {})); + HloInstruction::CreateBroadcast(output_shape, zero, {})); auto dus = comp->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - zero->shape(), zero, ag->mutable_operand(0), start_indices)); + zero->shape(), zero, operand, start_indices)); auto ar = comp->AddInstruction(HloInstruction::CreateAllReduce( dus->shape(), {dus}, MakeBinaryAdd(dus->shape().element_type(), comp->parent()), - ag->replica_groups(), - /*constrain_layout=*/ag->constrain_layout(), ag->channel_id(), - ag->use_global_device_ids())); - TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(ar)); + ag.replica_groups(), + /*constrain_layout=*/ag.constrain_layout(), ag.channel_id(), + ag.use_global_device_ids())); + return ar; +} + +Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) { + TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(ag->channel_id().has_value(), + ag->use_global_device_ids())); + if (ag->operand_count() > 1) { + HloInstruction* token = ag->mutable_operands().back(); + std::vector tuple_inputs; + for (int i = 0; i < ag->operand_count() - 1; ++i) { + auto* input_operand = ag->mutable_operand(i); + const auto& output_shape = ag->shape().tuple_shapes(i); + auto* ar = TranslateAllGatherToAllReducePerOperand( + group_mode, *ag, output_shape, input_operand, comp); + tuple_inputs.push_back(ar); + } + tuple_inputs.push_back(token); + auto tup = comp->AddInstruction(HloInstruction::CreateTuple(tuple_inputs)); + TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(tup)); + } else { + auto* ar = TranslateAllGatherToAllReducePerOperand( + group_mode, *ag, ag->shape(), ag->mutable_operand(0), comp); + TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(ar)); + } TF_RETURN_IF_ERROR(comp->RemoveInstructionAndUnusedOperands(ag)); return OkStatus(); } diff --git a/xla/service/all_gather_decomposer_test.cc b/xla/service/all_gather_decomposer_test.cc index 7fae210429f57..fc3cafe1ea0ee 100644 --- a/xla/service/all_gather_decomposer_test.cc +++ b/xla/service/all_gather_decomposer_test.cc @@ -155,5 +155,37 @@ ENTRY entry { op::Constant(), op::Multiply(id, op::Constant())))); } +TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithToken) { + const std::string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + param1 = f32[10,16] parameter(1) + t = f32[] parameter(2) + ROOT ag = (f32[10,80], f32[10,64], f32[]) all-gather(param0, param1, t), + replica_groups={}, dimensions={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple( + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(), + op::Multiply(op::ReplicaId(), op::Constant()))), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(1), op::Constant(), + op::Multiply(op::ReplicaId(), op::Constant()))), + op::Parameter(2) + )); +} + + } // namespace } // namespace xla diff --git a/xla/service/cpu/ir_emitter.cc b/xla/service/cpu/ir_emitter.cc index 7013a45f1edcf..3daa2ba9b2d2b 100644 --- a/xla/service/cpu/ir_emitter.cc +++ b/xla/service/cpu/ir_emitter.cc @@ -1289,6 +1289,10 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { return HandleAllReduceMultipleReplica(crs); } +Status IrEmitter::HandleReduceScatter(HloInstruction* rs) { + return Unimplemented("ReduceScatter is not implemented on CPU."); +} + Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { auto* instr = Cast(instruction); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction)); diff --git a/xla/service/cpu/ir_emitter.h b/xla/service/cpu/ir_emitter.h index c4534fa619d65..3a194d054cb5f 100644 --- a/xla/service/cpu/ir_emitter.h +++ b/xla/service/cpu/ir_emitter.h @@ -144,6 +144,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; Status HandleAllReduce(HloInstruction* crs) override; + Status HandleReduceScatter(HloInstruction* crs) override; Status HandleCollectivePermute(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* instruction) override; Status HandleOutfeed(HloInstruction* outfeed) override; diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index 9c8c86336c8bf..3a3e6c31d33e0 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -434,10 +434,20 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo, ag->use_global_device_ids())); TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode)); TF_RET_CHECK(ag->all_gather_dimension() >= 0); + TF_RET_CHECK(ag->operand_count() >= 1); int64_t shard_count; + // There can be one token in the input Tuple. The token is a scalar or `token`. + bool token_encountered = false; for (int64_t i = 0; i < ag->operand_count(); ++i) { - TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(i)->shape().rank()); + const Shape& operand_shape = ag->operand(i)->shape(); + if (operand_shape.IsToken() || operand_shape.rank() == 0) { + TF_RET_CHECK(!token_encountered) + << "AllGather can have at most 1 token."; + token_encountered = true; + continue; + } + TF_RET_CHECK(ag->all_gather_dimension() < operand_shape.rank()); Shape output_shape; if (hlo->opcode() == HloOpcode::kAllGather) { @@ -453,7 +463,7 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo, if (i == 0) { shard_count = CeilOfRatio( output_shape.dimensions(ag->all_gather_dimension()), - ag->operand(i)->shape().dimensions(ag->all_gather_dimension())); + operand_shape.dimensions(ag->all_gather_dimension())); } } @@ -521,9 +531,19 @@ Status ShapeVerifier::HandleReduceScatter(HloInstruction* hlo) { ars->use_global_device_ids())); TF_RETURN_IF_ERROR(CheckReplicaGroups(ars, group_mode)); TF_RET_CHECK(ars->scatter_dimension() >= 0); + TF_RET_CHECK(ars->operand_count() >= 1); + // There can be one token in the inputs. The token is a scalar or `token`. + bool token_encountered = false; for (int64_t i = 0; i < ars->operand_count(); ++i) { - TF_RET_CHECK(ars->scatter_dimension() < ars->operand(i)->shape().rank()); + const Shape& operand_shape = ars->operand(i)->shape(); + if (operand_shape.IsToken() || operand_shape.rank() == 0) { + TF_RET_CHECK(!token_encountered) + << "ReduceScatter can have at most 1 token."; + token_encountered = true; + continue; + } + TF_RET_CHECK(ars->scatter_dimension() < operand_shape.rank()); const Shape& output_shape = (ars->operand_count() == 1) ? ars->shape() diff --git a/xla/service/hlo_verifier_test.cc b/xla/service/hlo_verifier_test.cc index aa0991c26a3c4..50bda1b7670d5 100644 --- a/xla/service/hlo_verifier_test.cc +++ b/xla/service/hlo_verifier_test.cc @@ -2364,6 +2364,7 @@ TEST_F(HloVerifierTest, ReduceScatterNonUniformGroups) { HasSubstr("Replica groups expected to be of uniform size")); } + TEST_F(HloVerifierTest, ScatterInvalidScatterDim) { const char* const hlo_string = R"( HloModule Module @@ -2391,6 +2392,60 @@ TEST_F(HloVerifierTest, ScatterInvalidScatterDim) { HasSubstr("Invalid scatter_dims_to_operand_dims mapping")); } + +TEST_F(HloVerifierTest, ReduceScatterTwoTokens) { + const char* const hlo_string = R"( + HloModule Module + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY CRS { + input = f32[8]{0} parameter(0) + token1 = f32[] parameter(1) + token2 = f32[] parameter(2) + ROOT crs = (f32[4]{0}, f32[], f32[]) reduce-scatter(input, token1, token2), + replica_groups={}, to_apply=add, + dimensions={0} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + HasSubstr("ReduceScatter can have at most 1 token.")); +} + + + +TEST_F(HloVerifierTest, AllGatherTwoTokens) { + const char* const hlo_string = R"( + HloModule Module + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY CRS { + input = f32[8]{0} parameter(0) + token1 = f32[] parameter(1) + token2 = f32[] parameter(2) + ROOT crs = (f32[4]{0}, f32[], f32[]) all-gather(input, token1, token2), + replica_groups={}, dimensions={0} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + HasSubstr("AllGather can have at most 1 token.")); +} + TEST_F(HloVerifierTest, VerifyBroadcastDimensionsOrder) { const char* const hlo = R"( HloModule module diff --git a/xla/service/shape_inference.cc b/xla/service/shape_inference.cc index fbab0ea7fff26..3334f89ab2494 100644 --- a/xla/service/shape_inference.cc +++ b/xla/service/shape_inference.cc @@ -2083,7 +2083,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, std::vector output_shapes; output_shapes.reserve(operand_shapes.size()); + // There can be one token in the input Tuple. The token is a scalar or `token`. + bool token_encountered = false; for (const Shape* operand_shape : operand_shapes) { + if (operand_shape->IsToken() || operand_shape->rank() == 0) { + TF_RET_CHECK(!token_encountered); + token_encountered = true; + output_shapes.push_back(*operand_shape); + continue; + } TF_RET_CHECK(all_gather_dimension < operand_shape->rank()); TF_RETURN_IF_ERROR(ExpectArray(*operand_shape, "operand of all-gather")); @@ -2139,7 +2147,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, std::vector output_shapes; output_shapes.reserve(operand_shapes.size()); + // There can be one token in the input Tuple. The token is a scalar or `token`. + bool token_encountered = false; for (const Shape* operand_shape : operand_shapes) { + if (operand_shape->IsToken() || operand_shape->rank() == 0) { + TF_RET_CHECK(!token_encountered); + token_encountered = true; + output_shapes.push_back(*operand_shape); + continue; + } TF_RET_CHECK(scatter_dimension < operand_shape->rank()); TF_RETURN_IF_ERROR( ExpectArray(*operand_shape, "operand of reduce-scatter")); From cdb873e6d97f5f12b3d3c587bb5782d58e3554c5 Mon Sep 17 00:00:00 2001 From: Junmin Hao Date: Mon, 31 Oct 2022 13:24:11 -0700 Subject: [PATCH 2/6] lint fix --- xla/client/xla_builder.cc | 7 +++--- xla/client/xla_builder_test.cc | 30 +++++++++++------------ xla/service/all_gather_decomposer.cc | 8 +++--- xla/service/all_gather_decomposer_test.cc | 16 ++++++------ xla/service/hlo_verifier.cc | 14 +++++------ xla/service/shape_inference.cc | 6 +++-- 6 files changed, 39 insertions(+), 42 deletions(-) diff --git a/xla/client/xla_builder.cc b/xla/client/xla_builder.cc index 6f2b03c2eb0de..b839c5747f365 100644 --- a/xla/client/xla_builder.cc +++ b/xla/client/xla_builder.cc @@ -2904,10 +2904,9 @@ XlaOp XlaBuilder::AllGatherImpl(const XlaOp operand, operands.push_back(operand); } - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferAllGatherShape(operand_shapes, - all_gather_dimension, shard_count)); + TF_ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferAllGatherShape( + operand_shapes, all_gather_dimension, shard_count)); if (layout) { *inferred_shape.mutable_layout() = *layout; instr.set_constrain_layout(true); diff --git a/xla/client/xla_builder_test.cc b/xla/client/xla_builder_test.cc index 74f4b27f0670c..a74e36962329f 100644 --- a/xla/client/xla_builder_test.cc +++ b/xla/client/xla_builder_test.cc @@ -423,18 +423,17 @@ TEST_F(XlaBuilderTest, AllGatherWithToken) { auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x"); auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2"); auto t = Parameter(&b, 2, ShapeUtil::MakeScalarShape(F32), "t"); - AllGather(Tuple(&b, {x, x2, t}), /*all_gather_dimension=*/0, /*shard_count=*/4); + AllGather(Tuple(&b, {x, x2, t}), /*all_gather_dimension=*/0, + /*shard_count=*/4); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAllGather); - EXPECT_TRUE( - ShapeUtil::Equal(root->shape(), - ShapeUtil::MakeTupleShape({ - ShapeUtil::MakeShape(F32, {16}), - ShapeUtil::MakeShape(F32, {64, 4}), - ShapeUtil::MakeScalarShape(F32)}) - )); + EXPECT_TRUE(ShapeUtil::Equal( + root->shape(), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {16}), + ShapeUtil::MakeShape(F32, {64, 4}), + ShapeUtil::MakeScalarShape(F32)}))); } TEST_F(XlaBuilderTest, ReduceScatter) { @@ -481,19 +480,18 @@ TEST_F(XlaBuilderTest, ReduceScatterWithToken) { ReplicaGroup group; group.add_replica_ids(0); group.add_replica_ids(1); - ReduceScatter(Tuple(&b, {x, x2, t}), to_apply, /*scatter_dimension=*/1, /*shard_count=*/2, + ReduceScatter(Tuple(&b, {x, x2, t}), to_apply, /*scatter_dimension=*/1, + /*shard_count=*/2, /*replica_groups=*/{group}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kReduceScatter); - EXPECT_TRUE( - ShapeUtil::Equal(root->shape(), - ShapeUtil::MakeTupleShape({ - ShapeUtil::MakeShape(F32, {4, 8}), - ShapeUtil::MakeShape(F32, {16, 2}), - ShapeUtil::MakeScalarShape(F32)}) - )); + EXPECT_TRUE(ShapeUtil::Equal( + root->shape(), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4, 8}), + ShapeUtil::MakeShape(F32, {16, 2}), + ShapeUtil::MakeScalarShape(F32)}))); } TEST_F(XlaBuilderTest, AllToAll) { diff --git a/xla/service/all_gather_decomposer.cc b/xla/service/all_gather_decomposer.cc index ea777ef185abb..b170f836ba378 100644 --- a/xla/service/all_gather_decomposer.cc +++ b/xla/service/all_gather_decomposer.cc @@ -53,13 +53,13 @@ HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) { } HloInstruction* TranslateAllGatherToAllReducePerOperand( - CollectiveOpGroupMode group_mode, - const HloAllGatherInstruction& ag, const Shape& output_shape, - HloInstruction* operand, HloComputation* comp) { + CollectiveOpGroupMode group_mode, const HloAllGatherInstruction& ag, + const Shape& output_shape, HloInstruction* operand, HloComputation* comp) { std::vector start_indices = CreateStartIndicesForCollectiveDecomposition( group_mode, ag.replica_groups(), operand->shape(), - ag.all_gather_dimension(), comp).value(); + ag.all_gather_dimension(), comp) + .value(); auto zero = comp->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(output_shape.element_type()))); diff --git a/xla/service/all_gather_decomposer_test.cc b/xla/service/all_gather_decomposer_test.cc index fc3cafe1ea0ee..567d95337c017 100644 --- a/xla/service/all_gather_decomposer_test.cc +++ b/xla/service/all_gather_decomposer_test.cc @@ -176,16 +176,14 @@ ENTRY entry { EXPECT_THAT( module->entry_computation()->root_instruction(), op::Tuple( - op::AllReduce(op::DynamicUpdateSlice( - op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(), - op::Multiply(op::ReplicaId(), op::Constant()))), - op::AllReduce(op::DynamicUpdateSlice( - op::Broadcast(op::Constant()), op::Parameter(1), op::Constant(), - op::Multiply(op::ReplicaId(), op::Constant()))), - op::Parameter(2) - )); + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(), + op::Multiply(op::ReplicaId(), op::Constant()))), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(1), op::Constant(), + op::Multiply(op::ReplicaId(), op::Constant()))), + op::Parameter(2))); } - } // namespace } // namespace xla diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index 3a3e6c31d33e0..3304ed90b98b9 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -437,13 +437,13 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo, TF_RET_CHECK(ag->operand_count() >= 1); int64_t shard_count; - // There can be one token in the input Tuple. The token is a scalar or `token`. + // There can be one token in the input Tuple. The token is a scalar or + // `token`. bool token_encountered = false; for (int64_t i = 0; i < ag->operand_count(); ++i) { const Shape& operand_shape = ag->operand(i)->shape(); if (operand_shape.IsToken() || operand_shape.rank() == 0) { - TF_RET_CHECK(!token_encountered) - << "AllGather can have at most 1 token."; + TF_RET_CHECK(!token_encountered) << "AllGather can have at most 1 token."; token_encountered = true; continue; } @@ -461,9 +461,9 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo, } TF_RET_CHECK(ag->all_gather_dimension() < output_shape.rank()); if (i == 0) { - shard_count = CeilOfRatio( - output_shape.dimensions(ag->all_gather_dimension()), - operand_shape.dimensions(ag->all_gather_dimension())); + shard_count = + CeilOfRatio(output_shape.dimensions(ag->all_gather_dimension()), + operand_shape.dimensions(ag->all_gather_dimension())); } } @@ -538,7 +538,7 @@ Status ShapeVerifier::HandleReduceScatter(HloInstruction* hlo) { for (int64_t i = 0; i < ars->operand_count(); ++i) { const Shape& operand_shape = ars->operand(i)->shape(); if (operand_shape.IsToken() || operand_shape.rank() == 0) { - TF_RET_CHECK(!token_encountered) + TF_RET_CHECK(!token_encountered) << "ReduceScatter can have at most 1 token."; token_encountered = true; continue; diff --git a/xla/service/shape_inference.cc b/xla/service/shape_inference.cc index 3334f89ab2494..e4f8753f2bc06 100644 --- a/xla/service/shape_inference.cc +++ b/xla/service/shape_inference.cc @@ -2083,7 +2083,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, std::vector output_shapes; output_shapes.reserve(operand_shapes.size()); - // There can be one token in the input Tuple. The token is a scalar or `token`. + // There can be one token in the input Tuple. The token is a scalar or + // `token`. bool token_encountered = false; for (const Shape* operand_shape : operand_shapes) { if (operand_shape->IsToken() || operand_shape->rank() == 0) { @@ -2147,7 +2148,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, std::vector output_shapes; output_shapes.reserve(operand_shapes.size()); - // There can be one token in the input Tuple. The token is a scalar or `token`. + // There can be one token in the input Tuple. The token is a scalar or + // `token`. bool token_encountered = false; for (const Shape* operand_shape : operand_shapes) { if (operand_shape->IsToken() || operand_shape->rank() == 0) { From aad352117ba950ac5ae62330e3980f4b5898a701 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Fri, 22 Sep 2023 04:08:27 +0000 Subject: [PATCH 3/6] Fix hlo_verifier_test failure due to changed expectation --- xla/service/hlo_verifier_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xla/service/hlo_verifier_test.cc b/xla/service/hlo_verifier_test.cc index 5cfe47559441f..fa6704b277054 100644 --- a/xla/service/hlo_verifier_test.cc +++ b/xla/service/hlo_verifier_test.cc @@ -2365,7 +2365,7 @@ TEST_F(HloVerifierTest, ReduceScatterInvalidScatterDim) { ASSERT_FALSE(status.ok()); EXPECT_THAT( status.message(), - HasSubstr("ars->scatter_dimension() < ars->operand(i)->shape().rank()")); + HasSubstr("ars->scatter_dimension() < operand_shape.rank()")); } TEST_F(HloVerifierTest, ReduceScatterNonUniformGroups) { From 32e814524b88a474af5e4e904c0dd19841430b86 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Mon, 9 Oct 2023 18:56:44 +0000 Subject: [PATCH 4/6] Separate the token change out into a separate PR with RFC. --- xla/client/xla_builder.cc | 6 --- xla/service/all_gather_decomposer.cc | 4 +- xla/service/all_gather_decomposer_test.cc | 8 ++-- xla/service/hlo_verifier.cc | 23 +++------- xla/service/hlo_verifier_test.cc | 55 +---------------------- xla/service/shape_inference.cc | 18 -------- 6 files changed, 10 insertions(+), 104 deletions(-) diff --git a/xla/client/xla_builder.cc b/xla/client/xla_builder.cc index b839c5747f365..6540d6db7220e 100644 --- a/xla/client/xla_builder.cc +++ b/xla/client/xla_builder.cc @@ -2890,12 +2890,6 @@ XlaOp XlaBuilder::AllGatherImpl(const XlaOp operand, return Unimplemented("0 element tuple AllGather is not supported"); } for (int i = 0; i < operand_shape->tuple_shapes_size(); ++i) { - if (operand_shape->tuple_shapes(i).element_type() != - operand_shape->tuple_shapes(0).element_type()) { - return Unimplemented( - "All the shapes of a tuple input of AllGather must have the same " - "element type"); - } operand_shapes.push_back(&operand_shape->tuple_shapes(i)); operands.push_back(GetTupleElement(operand, i)); } diff --git a/xla/service/all_gather_decomposer.cc b/xla/service/all_gather_decomposer.cc index b170f836ba378..7e5325a52a595 100644 --- a/xla/service/all_gather_decomposer.cc +++ b/xla/service/all_gather_decomposer.cc @@ -82,16 +82,14 @@ Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) { GetCollectiveOpGroupMode(ag->channel_id().has_value(), ag->use_global_device_ids())); if (ag->operand_count() > 1) { - HloInstruction* token = ag->mutable_operands().back(); std::vector tuple_inputs; - for (int i = 0; i < ag->operand_count() - 1; ++i) { + for (int i = 0; i < ag->operand_count(); ++i) { auto* input_operand = ag->mutable_operand(i); const auto& output_shape = ag->shape().tuple_shapes(i); auto* ar = TranslateAllGatherToAllReducePerOperand( group_mode, *ag, output_shape, input_operand, comp); tuple_inputs.push_back(ar); } - tuple_inputs.push_back(token); auto tup = comp->AddInstruction(HloInstruction::CreateTuple(tuple_inputs)); TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(tup)); } else { diff --git a/xla/service/all_gather_decomposer_test.cc b/xla/service/all_gather_decomposer_test.cc index 567d95337c017..f6a3102db13a6 100644 --- a/xla/service/all_gather_decomposer_test.cc +++ b/xla/service/all_gather_decomposer_test.cc @@ -155,15 +155,14 @@ ENTRY entry { op::Constant(), op::Multiply(id, op::Constant())))); } -TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithToken) { +TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithTuple) { const std::string module_str = R"( HloModule module ENTRY entry { param0 = f32[10,20] parameter(0) param1 = f32[10,16] parameter(1) - t = f32[] parameter(2) - ROOT ag = (f32[10,80], f32[10,64], f32[]) all-gather(param0, param1, t), + ROOT ag = (f32[10,80], f32[10,64], f32[]) all-gather(param0, param1), replica_groups={}, dimensions={1} } )"; @@ -181,8 +180,7 @@ ENTRY entry { op::Multiply(op::ReplicaId(), op::Constant()))), op::AllReduce(op::DynamicUpdateSlice( op::Broadcast(op::Constant()), op::Parameter(1), op::Constant(), - op::Multiply(op::ReplicaId(), op::Constant()))), - op::Parameter(2))); + op::Multiply(op::ReplicaId(), op::Constant())))); } } // namespace diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index 18cbb1e907e49..2f2467af0d03c 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -441,13 +441,7 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo, // `token`. bool token_encountered = false; for (int64_t i = 0; i < ag->operand_count(); ++i) { - const Shape& operand_shape = ag->operand(i)->shape(); - if (operand_shape.IsToken() || operand_shape.rank() == 0) { - TF_RET_CHECK(!token_encountered) << "AllGather can have at most 1 token."; - token_encountered = true; - continue; - } - TF_RET_CHECK(ag->all_gather_dimension() < operand_shape.rank()); + TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(i)->shape().rank()); Shape output_shape; if (hlo->opcode() == HloOpcode::kAllGather) { @@ -461,9 +455,9 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo, } TF_RET_CHECK(ag->all_gather_dimension() < output_shape.rank()); if (i == 0) { - shard_count = - CeilOfRatio(output_shape.dimensions(ag->all_gather_dimension()), - operand_shape.dimensions(ag->all_gather_dimension())); + shard_count = CeilOfRatio( + output_shape.dimensions(ag->all_gather_dimension()), + ag->operand(i)->shape().dimensions(ag->all_gather_dimension())); } } @@ -536,14 +530,7 @@ Status ShapeVerifier::HandleReduceScatter(HloInstruction* hlo) { // There can be one token in the inputs. The token is a scalar or `token`. bool token_encountered = false; for (int64_t i = 0; i < ars->operand_count(); ++i) { - const Shape& operand_shape = ars->operand(i)->shape(); - if (operand_shape.IsToken() || operand_shape.rank() == 0) { - TF_RET_CHECK(!token_encountered) - << "ReduceScatter can have at most 1 token."; - token_encountered = true; - continue; - } - TF_RET_CHECK(ars->scatter_dimension() < operand_shape.rank()); + TF_RET_CHECK(ars->scatter_dimension() < ars->operand(i)->shape().rank()); const Shape& output_shape = (ars->operand_count() == 1) ? ars->shape() diff --git a/xla/service/hlo_verifier_test.cc b/xla/service/hlo_verifier_test.cc index fa6704b277054..41df9e6c064b9 100644 --- a/xla/service/hlo_verifier_test.cc +++ b/xla/service/hlo_verifier_test.cc @@ -2365,7 +2365,7 @@ TEST_F(HloVerifierTest, ReduceScatterInvalidScatterDim) { ASSERT_FALSE(status.ok()); EXPECT_THAT( status.message(), - HasSubstr("ars->scatter_dimension() < operand_shape.rank()")); + HasSubstr("ars->scatter_dimension() < ars->operand(i)->shape().rank()")); } TEST_F(HloVerifierTest, ReduceScatterNonUniformGroups) { @@ -2420,59 +2420,6 @@ TEST_F(HloVerifierTest, ScatterInvalidScatterDim) { } -TEST_F(HloVerifierTest, ReduceScatterTwoTokens) { - const char* const hlo_string = R"( - HloModule Module - add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - - ENTRY CRS { - input = f32[8]{0} parameter(0) - token1 = f32[] parameter(1) - token2 = f32[] parameter(2) - ROOT crs = (f32[4]{0}, f32[], f32[]) reduce-scatter(input, token1, token2), - replica_groups={}, to_apply=add, - dimensions={0} - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); - - auto status = verifier().Run(module.get()).status(); - ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - HasSubstr("ReduceScatter can have at most 1 token.")); -} - - - -TEST_F(HloVerifierTest, AllGatherTwoTokens) { - const char* const hlo_string = R"( - HloModule Module - add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - - ENTRY CRS { - input = f32[8]{0} parameter(0) - token1 = f32[] parameter(1) - token2 = f32[] parameter(2) - ROOT crs = (f32[4]{0}, f32[], f32[]) all-gather(input, token1, token2), - replica_groups={}, dimensions={0} - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); - - auto status = verifier().Run(module.get()).status(); - ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - HasSubstr("AllGather can have at most 1 token.")); -} - TEST_F(HloVerifierTest, VerifyBroadcastDimensionsOrder) { const char* const hlo = R"( HloModule module diff --git a/xla/service/shape_inference.cc b/xla/service/shape_inference.cc index e4f8753f2bc06..fbab0ea7fff26 100644 --- a/xla/service/shape_inference.cc +++ b/xla/service/shape_inference.cc @@ -2083,16 +2083,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, std::vector output_shapes; output_shapes.reserve(operand_shapes.size()); - // There can be one token in the input Tuple. The token is a scalar or - // `token`. - bool token_encountered = false; for (const Shape* operand_shape : operand_shapes) { - if (operand_shape->IsToken() || operand_shape->rank() == 0) { - TF_RET_CHECK(!token_encountered); - token_encountered = true; - output_shapes.push_back(*operand_shape); - continue; - } TF_RET_CHECK(all_gather_dimension < operand_shape->rank()); TF_RETURN_IF_ERROR(ExpectArray(*operand_shape, "operand of all-gather")); @@ -2148,16 +2139,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, std::vector output_shapes; output_shapes.reserve(operand_shapes.size()); - // There can be one token in the input Tuple. The token is a scalar or - // `token`. - bool token_encountered = false; for (const Shape* operand_shape : operand_shapes) { - if (operand_shape->IsToken() || operand_shape->rank() == 0) { - TF_RET_CHECK(!token_encountered); - token_encountered = true; - output_shapes.push_back(*operand_shape); - continue; - } TF_RET_CHECK(scatter_dimension < operand_shape->rank()); TF_RETURN_IF_ERROR( ExpectArray(*operand_shape, "operand of reduce-scatter")); From b301c2a2a5b52180f9e9626173e6b67a78782960 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Mon, 9 Oct 2023 21:46:24 +0000 Subject: [PATCH 5/6] Change *WithToken tests to *WithTuple --- xla/client/xla_builder_test.cc | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/xla/client/xla_builder_test.cc b/xla/client/xla_builder_test.cc index a74e36962329f..60503f890be01 100644 --- a/xla/client/xla_builder_test.cc +++ b/xla/client/xla_builder_test.cc @@ -418,12 +418,11 @@ TEST_F(XlaBuilderTest, AllGatherR2) { ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64}))); } -TEST_F(XlaBuilderTest, AllGatherWithToken) { +TEST_F(XlaBuilderTest, AllGatherWithTuple) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x"); auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2"); - auto t = Parameter(&b, 2, ShapeUtil::MakeScalarShape(F32), "t"); - AllGather(Tuple(&b, {x, x2, t}), /*all_gather_dimension=*/0, + AllGather(Tuple(&b, {x, x2}), /*all_gather_dimension=*/0, /*shard_count=*/4); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); @@ -432,8 +431,7 @@ TEST_F(XlaBuilderTest, AllGatherWithToken) { EXPECT_TRUE(ShapeUtil::Equal( root->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {16}), - ShapeUtil::MakeShape(F32, {64, 4}), - ShapeUtil::MakeScalarShape(F32)}))); + ShapeUtil::MakeShape(F32, {64, 4})}))); } TEST_F(XlaBuilderTest, ReduceScatter) { @@ -462,7 +460,7 @@ TEST_F(XlaBuilderTest, ReduceScatter) { ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 8}))); } -TEST_F(XlaBuilderTest, ReduceScatterWithToken) { +TEST_F(XlaBuilderTest, ReduceScatterWithTuple) { XlaBuilder b(TestName()); XlaComputation to_apply; { @@ -476,11 +474,10 @@ TEST_F(XlaBuilderTest, ReduceScatterWithToken) { } auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2"); - auto t = Parameter(&b, 2, ShapeUtil::MakeScalarShape(F32), "t"); ReplicaGroup group; group.add_replica_ids(0); group.add_replica_ids(1); - ReduceScatter(Tuple(&b, {x, x2, t}), to_apply, /*scatter_dimension=*/1, + ReduceScatter(Tuple(&b, {x, x2}), to_apply, /*scatter_dimension=*/1, /*shard_count=*/2, /*replica_groups=*/{group}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); @@ -490,8 +487,7 @@ TEST_F(XlaBuilderTest, ReduceScatterWithToken) { EXPECT_TRUE(ShapeUtil::Equal( root->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4, 8}), - ShapeUtil::MakeShape(F32, {16, 2}), - ShapeUtil::MakeScalarShape(F32)}))); + ShapeUtil::MakeShape(F32, {16, 2})}))); } TEST_F(XlaBuilderTest, AllToAll) { From 5890278fc16c9f900782d32a92d40ecf548aea85 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Mon, 9 Oct 2023 22:26:25 +0000 Subject: [PATCH 6/6] Fix missing parenthesis --- xla/service/all_gather_decomposer_test.cc | 4 ++-- xla/service/hlo_verifier.cc | 5 ----- xla/service/hlo_verifier_test.cc | 2 -- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/xla/service/all_gather_decomposer_test.cc b/xla/service/all_gather_decomposer_test.cc index f6a3102db13a6..4eb629a347311 100644 --- a/xla/service/all_gather_decomposer_test.cc +++ b/xla/service/all_gather_decomposer_test.cc @@ -162,7 +162,7 @@ HloModule module ENTRY entry { param0 = f32[10,20] parameter(0) param1 = f32[10,16] parameter(1) - ROOT ag = (f32[10,80], f32[10,64], f32[]) all-gather(param0, param1), + ROOT ag = (f32[10,80], f32[10,64]) all-gather(param0, param1), replica_groups={}, dimensions={1} } )"; @@ -180,7 +180,7 @@ ENTRY entry { op::Multiply(op::ReplicaId(), op::Constant()))), op::AllReduce(op::DynamicUpdateSlice( op::Broadcast(op::Constant()), op::Parameter(1), op::Constant(), - op::Multiply(op::ReplicaId(), op::Constant())))); + op::Multiply(op::ReplicaId(), op::Constant()))))); } } // namespace diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index 2f2467af0d03c..fd34bc4ba7a61 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -437,9 +437,6 @@ static Status CheckCommonAllGatherInvariants(HloInstruction* hlo, TF_RET_CHECK(ag->operand_count() >= 1); int64_t shard_count; - // There can be one token in the input Tuple. The token is a scalar or - // `token`. - bool token_encountered = false; for (int64_t i = 0; i < ag->operand_count(); ++i) { TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(i)->shape().rank()); @@ -527,8 +524,6 @@ Status ShapeVerifier::HandleReduceScatter(HloInstruction* hlo) { TF_RET_CHECK(ars->scatter_dimension() >= 0); TF_RET_CHECK(ars->operand_count() >= 1); - // There can be one token in the inputs. The token is a scalar or `token`. - bool token_encountered = false; for (int64_t i = 0; i < ars->operand_count(); ++i) { TF_RET_CHECK(ars->scatter_dimension() < ars->operand(i)->shape().rank()); diff --git a/xla/service/hlo_verifier_test.cc b/xla/service/hlo_verifier_test.cc index 41df9e6c064b9..8a4cbe59165ed 100644 --- a/xla/service/hlo_verifier_test.cc +++ b/xla/service/hlo_verifier_test.cc @@ -2391,7 +2391,6 @@ TEST_F(HloVerifierTest, ReduceScatterNonUniformGroups) { HasSubstr("Replica groups expected to be of uniform size")); } - TEST_F(HloVerifierTest, ScatterInvalidScatterDim) { const char* const hlo_string = R"( HloModule Module @@ -2419,7 +2418,6 @@ TEST_F(HloVerifierTest, ScatterInvalidScatterDim) { HasSubstr("Invalid scatter_dims_to_operand_dims mapping")); } - TEST_F(HloVerifierTest, VerifyBroadcastDimensionsOrder) { const char* const hlo = R"( HloModule module