Skip to content

Commit

Permalink
Clarify index parallel dims in gather/scatter instructions.
Browse files Browse the repository at this point in the history
The common connected dims between indices and gather output (scatter update) can be classified into 3 disjoint sets.
1. explicit batch dims
2. implicit batch dims
3. index passthrough dims

Therefore, when partitioning gather/scatter along index passthrough dims, we do not consider explicit batch and implicit batch dims. The batch dims are considered in other partitioning methods.

PiperOrigin-RevId: 701124558
  • Loading branch information
ZixuanJiang authored and tensorflower-gardener committed Nov 29, 2024
1 parent 946f70b commit 0271e30
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 104 deletions.
73 changes: 50 additions & 23 deletions third_party/xla/xla/hlo/utils/hlo_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1218,8 +1218,7 @@ HloSharding PropagateShardingAlongDimsAndReplicateOthers(
}

HloSharding GatherOutputShardingFromIndex(const HloSharding& index_sharding,
const HloInstruction* hlo,
bool consider_explicit_batch_dims) {
const HloInstruction* hlo) {
CHECK(hlo->opcode() == HloOpcode::kGather);
if (index_sharding.IsTileMaximal() || index_sharding.IsManual()) {
return index_sharding;
Expand All @@ -1229,16 +1228,14 @@ HloSharding GatherOutputShardingFromIndex(const HloSharding& index_sharding,
const GatherScatterDims indices_output_dims =
GetGatherConnectedDimsAcrossIndicesAndOutput(
hlo->operand(1)->shape().rank(), dnums.index_vector_dim(),
dnums.start_indices_batching_dims(), hlo->shape().rank(),
dnums.offset_dims(), consider_explicit_batch_dims);
hlo->shape().rank(), dnums.offset_dims());
return PropagateShardingAlongDimsAndReplicateOthers(
index_sharding, indices_output_dims.indices_dims,
indices_output_dims.output_dims, hlo->shape().rank());
}

HloSharding GatherIndexShardingFromOutput(const HloSharding& output_sharding,
const HloInstruction* hlo,
bool consider_explicit_batch_dims) {
const HloInstruction* hlo) {
CHECK(hlo->opcode() == HloOpcode::kGather);
if (output_sharding.IsTileMaximal() || output_sharding.IsManual()) {
return output_sharding;
Expand All @@ -1248,8 +1245,7 @@ HloSharding GatherIndexShardingFromOutput(const HloSharding& output_sharding,
const GatherScatterDims indices_output_dims =
GetGatherConnectedDimsAcrossIndicesAndOutput(
hlo->operand(1)->shape().rank(), dnums.index_vector_dim(),
dnums.start_indices_batching_dims(), hlo->shape().rank(),
dnums.offset_dims(), consider_explicit_batch_dims);
hlo->shape().rank(), dnums.offset_dims());
return PropagateShardingAlongDimsAndReplicateOthers(
output_sharding, indices_output_dims.output_dims,
indices_output_dims.indices_dims, hlo->operand(1)->shape().rank());
Expand Down Expand Up @@ -1307,9 +1303,8 @@ HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
return HloSharding::Tile(tile_assignment, hlo.sharding().metadata());
}

HloSharding ScatterIndexShardingFromUpdate(const HloSharding& update_sharding,
const HloScatterInstruction* scatter,
bool consider_explicit_batch_dims) {
HloSharding ScatterIndexShardingFromUpdate(
const HloSharding& update_sharding, const HloScatterInstruction* scatter) {
if (update_sharding.IsTileMaximal() || update_sharding.IsManual()) {
return update_sharding;
}
Expand All @@ -1318,18 +1313,16 @@ HloSharding ScatterIndexShardingFromUpdate(const HloSharding& update_sharding,
const GatherScatterDims indices_update_dims =
GetGatherConnectedDimsAcrossIndicesAndOutput(
scatter->scatter_indices()->shape().rank(), dnums.index_vector_dim(),
dnums.scatter_indices_batching_dims(),
scatter->scatter_updates()[0]->shape().rank(),
dnums.update_window_dims(), consider_explicit_batch_dims);
dnums.update_window_dims());
return PropagateShardingAlongDimsAndReplicateOthers(
update_sharding, indices_update_dims.output_dims,
indices_update_dims.indices_dims,
scatter->scatter_indices()->shape().rank());
}

HloSharding ScatterUpdateShardingFromIndex(const HloSharding& index_sharding,
const HloScatterInstruction* scatter,
bool consider_explicit_batch_dims) {
HloSharding ScatterUpdateShardingFromIndex(
const HloSharding& index_sharding, const HloScatterInstruction* scatter) {
if (index_sharding.IsTileMaximal() || index_sharding.IsManual()) {
return index_sharding;
}
Expand All @@ -1338,9 +1331,8 @@ HloSharding ScatterUpdateShardingFromIndex(const HloSharding& index_sharding,
const GatherScatterDims indices_update_dims =
GetGatherConnectedDimsAcrossIndicesAndOutput(
scatter->scatter_indices()->shape().rank(), dnums.index_vector_dim(),
dnums.scatter_indices_batching_dims(),
scatter->scatter_updates()[0]->shape().rank(),
dnums.update_window_dims(), consider_explicit_batch_dims);
dnums.update_window_dims());
return PropagateShardingAlongDimsAndReplicateOthers(
index_sharding, indices_update_dims.indices_dims,
indices_update_dims.output_dims,
Expand Down Expand Up @@ -2402,9 +2394,9 @@ absl::InlinedVector<int64_t, 1> GetScatterOperandPassthroughUpdateDims(
}

GatherScatterDims GetGatherConnectedDimsAcrossIndicesAndOutput(
int64_t indices_rank, int64_t index_vector_dim,
absl::Span<const int64_t> indices_batching_dims, int64_t output_rank,
absl::Span<const int64_t> offset_dims, bool consider_explicit_batch_dims) {
int64_t indices_rank, int64_t index_vector_dim, int64_t output_rank,
absl::Span<const int64_t> offset_dims,
absl::Span<const int64_t> excluded_indices_dims) {
GatherScatterDims result;
for (int64_t output_dim = 0, indices_dim = 0; output_dim < output_rank;
++output_dim) {
Expand All @@ -2415,8 +2407,7 @@ GatherScatterDims GetGatherConnectedDimsAcrossIndicesAndOutput(
indices_dim++;
}
CHECK_LT(indices_dim, indices_rank);
if (consider_explicit_batch_dims ||
!absl::c_linear_search(indices_batching_dims, indices_dim)) {
if (!absl::c_linear_search(excluded_indices_dims, indices_dim)) {
result.indices_dims.push_back(indices_dim);
result.output_dims.push_back(output_dim);
}
Expand All @@ -2425,6 +2416,42 @@ GatherScatterDims GetGatherConnectedDimsAcrossIndicesAndOutput(
return result;
}

GatherScatterDims GetGatherScatterIndexPassThroughDims(
const HloInstruction& hlo, const CallGraph& call_graph) {
if (const auto* gather = DynCast<HloGatherInstruction>(&hlo)) {
const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers();
absl::InlinedVector<int64_t, 1> excluded_indices_dims{
dnums.start_indices_batching_dims().begin(),
dnums.start_indices_batching_dims().end()};
if (std::optional<GatherScatterDims> implicit_batch_dims =
GetGatherParallelBatchDims(hlo, call_graph)) {
absl::c_copy(implicit_batch_dims->indices_dims,
std::back_inserter(excluded_indices_dims));
}
return GetGatherConnectedDimsAcrossIndicesAndOutput(
gather->operand(1)->shape().rank(), dnums.index_vector_dim(),
hlo.shape().rank(), dnums.offset_dims(), excluded_indices_dims);
}

if (const auto* scatter = DynCast<HloScatterInstruction>(&hlo)) {
const ScatterDimensionNumbers& dnums = scatter->scatter_dimension_numbers();
absl::InlinedVector<int64_t, 1> excluded_indices_dims{
dnums.scatter_indices_batching_dims().begin(),
dnums.scatter_indices_batching_dims().end()};
if (std::optional<GatherScatterDims> implicit_batch_dims =
GetScatterParallelBatchDims(hlo, call_graph)) {
absl::c_copy(implicit_batch_dims->indices_dims,
std::back_inserter(excluded_indices_dims));
}
return GetGatherConnectedDimsAcrossIndicesAndOutput(
scatter->scatter_indices()->shape().rank(), dnums.index_vector_dim(),
scatter->scatter_updates()[0]->shape().rank(),
dnums.update_window_dims(), excluded_indices_dims);
}

LOG(FATAL) << "Expected gather or scatter, got " << hlo.ToString();
}

HloSharding InferGatherScatterParallelShardingFromOperandSharding(
const HloSharding& operand_sharding, const Shape& shape,
absl::Span<const int64_t> output_aligned_operand_parallel_dims,
Expand Down
30 changes: 16 additions & 14 deletions third_party/xla/xla/hlo/utils/hlo_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,13 @@ bool ContainsTileSharding(const HloModule& module);

// Returns the preferred output sharding for a gather op based on the sharding
// of the indices.
HloSharding GatherOutputShardingFromIndex(
const HloSharding& index_sharding, const HloInstruction* hlo,
bool consider_explicit_batch_dims = true);
HloSharding GatherOutputShardingFromIndex(const HloSharding& index_sharding,
const HloInstruction* hlo);

// Returns the preferred index sharding for a gather op based on the sharding
// of the output.
HloSharding GatherIndexShardingFromOutput(
const HloSharding& output_sharding, const HloInstruction* hlo,
bool consider_explicit_batch_dims = true);
HloSharding GatherIndexShardingFromOutput(const HloSharding& output_sharding,
const HloInstruction* hlo);

// Returns a new HloSharding for a gather op so that only non offset dimensions
// are sharded. Assume "result" is returned by this function. It is ensured that
Expand All @@ -187,14 +185,12 @@ HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo);
// Returns the preferred index sharding for a scatter op based on the sharding
// of the data.
HloSharding ScatterIndexShardingFromUpdate(
const HloSharding& update_sharding, const HloScatterInstruction* scatter,
bool consider_explicit_batch_dims = true);
const HloSharding& update_sharding, const HloScatterInstruction* scatter);

// Returns the preferred data sharding for a scatter op based on the sharding
// of the index.
HloSharding ScatterUpdateShardingFromIndex(
const HloSharding& index_sharding, const HloScatterInstruction* scatter,
bool consider_explicit_batch_dims = true);
const HloSharding& index_sharding, const HloScatterInstruction* scatter);

// Returns a new index sharding for a scatter op so that we only shard on first
// "number of scatter_window_dims" dimensions. Assume "result" is returned by
Expand Down Expand Up @@ -383,12 +379,18 @@ absl::InlinedVector<int64_t, 1> GetScatterOperandPassthroughUpdateDims(
absl::Span<const int64_t> slice_sizes);

// Returns the dims along which sharding can be propagated between indices and
// output/update for gather/scatter operations.
// output/update for gather/scatter operations. `excluded_indices_dims` are
// excluded from the result.
GatherScatterDims GetGatherConnectedDimsAcrossIndicesAndOutput(
int64_t indices_rank, int64_t index_vector_dim,
absl::Span<const int64_t> indices_batching_dims, int64_t output_rank,
int64_t indices_rank, int64_t index_vector_dim, int64_t output_rank,
absl::Span<const int64_t> offset_dims,
bool consider_explicit_batch_dims = true);
absl::Span<const int64_t> excluded_indices_dims = {});

// Returns the index pass-through dimensions, which are defined by
// GetGatherConnectedDimsAcrossIndicesAndOutput - ExplictBatchDims -
// GetGatherScatterBatchParallelDims.
GatherScatterDims GetGatherScatterIndexPassThroughDims(
const HloInstruction& hlo, const CallGraph& call_graph);

// Infer output sharding on index parallel dimensions for gather/scatter from
// gather operand/indices or scatter operands/indices/updates.
Expand Down
80 changes: 35 additions & 45 deletions third_party/xla/xla/service/spmd/gather_scatter_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,35 +289,33 @@ absl::StatusOr<HloInstruction*> PartitionGatherIndexPassthroughDimensions(
}
};

const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers();
SpmdBuilder* b = visitor->builder();

const hlo_sharding_util::GatherScatterDims indices_output_dims =
hlo_sharding_util::GetGatherConnectedDimsAcrossIndicesAndOutput(
indices.rank(), dnums.index_vector_dim(),
dnums.start_indices_batching_dims(), output_shape.rank(),
dnums.offset_dims(), /*consider_explicit_batch_dims=*/false);
const int64_t num_groups =
indices.sharding().NumTiles(indices_output_dims.indices_dims);
const int64_t num_tiles = indices.sharding().TotalNumTiles();
const hlo_sharding_util::GatherScatterDims index_passthrough_dims =
hlo_sharding_util::GetGatherScatterIndexPassThroughDims(
*gather, visitor->call_graph());
// Compute output sharding.
HloSharding passthrough_sharding =
hlo_sharding_util::GatherOutputShardingFromIndex(
indices.sharding(), gather, /*consider_explicit_batch_dims=*/false);
hlo_sharding_util::PropagateShardingAlongDimsAndReplicateOthers(
indices.sharding(), index_passthrough_dims.indices_dims,
index_passthrough_dims.output_dims, gather->shape().rank());
if (passthrough_sharding.IsTileMaximal()) {
return nullptr;
}
hlo_sharding_util::MergeShardingIfCompatible(output_sharding,
&passthrough_sharding);
// Group shardings on index pass-through dimensions.
const GroupedSharding output_grouped = hlo_sharding_util::GroupShardingOnDims(
passthrough_sharding, indices_output_dims.output_dims);
const GroupedSharding indices_grouped =
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
indices.sharding(), indices_output_dims.indices_dims),
output_grouped);
passthrough_sharding, index_passthrough_dims.output_dims);
const GroupedSharding indices_grouped = AlignGroupsWith(
hlo_sharding_util::GroupShardingOnDims(
indices.sharding(), index_passthrough_dims.indices_dims),
output_grouped);
// See if we can group partially replicated dimensions from the operand
// otherwise replicate it.
const int64_t num_groups =
indices.sharding().NumTiles(index_passthrough_dims.indices_dims);
const int64_t num_tiles = indices.sharding().TotalNumTiles();
const GroupedSharding operand_grouped = AlignGroupsWith(
hlo_sharding_util::GroupShardingOnReplicatedDim(
operand.sharding(), num_groups, num_tiles, operand.rank(),
Expand Down Expand Up @@ -821,9 +819,7 @@ std::pair<int64_t, int64_t> GatherPartitionMethodCostModel(
const PartitionedHlo& indices, const Shape& output_shape,
const HloSharding& output_sharding, absl::Span<const int64_t> batch_dims,
absl::Span<const int64_t> slice_sizes, SpmdPartitioningVisitor* visitor) {
decltype(PartitionGather)* zero_cost_method =
GetGatherPartitionMethod(gather_partition_method);
if (partition_method == zero_cost_method) {
if (partition_method == GetGatherPartitionMethod(gather_partition_method)) {
// Always prioritize the user's chosen partitioning, and assume it has zero
// cost.
// This defaults to IndexParallel.
Expand Down Expand Up @@ -1224,7 +1220,7 @@ absl::StatusOr<HloInstruction*> PartitionScatterParallelDimensions(
.hlo();
}

// Partition a scatter over a indices dimensions that are cosidered parallel
// Partition a scatter over a indices dimensions that are considered parallel
// (which means that the indices access the operand in a monotonically
// increasing way across the respective operand dimension referenced by the
// index).
Expand All @@ -1246,10 +1242,8 @@ absl::StatusOr<HloInstruction*> PartitionScatterIndexParallelDimensions(
slice_sizes, visitor, allow_recursive, *parallel_dims, true);
}

// Partition a scatter over a indices dimensions that are cosidered parallel
// (which means that the indices access the operand in a monotonically
// increasing way across the respective operand dimension referenced by the
// index).
// Partition a scatter over explicit batch dimensions defined in
// input_batching_dims and scatter_indices_batching_dims.
absl::StatusOr<HloInstruction*> PartitionScatterExplicitBatchDimensions(
const HloScatterInstruction* scatter, std::vector<PartitionedHlo> operands,
PartitionedHlo indices, std::vector<PartitionedHlo> updates,
Expand Down Expand Up @@ -1383,38 +1377,37 @@ absl::StatusOr<HloInstruction*> PartitionScatterIndexPassthroughDimensions(
};

SpmdBuilder* b = visitor->builder();
const auto& dnums = scatter->scatter_dimension_numbers();
// Parse non-variadic computation only. Variadic case will be replicated.
const HloSharding original_indices_sharding = indices.sharding();
const hlo_sharding_util::GatherScatterDims indices_update_dims =
hlo_sharding_util::GetGatherConnectedDimsAcrossIndicesAndOutput(
indices.rank(), dnums.index_vector_dim(),
dnums.scatter_indices_batching_dims(), updates[0].rank(),
dnums.update_window_dims(), /*consider_explicit_batch_dims=*/false);
const int64_t num_groups =
indices.sharding().NumTiles(indices_update_dims.indices_dims);
const int64_t num_tiles = indices.sharding().TotalNumTiles();
const hlo_sharding_util::GatherScatterDims index_passthrough_dims =
hlo_sharding_util::GetGatherScatterIndexPassThroughDims(
*scatter, visitor->call_graph());
HloSharding passthrough_sharding =
hlo_sharding_util::ScatterUpdateShardingFromIndex(
indices.sharding(), scatter, /*consider_explicit_batch_dims=*/false);
hlo_sharding_util::PropagateShardingAlongDimsAndReplicateOthers(
indices.sharding(), index_passthrough_dims.indices_dims,
index_passthrough_dims.output_dims,
scatter->scatter_updates()[0]->shape().rank());
if (passthrough_sharding.IsTileMaximal()) {
return nullptr;
}
hlo_sharding_util::MergeShardingIfCompatible(updates[0].sharding(),
&passthrough_sharding);
const GroupedSharding update_grouped = hlo_sharding_util::GroupShardingOnDims(
passthrough_sharding, indices_update_dims.output_dims);
passthrough_sharding, index_passthrough_dims.output_dims);
// See if we can group partially replicated dimensions from the operand
// otherwise replicate it.
const int64_t num_groups =
indices.sharding().NumTiles(index_passthrough_dims.indices_dims);
const int64_t num_tiles = indices.sharding().TotalNumTiles();
const GroupedSharding operand_grouped = AlignGroupsWith(
hlo_sharding_util::GroupShardingOnReplicatedDim(
operands[0].sharding(), num_groups, num_tiles, operands[0].rank(),
ScatterOperandDimsByPriority(operands[0], scatter, slice_sizes)),
update_grouped);
const GroupedSharding indices_grouped =
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
indices.sharding(), indices_update_dims.indices_dims),
update_grouped);
const GroupedSharding indices_grouped = AlignGroupsWith(
hlo_sharding_util::GroupShardingOnDims(
indices.sharding(), index_passthrough_dims.indices_dims),
update_grouped);
const GroupedSharding& output_grouped = operand_grouped;
PartitionedHlo per_group_operand =
PerGroupPartitionedHlo(operands[0], operand_grouped, b, clean_ups);
Expand Down Expand Up @@ -1487,7 +1480,7 @@ absl::StatusOr<HloInstruction*> PartitionScatterIndexPassthroughDimensions(
slice_sizes, visitor, allow_recursive));
auto all_reduce = operands[0].state().partitioner->AllReduceAlongShardingDims(
b, pscatter, original_indices_sharding, indices.state().next_channel_id,
indices_update_dims.indices_dims,
index_passthrough_dims.indices_dims,
operands[0].state().collective_ops_creator, scatter->to_apply());
all_reduce->set_sharding(hlo_sharding_util::UngroupSharding(output_grouped));
if (allow_recursive) {
Expand Down Expand Up @@ -1655,10 +1648,7 @@ std::pair<int64_t, int64_t> ScatterPartitionMethodCostModel(
const std::vector<PartitionedHlo>& updates, const Shape& output_shape,
const HloSharding& output_sharding, absl::Span<const int64_t> slice_sizes,
SpmdPartitioningVisitor* visitor) {
decltype(PartitionScatter)* zero_cost_method =
GetScatterPartitionMethod(scatter_partition_method);

if (partition_method == zero_cost_method) {
if (partition_method == GetScatterPartitionMethod(scatter_partition_method)) {
// Always prioritize index parallel partitioning, and assume it has zero
// cost.
return {0, 0};
Expand Down
Loading

0 comments on commit 0271e30

Please sign in to comment.