Skip to content

Commit

Permalink
Add an expander pattern for GatherOp/ScatterOp with batching dims (op…
Browse files Browse the repository at this point in the history
  • Loading branch information
abhigunj authored Sep 20, 2024
1 parent 728d625 commit 9bb28f8
Show file tree
Hide file tree
Showing 2 changed files with 309 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,170 @@ func.func @tan_op_complex(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>) -> (tensor
%3 = stablehlo.imag %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
func.return %2, %3 : tensor<4xf64>, tensor<4xf64>
}

// -----

// CHECK-LABEL: @gather_with_batching_dims
// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32>
// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{
// CHECK-SAME: dimension_numbers = #stablehlo.gather<
// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3],
// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>,
// CHECK-SAME: indices_are_sorted = false,
// CHECK-SAME: slice_sizes = array<i64: 1, 1, 1, 1, 8>
// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>) -> tensor<4x3x5x8xi32>
// CHECK-NEXT: return %[[gather]] : tensor<4x3x5x8xi32>
func.func @gather_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> {
// CHECK-NO-DOWNGRADE: operand_batching_dims = [0, 2]
// CHECK-NO-DOWNGRADE: start_indices_batching_dims = [1, 0]
%0 = "stablehlo.gather"(%arg0, %arg1) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3],
collapsed_slice_dims = [1, 3],
operand_batching_dims = [0, 2],
start_indices_batching_dims = [1, 0],
start_index_map = [1, 3],
index_vector_dim = 3
>,
slice_sizes = array<i64: 1, 1, 1, 1, 8>,
indices_are_sorted = true
} : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32>
func.return %0 : tensor<4x3x5x8xi32>
}

// -----

// CHECK-LABEL: @gather_with_batching_no_index_vector_dim
// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32>
// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{
// CHECK-SAME: dimension_numbers = #stablehlo.gather<
// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2],
// CHECK-SAME: start_index_map = [0, 2, 1], index_vector_dim = 3>,
// CHECK-SAME: indices_are_sorted = false,
// CHECK-SAME: slice_sizes = array<i64: 1, 1, 1, 8>
// CHECK-SAME: }> : (tensor<3x2x4x9xi32>, tensor<4x3x5x3xi32>) -> tensor<4x3x5x8xi32>
// CHECK-NEXT: return %[[gather]] : tensor<4x3x5x8xi32>
func.func @gather_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, %arg1: tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> {
// CHECK-NO-DOWNGRADE: operand_batching_dims = [0, 2]
// CHECK-NO-DOWNGRADE: start_indices_batching_dims = [1, 0]
%0 = "stablehlo.gather"(%arg0, %arg1) <{
dimension_numbers = #stablehlo.gather<
offset_dims = [3],
collapsed_slice_dims = [1],
operand_batching_dims = [0, 2],
start_indices_batching_dims = [1, 0],
start_index_map = [1],
index_vector_dim = 3
>,
slice_sizes = array<i64: 1, 1, 1, 8>,
indices_are_sorted = true
}> : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32>
func.return %0 : tensor<4x3x5x8xi32>
}

// -----

