Skip to content

Commit

Permalink
Add gather op end-to-end [#754]
Browse files Browse the repository at this point in the history
Gather is lowered into ttir::EmbeddingOp if possible. ttir::reshapeOp is
added to sliceSizes before embeddingOp if needed. Metal currently only
supports embedding, not gather.
  • Loading branch information
ddilbazTT committed Oct 17, 2024
1 parent 998c81a commit 3ad6ac1
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 0 deletions.
27 changes: 27 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,33 @@ def TTIR_GenericOp : TTIR_DPSOp<"generic", [AttrSizedOperandSegments]> {
}];
}

def TTIR_GatherOp: TTIR_DPSOp<"gather"> {
let summary = "Gather operation.";
let description = [{
Gather operation.
}];

let arguments = (ins AnyRankedTensor:$input, // operand
AnyRankedTensor:$start_indices, // start_indices
AnyRankedTensor:$output, // result
DenseI64ArrayAttr:$offset_dims, // offset_dims
DenseI64ArrayAttr:$collapsed_slice_dims, // collapsed_slice_dims
DenseI64ArrayAttr:$operand_batching_dims, // operand_batching_dims
DenseI64ArrayAttr:$start_indices_batching_dims, // start_indices_batching_dims
DenseI64ArrayAttr:$start_index_map, // start_index_map
SI64Attr:$index_vector_dim, // index_vector_dim
DenseI64ArrayAttr:$slice_sizes, // slice_sizes
BoolAttr:$indices_are_sorted, // indices_are_sorted (bool)
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

}

def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpInterface]> {
let summary = "Layout op.";
let description = [{
Expand Down
8 changes: 8 additions & 0 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ def TTIROptimizer: Pass<"ttir-optimizer", "::mlir::ModuleOp"> {
];
}

def TTIRGatherPatternMatch : Pass<"convert-ttir-to-ttir", "::mlir::ModuleOp"> {
let summary = "Convert GatherOp in TTIR dialect to EmbeddingOp in TTIR dialect.";
let description = [{
Convert GatherOp in TTIR dialect to EmbeddingOp in TTIR dialect. Add ReshapeOp to the input of EmbeddingOp if needed.
}];
}


def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> {
let summary = "Load system desc.";
let description = [{
Expand Down
51 changes: 51 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,51 @@ class StableHLOToTTIRConcatOpConversionPattern
return success();
}
};
class StableHLOToTTIRGatherOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::GatherOp> {

using OpConversionPattern<mlir::stablehlo::GatherOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::GatherOp srcOp,
mlir::stablehlo::GatherOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// Create the output tensor type based on inputs
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));
// Create an empty output tensor with the computed shape
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

auto dimensionNumbers = srcOp.getDimensionNumbers();
rewriter.replaceOpWithNewOp<mlir::tt::ttir::GatherOp>(
srcOp, // The original operation to replace
outputType, // Result type
srcOp.getOperands()[0], // Input tensor
srcOp.getOperands()[1], // Start indices
Value(outputTensor), // Output tensor
dimensionNumbers.getOffsetDims(), // offset_dims attribute
dimensionNumbers
.getCollapsedSliceDims(), // collapsed_slice_dims attribute
dimensionNumbers
.getOperandBatchingDims(), // operand_batching_dims attribute
dimensionNumbers
.getStartIndicesBatchingDims(), // start_indices_batching_dims
// attribute
dimensionNumbers.getStartIndexMap(), // start_index_map attribute
dimensionNumbers.getIndexVectorDim(), // index_vector_dim attribute
srcOp.getSliceSizesAttr(), // slice_sizes attribute
false, // indices_are_sorted attribute
rewriter.getArrayAttr( // operand constraints
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));

return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
Expand Down Expand Up @@ -858,6 +903,11 @@ void addReshapeOpConversionPattern(MLIRContext *ctx,
patterns.add<StableHLOToTTIRReshapeOpConversionPattern>(typeConverter, ctx);
}

void addGatherOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIRGatherOpConversionPattern>(typeConverter, ctx);
}

} // namespace

