From 0271e30070f4c83d7207fc09b47746976a80d7e1 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Thu, 28 Nov 2024 17:13:03 -0800 Subject: [PATCH] Clarify index parallel dims in gather/scatter instructions. The common connected dims between indices and gather output (scatter update) can be classified into 3 disjoint sets. 1. explicit batch dims 2. implicit batch dims 3. index passthrough dims Therefore, when partitioning gather/scatter along index passthrough dims, we do not consider explicit batch and implicit batch dims. The batch dims are considered in other partitioning methods. PiperOrigin-RevId: 701124558 --- .../xla/xla/hlo/utils/hlo_sharding_util.cc | 73 +++++++++++------ .../xla/xla/hlo/utils/hlo_sharding_util.h | 30 +++---- .../service/spmd/gather_scatter_handler.cc | 80 ++++++++----------- .../xla/service/spmd/spmd_partitioner_test.cc | 68 +++++++++++----- 4 files changed, 147 insertions(+), 104 deletions(-) diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index 03afe18eb5f91e..c1cf48609aa75b 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -1218,8 +1218,7 @@ HloSharding PropagateShardingAlongDimsAndReplicateOthers( } HloSharding GatherOutputShardingFromIndex(const HloSharding& index_sharding, - const HloInstruction* hlo, - bool consider_explicit_batch_dims) { + const HloInstruction* hlo) { CHECK(hlo->opcode() == HloOpcode::kGather); if (index_sharding.IsTileMaximal() || index_sharding.IsManual()) { return index_sharding; @@ -1229,16 +1228,14 @@ HloSharding GatherOutputShardingFromIndex(const HloSharding& index_sharding, const GatherScatterDims indices_output_dims = GetGatherConnectedDimsAcrossIndicesAndOutput( hlo->operand(1)->shape().rank(), dnums.index_vector_dim(), - dnums.start_indices_batching_dims(), hlo->shape().rank(), - dnums.offset_dims(), consider_explicit_batch_dims); + hlo->shape().rank(), dnums.offset_dims()); return PropagateShardingAlongDimsAndReplicateOthers( index_sharding, indices_output_dims.indices_dims, indices_output_dims.output_dims, hlo->shape().rank()); } HloSharding GatherIndexShardingFromOutput(const HloSharding& output_sharding, - const HloInstruction* hlo, - bool consider_explicit_batch_dims) { + const HloInstruction* hlo) { CHECK(hlo->opcode() == HloOpcode::kGather); if (output_sharding.IsTileMaximal() || output_sharding.IsManual()) { return output_sharding; @@ -1248,8 +1245,7 @@ HloSharding GatherIndexShardingFromOutput(const HloSharding& output_sharding, const GatherScatterDims indices_output_dims = GetGatherConnectedDimsAcrossIndicesAndOutput( hlo->operand(1)->shape().rank(), dnums.index_vector_dim(), - dnums.start_indices_batching_dims(), hlo->shape().rank(), - dnums.offset_dims(), consider_explicit_batch_dims); + hlo->shape().rank(), dnums.offset_dims()); return PropagateShardingAlongDimsAndReplicateOthers( output_sharding, indices_output_dims.output_dims, indices_output_dims.indices_dims, hlo->operand(1)->shape().rank()); @@ -1307,9 +1303,8 @@ HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) { return HloSharding::Tile(tile_assignment, hlo.sharding().metadata()); } -HloSharding ScatterIndexShardingFromUpdate(const HloSharding& update_sharding, - const HloScatterInstruction* scatter, - bool consider_explicit_batch_dims) { +HloSharding ScatterIndexShardingFromUpdate( + const HloSharding& update_sharding, const HloScatterInstruction* scatter) { if (update_sharding.IsTileMaximal() || update_sharding.IsManual()) { return update_sharding; } @@ -1318,18 +1313,16 @@ HloSharding ScatterIndexShardingFromUpdate(const HloSharding& update_sharding, const GatherScatterDims indices_update_dims = GetGatherConnectedDimsAcrossIndicesAndOutput( scatter->scatter_indices()->shape().rank(), dnums.index_vector_dim(), - dnums.scatter_indices_batching_dims(), scatter->scatter_updates()[0]->shape().rank(), - dnums.update_window_dims(), consider_explicit_batch_dims); + dnums.update_window_dims()); return PropagateShardingAlongDimsAndReplicateOthers( update_sharding, indices_update_dims.output_dims, indices_update_dims.indices_dims, scatter->scatter_indices()->shape().rank()); } -HloSharding ScatterUpdateShardingFromIndex(const HloSharding& index_sharding, - const HloScatterInstruction* scatter, - bool consider_explicit_batch_dims) { +HloSharding ScatterUpdateShardingFromIndex( + const HloSharding& index_sharding, const HloScatterInstruction* scatter) { if (index_sharding.IsTileMaximal() || index_sharding.IsManual()) { return index_sharding; } @@ -1338,9 +1331,8 @@ HloSharding ScatterUpdateShardingFromIndex(const HloSharding& index_sharding, const GatherScatterDims indices_update_dims = GetGatherConnectedDimsAcrossIndicesAndOutput( scatter->scatter_indices()->shape().rank(), dnums.index_vector_dim(), - dnums.scatter_indices_batching_dims(), scatter->scatter_updates()[0]->shape().rank(), - dnums.update_window_dims(), consider_explicit_batch_dims); + dnums.update_window_dims()); return PropagateShardingAlongDimsAndReplicateOthers( index_sharding, indices_update_dims.indices_dims, indices_update_dims.output_dims, @@ -2402,9 +2394,9 @@ absl::InlinedVector GetScatterOperandPassthroughUpdateDims( } GatherScatterDims GetGatherConnectedDimsAcrossIndicesAndOutput( - int64_t indices_rank, int64_t index_vector_dim, - absl::Span indices_batching_dims, int64_t output_rank, - absl::Span offset_dims, bool consider_explicit_batch_dims) { + int64_t indices_rank, int64_t index_vector_dim, int64_t output_rank, + absl::Span offset_dims, + absl::Span excluded_indices_dims) { GatherScatterDims result; for (int64_t output_dim = 0, indices_dim = 0; output_dim < output_rank; ++output_dim) { @@ -2415,8 +2407,7 @@ GatherScatterDims GetGatherConnectedDimsAcrossIndicesAndOutput( indices_dim++; } CHECK_LT(indices_dim, indices_rank); - if (consider_explicit_batch_dims || - !absl::c_linear_search(indices_batching_dims, indices_dim)) { + if (!absl::c_linear_search(excluded_indices_dims, indices_dim)) { result.indices_dims.push_back(indices_dim); result.output_dims.push_back(output_dim); } @@ -2425,6 +2416,42 @@ GatherScatterDims GetGatherConnectedDimsAcrossIndicesAndOutput( return result; } +GatherScatterDims GetGatherScatterIndexPassThroughDims( + const HloInstruction& hlo, const CallGraph& call_graph) { + if (const auto* gather = DynCast(&hlo)) { + const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers(); + absl::InlinedVector excluded_indices_dims{ + dnums.start_indices_batching_dims().begin(), + dnums.start_indices_batching_dims().end()}; + if (std::optional implicit_batch_dims = + GetGatherParallelBatchDims(hlo, call_graph)) { + absl::c_copy(implicit_batch_dims->indices_dims, + std::back_inserter(excluded_indices_dims)); + } + return GetGatherConnectedDimsAcrossIndicesAndOutput( + gather->operand(1)->shape().rank(), dnums.index_vector_dim(), + hlo.shape().rank(), dnums.offset_dims(), excluded_indices_dims); + } + + if (const auto* scatter = DynCast(&hlo)) { + const ScatterDimensionNumbers& dnums = scatter->scatter_dimension_numbers(); + absl::InlinedVector excluded_indices_dims{ + dnums.scatter_indices_batching_dims().begin(), + dnums.scatter_indices_batching_dims().end()}; + if (std::optional implicit_batch_dims = + GetScatterParallelBatchDims(hlo, call_graph)) { + absl::c_copy(implicit_batch_dims->indices_dims, + std::back_inserter(excluded_indices_dims)); + } + return GetGatherConnectedDimsAcrossIndicesAndOutput( + scatter->scatter_indices()->shape().rank(), dnums.index_vector_dim(), + scatter->scatter_updates()[0]->shape().rank(), + dnums.update_window_dims(), excluded_indices_dims); + } + + LOG(FATAL) << "Expected gather or scatter, got " << hlo.ToString(); +} + HloSharding InferGatherScatterParallelShardingFromOperandSharding( const HloSharding& operand_sharding, const Shape& shape, absl::Span output_aligned_operand_parallel_dims, diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h index c680ff39af27cc..daceedf8c17f65 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h @@ -168,15 +168,13 @@ bool ContainsTileSharding(const HloModule& module); // Returns the preferred output sharding for a gather op based on the sharding // of the indices. -HloSharding GatherOutputShardingFromIndex( - const HloSharding& index_sharding, const HloInstruction* hlo, - bool consider_explicit_batch_dims = true); +HloSharding GatherOutputShardingFromIndex(const HloSharding& index_sharding, + const HloInstruction* hlo); // Returns the preferred index sharding for a gather op based on the sharding // of the output. -HloSharding GatherIndexShardingFromOutput( - const HloSharding& output_sharding, const HloInstruction* hlo, - bool consider_explicit_batch_dims = true); +HloSharding GatherIndexShardingFromOutput(const HloSharding& output_sharding, + const HloInstruction* hlo); // Returns a new HloSharding for a gather op so that only non offset dimensions // are sharded. Assume "result" is returned by this function. It is ensured that @@ -187,14 +185,12 @@ HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo); // Returns the preferred index sharding for a scatter op based on the sharding // of the data. HloSharding ScatterIndexShardingFromUpdate( - const HloSharding& update_sharding, const HloScatterInstruction* scatter, - bool consider_explicit_batch_dims = true); + const HloSharding& update_sharding, const HloScatterInstruction* scatter); // Returns the preferred data sharding for a scatter op based on the sharding // of the index. HloSharding ScatterUpdateShardingFromIndex( - const HloSharding& index_sharding, const HloScatterInstruction* scatter, - bool consider_explicit_batch_dims = true); + const HloSharding& index_sharding, const HloScatterInstruction* scatter); // Returns a new index sharding for a scatter op so that we only shard on first // "number of scatter_window_dims" dimensions. Assume "result" is returned by @@ -383,12 +379,18 @@ absl::InlinedVector GetScatterOperandPassthroughUpdateDims( absl::Span slice_sizes); // Returns the dims along which sharding can be propagated between indices and -// output/update for gather/scatter operations. +// output/update for gather/scatter operations. `excluded_indices_dims` are +// excluded from the result. GatherScatterDims GetGatherConnectedDimsAcrossIndicesAndOutput( - int64_t indices_rank, int64_t index_vector_dim, - absl::Span indices_batching_dims, int64_t output_rank, + int64_t indices_rank, int64_t index_vector_dim, int64_t output_rank, absl::Span offset_dims, - bool consider_explicit_batch_dims = true); + absl::Span excluded_indices_dims = {}); + +// Returns the index pass-through dimensions, which are defined by +// GetGatherConnectedDimsAcrossIndicesAndOutput - ExplictBatchDims - +// GetGatherScatterBatchParallelDims. +GatherScatterDims GetGatherScatterIndexPassThroughDims( + const HloInstruction& hlo, const CallGraph& call_graph); // Infer output sharding on index parallel dimensions for gather/scatter from // gather operand/indices or scatter operands/indices/updates. diff --git a/third_party/xla/xla/service/spmd/gather_scatter_handler.cc b/third_party/xla/xla/service/spmd/gather_scatter_handler.cc index 87e9bffc0cfda3..3011c56a7109b1 100644 --- a/third_party/xla/xla/service/spmd/gather_scatter_handler.cc +++ b/third_party/xla/xla/service/spmd/gather_scatter_handler.cc @@ -289,21 +289,16 @@ absl::StatusOr PartitionGatherIndexPassthroughDimensions( } }; - const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers(); SpmdBuilder* b = visitor->builder(); - const hlo_sharding_util::GatherScatterDims indices_output_dims = - hlo_sharding_util::GetGatherConnectedDimsAcrossIndicesAndOutput( - indices.rank(), dnums.index_vector_dim(), - dnums.start_indices_batching_dims(), output_shape.rank(), - dnums.offset_dims(), /*consider_explicit_batch_dims=*/false); - const int64_t num_groups = - indices.sharding().NumTiles(indices_output_dims.indices_dims); - const int64_t num_tiles = indices.sharding().TotalNumTiles(); + const hlo_sharding_util::GatherScatterDims index_passthrough_dims = + hlo_sharding_util::GetGatherScatterIndexPassThroughDims( + *gather, visitor->call_graph()); // Compute output sharding. HloSharding passthrough_sharding = - hlo_sharding_util::GatherOutputShardingFromIndex( - indices.sharding(), gather, /*consider_explicit_batch_dims=*/false); + hlo_sharding_util::PropagateShardingAlongDimsAndReplicateOthers( + indices.sharding(), index_passthrough_dims.indices_dims, + index_passthrough_dims.output_dims, gather->shape().rank()); if (passthrough_sharding.IsTileMaximal()) { return nullptr; } @@ -311,13 +306,16 @@ absl::StatusOr PartitionGatherIndexPassthroughDimensions( &passthrough_sharding); // Group shardings on index pass-through dimensions. const GroupedSharding output_grouped = hlo_sharding_util::GroupShardingOnDims( - passthrough_sharding, indices_output_dims.output_dims); - const GroupedSharding indices_grouped = - AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( - indices.sharding(), indices_output_dims.indices_dims), - output_grouped); + passthrough_sharding, index_passthrough_dims.output_dims); + const GroupedSharding indices_grouped = AlignGroupsWith( + hlo_sharding_util::GroupShardingOnDims( + indices.sharding(), index_passthrough_dims.indices_dims), + output_grouped); // See if we can group partially replicated dimensions from the operand // otherwise replicate it. + const int64_t num_groups = + indices.sharding().NumTiles(index_passthrough_dims.indices_dims); + const int64_t num_tiles = indices.sharding().TotalNumTiles(); const GroupedSharding operand_grouped = AlignGroupsWith( hlo_sharding_util::GroupShardingOnReplicatedDim( operand.sharding(), num_groups, num_tiles, operand.rank(), @@ -821,9 +819,7 @@ std::pair GatherPartitionMethodCostModel( const PartitionedHlo& indices, const Shape& output_shape, const HloSharding& output_sharding, absl::Span batch_dims, absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { - decltype(PartitionGather)* zero_cost_method = - GetGatherPartitionMethod(gather_partition_method); - if (partition_method == zero_cost_method) { + if (partition_method == GetGatherPartitionMethod(gather_partition_method)) { // Always prioritize the user's chosen partitioning, and assume it has zero // cost. // This defaults to IndexParallel. @@ -1224,7 +1220,7 @@ absl::StatusOr PartitionScatterParallelDimensions( .hlo(); } -// Partition a scatter over a indices dimensions that are cosidered parallel +// Partition a scatter over a indices dimensions that are considered parallel // (which means that the indices access the operand in a monotonically // increasing way across the respective operand dimension referenced by the // index). @@ -1246,10 +1242,8 @@ absl::StatusOr PartitionScatterIndexParallelDimensions( slice_sizes, visitor, allow_recursive, *parallel_dims, true); } -// Partition a scatter over a indices dimensions that are cosidered parallel -// (which means that the indices access the operand in a monotonically -// increasing way across the respective operand dimension referenced by the -// index). +// Partition a scatter over explicit batch dimensions defined in +// input_batching_dims and scatter_indices_batching_dims. absl::StatusOr PartitionScatterExplicitBatchDimensions( const HloScatterInstruction* scatter, std::vector operands, PartitionedHlo indices, std::vector updates, @@ -1383,38 +1377,37 @@ absl::StatusOr PartitionScatterIndexPassthroughDimensions( }; SpmdBuilder* b = visitor->builder(); - const auto& dnums = scatter->scatter_dimension_numbers(); // Parse non-variadic computation only. Variadic case will be replicated. const HloSharding original_indices_sharding = indices.sharding(); - const hlo_sharding_util::GatherScatterDims indices_update_dims = - hlo_sharding_util::GetGatherConnectedDimsAcrossIndicesAndOutput( - indices.rank(), dnums.index_vector_dim(), - dnums.scatter_indices_batching_dims(), updates[0].rank(), - dnums.update_window_dims(), /*consider_explicit_batch_dims=*/false); - const int64_t num_groups = - indices.sharding().NumTiles(indices_update_dims.indices_dims); - const int64_t num_tiles = indices.sharding().TotalNumTiles(); + const hlo_sharding_util::GatherScatterDims index_passthrough_dims = + hlo_sharding_util::GetGatherScatterIndexPassThroughDims( + *scatter, visitor->call_graph()); HloSharding passthrough_sharding = - hlo_sharding_util::ScatterUpdateShardingFromIndex( - indices.sharding(), scatter, /*consider_explicit_batch_dims=*/false); + hlo_sharding_util::PropagateShardingAlongDimsAndReplicateOthers( + indices.sharding(), index_passthrough_dims.indices_dims, + index_passthrough_dims.output_dims, + scatter->scatter_updates()[0]->shape().rank()); if (passthrough_sharding.IsTileMaximal()) { return nullptr; } hlo_sharding_util::MergeShardingIfCompatible(updates[0].sharding(), &passthrough_sharding); const GroupedSharding update_grouped = hlo_sharding_util::GroupShardingOnDims( - passthrough_sharding, indices_update_dims.output_dims); + passthrough_sharding, index_passthrough_dims.output_dims); // See if we can group partially replicated dimensions from the operand // otherwise replicate it. + const int64_t num_groups = + indices.sharding().NumTiles(index_passthrough_dims.indices_dims); + const int64_t num_tiles = indices.sharding().TotalNumTiles(); const GroupedSharding operand_grouped = AlignGroupsWith( hlo_sharding_util::GroupShardingOnReplicatedDim( operands[0].sharding(), num_groups, num_tiles, operands[0].rank(), ScatterOperandDimsByPriority(operands[0], scatter, slice_sizes)), update_grouped); - const GroupedSharding indices_grouped = - AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( - indices.sharding(), indices_update_dims.indices_dims), - update_grouped); + const GroupedSharding indices_grouped = AlignGroupsWith( + hlo_sharding_util::GroupShardingOnDims( + indices.sharding(), index_passthrough_dims.indices_dims), + update_grouped); const GroupedSharding& output_grouped = operand_grouped; PartitionedHlo per_group_operand = PerGroupPartitionedHlo(operands[0], operand_grouped, b, clean_ups); @@ -1487,7 +1480,7 @@ absl::StatusOr PartitionScatterIndexPassthroughDimensions( slice_sizes, visitor, allow_recursive)); auto all_reduce = operands[0].state().partitioner->AllReduceAlongShardingDims( b, pscatter, original_indices_sharding, indices.state().next_channel_id, - indices_update_dims.indices_dims, + index_passthrough_dims.indices_dims, operands[0].state().collective_ops_creator, scatter->to_apply()); all_reduce->set_sharding(hlo_sharding_util::UngroupSharding(output_grouped)); if (allow_recursive) { @@ -1655,10 +1648,7 @@ std::pair ScatterPartitionMethodCostModel( const std::vector& updates, const Shape& output_shape, const HloSharding& output_sharding, absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { - decltype(PartitionScatter)* zero_cost_method = - GetScatterPartitionMethod(scatter_partition_method); - - if (partition_method == zero_cost_method) { + if (partition_method == GetScatterPartitionMethod(scatter_partition_method)) { // Always prioritize index parallel partitioning, and assume it has zero // cost. return {0, 0}; diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc index 37afd04d9d0d0f..c119d1c7b488de 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc @@ -11572,17 +11572,29 @@ ENTRY %module { collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0, slice_sizes={1,1,2,2}, sharding={replicated} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - PartitionComputation(hlo_string, /*num_devices=*/8)); - VLOG(1) << module->ToString(); - const auto root = module->entry_computation()->root_instruction(); - auto operand = AllOf(op::Shape("s32[2,4,2,2]"), op::Parameter()); - auto indices = AllOf(op::Shape("s32[2,2,2]"), op::Subtract()); - auto gather = AllOf(op::Shape("s32[2,2,2,2]"), op::Gather(operand, indices)); - EXPECT_THAT( - root, op::AllReduce(op::DynamicUpdateSlice( - _, op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)), - _, _, _, _))); + for (const PartitioningMethod& method : + {PartitioningMethod::kIndexParallel, + PartitioningMethod::kIndexPassthrough}) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/8, + /*conv_halo_exchange_always_on_lhs=*/true, + /*choose_faster_windowed_einsum=*/false, + /*unroll_windowed_einsum=*/false, + /*bidirectional_windowed_einsum=*/false, + /*threshold_for_windowed_einsum_mib=*/-1, method, + method)); + VLOG(1) << module->ToString(); + auto operand = AllOf(op::Shape("s32[2,4,2,2]"), op::Parameter()); + auto indices = AllOf(op::Shape("s32[2,2,2]"), op::Subtract()); + auto gather = + AllOf(op::Shape("s32[2,2,2,2]"), op::Gather(operand, indices)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + _, op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)), _, + _, _, _))); + } } TEST_P(SpmdPartitioningTest, @@ -12553,17 +12565,29 @@ ENTRY %module { scatter_dims_to_operand_dims={1,0}, index_vector_dim=0, sharding={replicated} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - PartitionComputation(hlo_string, /*num_devices=*/8)); - VLOG(1) << module->ToString(); - const auto root = module->entry_computation()->root_instruction(); - auto operand = AllOf(op::Shape("s32[2,4,2,2]"), op::Select()); - auto indices = AllOf(op::Shape("s32[2,2,2]"), op::Subtract()); - auto update = AllOf(op::Shape("s32[2,2,2,2]"), op::DynamicSlice()); - auto scatter = - AllOf(op::Shape("s32[2,4,2,2]"), op::Scatter(operand, indices, update)); - EXPECT_THAT(root, op::AllReduce(op::DynamicUpdateSlice( - _, op::AllReduce(scatter), _, _, _, _))); + + for (const PartitioningMethod& method : + {PartitioningMethod::kIndexParallel, + PartitioningMethod::kIndexPassthrough}) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/8, + /*conv_halo_exchange_always_on_lhs=*/true, + /*choose_faster_windowed_einsum=*/false, + /*unroll_windowed_einsum=*/false, + /*bidirectional_windowed_einsum=*/false, + /*threshold_for_windowed_einsum_mib=*/-1, method, + method)); + VLOG(1) << module->ToString(); + auto operand = AllOf(op::Shape("s32[2,4,2,2]"), op::Select()); + auto indices = AllOf(op::Shape("s32[2,2,2]"), op::Subtract()); + auto update = AllOf(op::Shape("s32[2,2,2,2]"), op::DynamicSlice()); + auto scatter = + AllOf(op::Shape("s32[2,4,2,2]"), op::Scatter(operand, indices, update)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice(_, op::AllReduce(scatter), + _, _, _, _))); + } } TEST_P(SpmdPartitioningTest,