// CHECK-LABEL: @gather_with_batching_dim_size_zero
// CHECK-NEXT: %[[iota:.*]] = stablehlo.iota dim = 0 : tensor<0x3x5x1xi32>
// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota]], %arg1, dim = 3 : (tensor<0x3x5x1xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x2xi32>
// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{
// CHECK-SAME: dimension_numbers = #stablehlo.gather<
// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1],
// CHECK-SAME: start_index_map = [0, 1], index_vector_dim = 3>,
// CHECK-SAME: indices_are_sorted = false,
// CHECK-SAME: slice_sizes = array<i64: 0, 1, 8>
// CHECK-SAME: }> : (tensor<0x2x9xi32>, tensor<0x3x5x2xi32>) -> tensor<0x3x5x8xi32>
// CHECK-NEXT: return %[[gather]] : tensor<0x3x5x8xi32>
func.func @gather_with_batching_dim_size_zero(%arg0: tensor<0x2x9xi32>, %arg1: tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> {
// CHECK-NO-DOWNGRADE: operand_batching_dims = [0]
// CHECK-NO-DOWNGRADE: start_indices_batching_dims = [0]
%0 = "stablehlo.gather"(%arg0, %arg1) <{
dimension_numbers = #stablehlo.gather<
offset_dims = [3],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [0],
start_index_map = [1],
index_vector_dim = 3
>,
slice_sizes = array<i64: 0, 1, 8>,
indices_are_sorted = true
}> : (tensor<0x2x9xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32>
func.return %0 : tensor<0x3x5x8xi32>
}

// -----

// CHECK-LABEL: @scatter_with_batching_dims
// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32>
// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{
// CHECK-SAME: indices_are_sorted = false,
// CHECK-SAME: dimension_numbers = #stablehlo.scatter<
// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3],
// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2, 1, 3], index_vector_dim = 3>,
// CHECK-SAME: unique_indices = false}>
// CHECK: (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32>
// CHECK-NEXT: return %[[scatter]] : tensor<3x2x4x7x9xi32>
func.func @scatter_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> {
// CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2]
// CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0]
%0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{
indices_are_sorted = true,
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3],
inserted_window_dims = [1, 3],
input_batching_dims = [0, 2],
scatter_indices_batching_dims = [1, 0],
scatter_dims_to_operand_dims = [1, 3],
index_vector_dim = 3
>,
unique_indices = false
}> ({
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
stablehlo.return %arg4 : tensor<i32>
}) : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32>
func.return %0 : tensor<3x2x4x7x9xi32>
}

// -----

// CHECK-LABEL: @scatter_with_batching_no_index_vector_dim
// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32>
// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{
// CHECK-SAME: indices_are_sorted = false,
// CHECK-SAME: dimension_numbers = #stablehlo.scatter<
// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2],
// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2, 1], index_vector_dim = 3>,
// CHECK-SAME: unique_indices = true}>
// CHECK: (tensor<3x2x4x9xi32>, tensor<4x3x5x3xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32>
// CHECK-NEXT: return %[[scatter]] : tensor<3x2x4x9xi32>
func.func @scatter_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, %arg1: tensor<4x3x5xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> {
// CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2]
// CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0]
%0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{
indices_are_sorted = true,
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3],
inserted_window_dims = [1],
input_batching_dims = [0, 2],
scatter_indices_batching_dims = [1, 0],
scatter_dims_to_operand_dims = [1],
index_vector_dim = 3
>,
unique_indices = true
}> ({
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
stablehlo.return %arg4 : tensor<i32>
}) : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32>
func.return %0 : tensor<3x2x4x9xi32>
}
144 changes: 142 additions & 2 deletions stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,22 @@ limitations under the License.

#include <fcntl.h>

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <iterator>
#include <utility>

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -58,6 +66,132 @@ vhlo::Version validateTargetVersion(llvm::StringRef versionRef) {
return targetVersion;
}

SmallVector<int64_t> mergeSortedDims(ArrayRef<int64_t> dims1,
ArrayRef<int64_t> dims2) {
SmallVector<int64_t> result;
result.reserve(dims1.size() + dims2.size());
std::merge(dims1.begin(), dims1.end(), dims2.begin(), dims2.end(),
std::back_inserter(result));
return result;
}

// Returns an updated indices tensor such that an `IotaOp` is prepended for each
// dim in `indicesBatchingDims` with a `ConcatenateOp`.
//
// If `indexVectorDim` is equal to the rank of `indices`, it is reshaped to have
// a trailing dimension of size 1 so it can be concatenated with the `IotaOp`s.
Value createConcatIndices(Value indices, int64_t indexVectorDim,
ArrayRef<int64_t> indicesBatchingDims,
PatternRewriter &rewriter) {
Location loc = indices.getLoc();
auto indicesType = cast<RankedTensorType>(indices.getType());
bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank();

SmallVector<int64_t> iotaShape(indicesType.getShape());
if (indexVectorDimOnLastDim) {
iotaShape.push_back(1);
} else {
iotaShape[indexVectorDim] = 1;
}
auto iotaType =
RankedTensorType::get(iotaShape, indicesType.getElementType());

SmallVector<Value> indicesToConcat;
indicesToConcat.reserve(indicesBatchingDims.size() + 1);
for (int64_t batchingDim : indicesBatchingDims) {
indicesToConcat.push_back(
rewriter.create<IotaOp>(loc, iotaType, batchingDim));
}
if (indexVectorDimOnLastDim) {
indicesToConcat.push_back(
rewriter.create<ReshapeOp>(loc, iotaType, indices));
} else {
indicesToConcat.push_back(indices);
}
return rewriter.create<ConcatenateOp>(loc, indicesToConcat, indexVectorDim);
}

//===----------------------------------------------------------------------===//
// Patterns (non DRR)
//===----------------------------------------------------------------------===//