namespace mlir::tt {
Expand All @@ -877,6 +927,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addCompareOpsConversionPatterns(ctx, patterns, typeConverter);
addConcatOpsConversionPatterns(ctx, patterns, typeConverter);
addReshapeOpConversionPattern(ctx, patterns, typeConverter);
addGatherOpConversionPattern(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
2 changes: 2 additions & 0 deletions lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlir/Transforms/Passes.h"

#include "ttmlir/Conversion/Passes.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"

namespace mlir::tt::ttir {
//===----------------------------------------------------------------------===//
Expand All @@ -18,6 +19,7 @@ namespace mlir::tt::ttir {
void createStableHLOToTTIRPipeline(
OpPassManager &pm, const StableHLOToTTIRPipelineOptions &options) {
pm.addPass(createConvertStableHLOToTTIRPass());
pm.addPass(createTTIRGatherPatternMatch());
if (options.removeDeadValuesEnabled) {
pm.addPass(mlir::createRemoveDeadValuesPass());
}
Expand Down
120 changes: 120 additions & 0 deletions lib/Dialect/TTIR/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

namespace mlir::tt::ttir {
#define GEN_PASS_DEF_TTIRSLIDINGWINDOW2DFIXSHAPES
#define GEN_PASS_DEF_TTIRGATHERPATTERNMATCH
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -151,4 +152,123 @@ class TTIRSlidingWindow2dFixShapes
}
};

class GatherOpRewritePattern : public OpRewritePattern<GatherOp> {
public:
using OpRewritePattern<GatherOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp op,
PatternRewriter &rewriter) const final {
bool reduceWeightTensorDim = false;
auto outputType = mlir::cast<RankedTensorType>(op.getResult().getType());
auto shape = outputType.getShape();
auto startIndices = op.getStartIndices(); // start indices of the gather op
auto startIndicesType =
mlir::cast<RankedTensorType>(startIndices.getType());
auto sliceSizes = op.getSliceSizes(); // slice sizes of the gather op
auto offsetDims = op.getOffsetDims();
auto collapsedSliceDims =
op.getCollapsedSliceDims(); // collapsed slice dims of the gather op

if (shape.size() > 1) {
auto hiddenDim = shape[shape.size() - 1];
assert(sliceSizes.size() > 1 &&
"sliceSizes should have at least 2 elements");
if (sliceSizes[0] != 1 || sliceSizes[1] != hiddenDim) {
return rewriter.notifyMatchFailure(op, "Did not satisfy sliceSizes");
}
}
if (offsetDims.size() != 1 &&
std::vector<int64_t>(offsetDims.begin(), offsetDims.end()) !=
std::vector<int64_t>{2}) {
return rewriter.notifyMatchFailure(op, "Did not satisfy offsetDims");
}
llvm::outs() << "collapsedSliceDims.size() = " << collapsedSliceDims.size()
<< "\n";
if (collapsedSliceDims.size() != 1 ||
std::vector<int64_t>(collapsedSliceDims.begin(),
collapsedSliceDims.end()) !=
std::vector<int64_t>{0}) {
return rewriter.notifyMatchFailure(op,
"Did not satisfy collapsedSliceDims");
}
if (shape.size() == startIndicesType.getShape().size()) {
if (startIndicesType.getShape()[shape.size() - 1] == 1) {
reduceWeightTensorDim = true;
} else {
return rewriter.notifyMatchFailure(op,
"Did not satisfy startIndicesType");
}
}

if (reduceWeightTensorDim) {
// insert reshape op to remove the last dimension of start indices
// before gather/ embedding op
std::vector<int32_t> newShape(startIndicesType.getShape().begin(),
startIndicesType.getShape().end() - 1);
std::vector<int64_t> newShapeInt64(newShape.begin(), newShape.end());
tensor::EmptyOp reshapeOutputTensor = rewriter.create<tensor::EmptyOp>(
op.getLoc(), llvm::ArrayRef<int64_t>(newShapeInt64),
startIndicesType.getElementType());
mlir::tt::ttir::ReshapeOp reshapeOp =
rewriter.create<mlir::tt::ttir::ReshapeOp>(
op.getLoc(),
mlir::RankedTensorType::get(newShapeInt64,
startIndicesType.getElementType()),
startIndices, reshapeOutputTensor,
rewriter.getI32ArrayAttr(newShape),
rewriter.getArrayAttr(SmallVector<Attribute>(
startIndicesType.getShape().size() - 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
if (reshapeOp && op) {
reshapeOp->moveBefore(op);
EmbeddingOp embeddingOp = rewriter.create<EmbeddingOp>(
op.getLoc(), op.getResult().getType(),
reshapeOp.getResult(), // input - start indices
op.getOperands()[0], // weight - input tensor
op.getOutput(),
rewriter.getArrayAttr( // operand constraints
SmallVector<Attribute>(op.getNumOperands() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
assert(embeddingOp != nullptr && "Failed to create embedding op");
rewriter.replaceOp(op, embeddingOp);
return success();
} else {
return rewriter.notifyMatchFailure(op, "Failed to create reshape op");
}
} else {
EmbeddingOp embeddingOp = rewriter.create<EmbeddingOp>(
op.getLoc(), op.getResult().getType(),
op.getStartIndices(), // input - start indices
op.getOperands()[0], // weight - input tensor
op.getOutput(),
rewriter.getArrayAttr( // operand constraints
SmallVector<Attribute>(op.getNumOperands() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
assert(embeddingOp != nullptr && "Failed to create embedding op");
rewriter.replaceOp(op, embeddingOp);
return success();
}

return rewriter.notifyMatchFailure(op, "Failed to create embedding op");
}
};

class TTIRGatherPatternMatch
: public ttir::impl::TTIRGatherPatternMatchBase<TTIRGatherPatternMatch> {
public:
using ttir::impl::TTIRGatherPatternMatchBase<
TTIRGatherPatternMatch>::TTIRGatherPatternMatchBase;
void runOnOperation() final {
RewritePatternSet patterns(&getContext());
patterns.add<GatherOpRewritePattern>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
signalPassFailure();
return;
}
}
};

} // namespace mlir::tt::ttir
25 changes: 25 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module @jit_gather attributes {} {
func.func public @test_gather_0(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> {
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1024>}> : (tensor<32000x1024xf32>, tensor<1x32xi32>) -> tensor<1x32x1024xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.embedding"[[C:.*]]
return %0 : tensor<1x32x1024xf32>
}
func.func public @test_gather_1(%operand: tensor<448x384xf32>, %start_indices: tensor<1x2x1xi64>) -> tensor<1x2x384xf32> {
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 384>}> : (tensor<448x384xf32>, tensor<1x2x1xi64>) -> tensor<1x2x384xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.embedding"[[C:.*]]
return %0 : tensor<1x2x384xf32>
}

func.func public @test_gather_2(%operand: tensor<51864x384xf32>, %start_indices: tensor<1x2xi64>) -> tensor<1x2x384xf32> {
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 384>}> : (tensor<51864x384xf32>, tensor<1x2xi64>) -> tensor<1x2x384xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.embedding"[[C:.*]]
return %0 : tensor<1x2x384xf32>
}

}

0 comments on commit 3ad6ac1

Please sign in to comment.