// Converts a `GatherOp` with batching dims to a `GatherOp` without batching
// dims, such that each batching dim becomes a collapsed slice dim with a
// corresponding `IotaOp` concatenated to the start indices.
class GatherWithBatchingDimsExpander : public OpRewritePattern<GatherOp> {
using OpRewritePattern<GatherOp>::OpRewritePattern;

LogicalResult matchAndRewrite(GatherOp op,
PatternRewriter &rewriter) const override {
GatherDimensionNumbersAttr dimNumbers = op.getDimensionNumbers();
ArrayRef<int64_t> operandBatchingDims = dimNumbers.getOperandBatchingDims();
if (operandBatchingDims.empty()) {
return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
diag << "gather op has no batching dims";
});
}

SmallVector<int64_t> newCollapsedSliceDims = mergeSortedDims(
operandBatchingDims, dimNumbers.getCollapsedSliceDims());
SmallVector<int64_t> newStartIndexMap =
llvm::to_vector(llvm::concat<const int64_t>(
operandBatchingDims, dimNumbers.getStartIndexMap()));
Value newIndices = createConcatIndices(
op.getStartIndices(), dimNumbers.getIndexVectorDim(),
dimNumbers.getStartIndicesBatchingDims(), rewriter);
rewriter.replaceOpWithNewOp<GatherOp>(
op, op.getOperand(), newIndices,
GatherDimensionNumbersAttr::get(
op.getContext(), dimNumbers.getOffsetDims(), newCollapsedSliceDims,
/*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{},
newStartIndexMap, dimNumbers.getIndexVectorDim()),
op.getSliceSizes(), /*indicesAreSorted=*/false);

return success();
}
};

// Converts a `ScatterOp` with batching dims to a `ScatterOp` without batching
// dims, such that each batching dim becomes an inserted window dim with a
// corresponding `IotaOp` concatenated to the scatter indices.
class ScatterWithBatchingDimsExpander : public OpRewritePattern<ScatterOp> {
using OpRewritePattern<ScatterOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter &rewriter) const override {
ScatterDimensionNumbersAttr dimNumbers = op.getScatterDimensionNumbers();
ArrayRef<int64_t> inputBatchingDims = dimNumbers.getInputBatchingDims();
if (inputBatchingDims.empty()) {
return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
diag << "scatter op has no batching dims";
});
}

SmallVector<int64_t> newInsertedWindowDims =
mergeSortedDims(inputBatchingDims, dimNumbers.getInsertedWindowDims());
SmallVector<int64_t> newScatterDimsToOperandDims =
llvm::to_vector(llvm::concat<const int64_t>(
inputBatchingDims, dimNumbers.getScatterDimsToOperandDims()));
Value newIndices = createConcatIndices(
op.getScatterIndices(), dimNumbers.getIndexVectorDim(),
dimNumbers.getScatterIndicesBatchingDims(), rewriter);
auto newScatterOp = rewriter.create<ScatterOp>(
op.getLoc(), op->getResultTypes(), op.getInputs(), newIndices,
op.getUpdates(),
ScatterDimensionNumbersAttr::get(
op.getContext(), dimNumbers.getUpdateWindowDims(),
newInsertedWindowDims,
/*inputBatchingDims=*/{}, /*scatterIndicesBatchingDims=*/{},
newScatterDimsToOperandDims, dimNumbers.getIndexVectorDim()),
/*indicesAreSorted=*/false, op.getUniqueIndices());

newScatterOp.getUpdateComputation().takeBody(op.getUpdateComputation());
rewriter.replaceOp(op, newScatterOp.getResults());

return success();
}
};

//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -107,10 +241,16 @@ struct StablehloCreateCompatibilityExpanderPass
void populateStablehloCreateCompatibilityExpanderPatterns(
RewritePatternSet *patterns, MLIRContext *context,
vhlo::Version targetVersion) {
// StableHLO GatherOp/ScatterOp with batching dims is introduced in v1.1.0.
if (targetVersion < vhlo::Version(1, 1, 0)) {
patterns
->add<GatherWithBatchingDimsExpander, ScatterWithBatchingDimsExpander>(
context);
}
// StableHLO TanOp is introduced in v1.4.0.
if (targetVersion < vhlo::Version(1, 4, 0)) {
patterns->add<TanOp_ComplexElementType_CompatiblityExpander>(context);
patterns->add<TanOp_CompatiblityExpander>(context);
patterns->add<TanOp_ComplexElementType_CompatiblityExpander,
TanOp_CompatiblityExpander>(context);
}
}

Expand Down

0 comments on commit 9bb28f8

Please sign in to comment.