From 3ad6ac1d186a623a9b823cf526afd94847c8ecc8 Mon Sep 17 00:00:00 2001 From: ddilbaz Date: Thu, 17 Oct 2024 15:56:27 +0000 Subject: [PATCH] Add gather op end-to-end [#754] Gather is lowered into ttir::EmbeddingOp if possible. ttir::reshapeOp is added to sliceSizes before embeddingOp if needed. Metal currently only supports embedding, not gather. --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 27 ++++ .../ttmlir/Dialect/TTIR/Transforms/Passes.td | 8 ++ .../StableHLOToTTIRPatterns.cpp | 51 ++++++++ lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp | 2 + lib/Dialect/TTIR/Transforms/Transforms.cpp | 120 ++++++++++++++++++ .../Conversion/StableHLOToTTIR/gather_op.mlir | 25 ++++ 6 files changed, 233 insertions(+) create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index aee213b5af..3105241f76 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -83,6 +83,33 @@ def TTIR_GenericOp : TTIR_DPSOp<"generic", [AttrSizedOperandSegments]> { }]; } +def TTIR_GatherOp: TTIR_DPSOp<"gather"> { + let summary = "Gather operation."; + let description = [{ + Gather operation. + }]; + + let arguments = (ins AnyRankedTensor:$input, // operand + AnyRankedTensor:$start_indices, // start_indices + AnyRankedTensor:$output, // result + DenseI64ArrayAttr:$offset_dims, // offset_dims + DenseI64ArrayAttr:$collapsed_slice_dims, // collapsed_slice_dims + DenseI64ArrayAttr:$operand_batching_dims, // operand_batching_dims + DenseI64ArrayAttr:$start_indices_batching_dims, // start_indices_batching_dims + DenseI64ArrayAttr:$start_index_map, // start_index_map + SI64Attr:$index_vector_dim, // index_vector_dim + DenseI64ArrayAttr:$slice_sizes, // slice_sizes + BoolAttr:$indices_are_sorted, // indices_are_sorted (bool) + TT_OperandConstraintArrayAttr:$operand_constraints); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + +} + def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpInterface]> { let summary = "Layout op."; let description = [{ diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index 3e5298440f..60bc02063c 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -134,6 +134,14 @@ def TTIROptimizer: Pass<"ttir-optimizer", "::mlir::ModuleOp"> { ]; } +def TTIRGatherPatternMatch : Pass<"convert-ttir-to-ttir", "::mlir::ModuleOp"> { + let summary = "Convert GatherOp in TTIR dialect to EmbeddingOp in TTIR dialect."; + let description = [{ + Convert GatherOp in TTIR dialect to EmbeddingOp in TTIR dialect. Add ReshapeOp to the input of EmbeddingOp if needed. + }]; +} + + def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> { let summary = "Load system desc."; let description = [{ diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 5d63d87b0d..e6f95db27c 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -756,6 +756,51 @@ class StableHLOToTTIRConcatOpConversionPattern return success(); } }; +class StableHLOToTTIRGatherOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::GatherOp srcOp, + mlir::stablehlo::GatherOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Create the output tensor type based on inputs + auto outputType = mlir::cast( + getTypeConverter()->convertType(srcOp.getResult().getType())); + // Create an empty output tensor with the computed shape + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + + auto dimensionNumbers = srcOp.getDimensionNumbers(); + rewriter.replaceOpWithNewOp( + srcOp, // The original operation to replace + outputType, // Result type + srcOp.getOperands()[0], // Input tensor + srcOp.getOperands()[1], // Start indices + Value(outputTensor), // Output tensor + dimensionNumbers.getOffsetDims(), // offset_dims attribute + dimensionNumbers + .getCollapsedSliceDims(), // collapsed_slice_dims attribute + dimensionNumbers + .getOperandBatchingDims(), // operand_batching_dims attribute + dimensionNumbers + .getStartIndicesBatchingDims(), // start_indices_batching_dims + // attribute + dimensionNumbers.getStartIndexMap(), // start_index_map attribute + dimensionNumbers.getIndexVectorDim(), // index_vector_dim attribute + srcOp.getSliceSizesAttr(), // slice_sizes attribute + false, // indices_are_sorted attribute + rewriter.getArrayAttr( // operand constraints + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + return success(); + } +}; void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, @@ -858,6 +903,11 @@ void addReshapeOpConversionPattern(MLIRContext *ctx, patterns.add(typeConverter, ctx); } +void addGatherOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); +} + } // namespace namespace mlir::tt { @@ -877,6 +927,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx, addCompareOpsConversionPatterns(ctx, patterns, typeConverter); addConcatOpsConversionPatterns(ctx, patterns, typeConverter); addReshapeOpConversionPattern(ctx, patterns, typeConverter); + addGatherOpConversionPattern(ctx, patterns, typeConverter); } } // namespace mlir::tt diff --git a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp index b4f3b5ee05..418459c3c9 100644 --- a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp +++ b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp @@ -8,6 +8,7 @@ #include "mlir/Transforms/Passes.h" #include "ttmlir/Conversion/Passes.h" +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" namespace mlir::tt::ttir { //===----------------------------------------------------------------------===// @@ -18,6 +19,7 @@ namespace mlir::tt::ttir { void createStableHLOToTTIRPipeline( OpPassManager &pm, const StableHLOToTTIRPipelineOptions &options) { pm.addPass(createConvertStableHLOToTTIRPass()); + pm.addPass(createTTIRGatherPatternMatch()); if (options.removeDeadValuesEnabled) { pm.addPass(mlir::createRemoveDeadValuesPass()); } diff --git a/lib/Dialect/TTIR/Transforms/Transforms.cpp b/lib/Dialect/TTIR/Transforms/Transforms.cpp index 084f1a90d4..efa03ed4f9 100644 --- a/lib/Dialect/TTIR/Transforms/Transforms.cpp +++ b/lib/Dialect/TTIR/Transforms/Transforms.cpp @@ -12,6 +12,7 @@ namespace mlir::tt::ttir { #define GEN_PASS_DEF_TTIRSLIDINGWINDOW2DFIXSHAPES +#define GEN_PASS_DEF_TTIRGATHERPATTERNMATCH #include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" //===----------------------------------------------------------------------===// @@ -151,4 +152,123 @@ class TTIRSlidingWindow2dFixShapes } }; +class GatherOpRewritePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GatherOp op, + PatternRewriter &rewriter) const final { + bool reduceWeightTensorDim = false; + auto outputType = mlir::cast(op.getResult().getType()); + auto shape = outputType.getShape(); + auto startIndices = op.getStartIndices(); // start indices of the gather op + auto startIndicesType = + mlir::cast(startIndices.getType()); + auto sliceSizes = op.getSliceSizes(); // slice sizes of the gather op + auto offsetDims = op.getOffsetDims(); + auto collapsedSliceDims = + op.getCollapsedSliceDims(); // collapsed slice dims of the gather op + + if (shape.size() > 1) { + auto hiddenDim = shape[shape.size() - 1]; + assert(sliceSizes.size() > 1 && + "sliceSizes should have at least 2 elements"); + if (sliceSizes[0] != 1 || sliceSizes[1] != hiddenDim) { + return rewriter.notifyMatchFailure(op, "Did not satisfy sliceSizes"); + } + } + if (offsetDims.size() != 1 && + std::vector(offsetDims.begin(), offsetDims.end()) != + std::vector{2}) { + return rewriter.notifyMatchFailure(op, "Did not satisfy offsetDims"); + } + llvm::outs() << "collapsedSliceDims.size() = " << collapsedSliceDims.size() + << "\n"; + if (collapsedSliceDims.size() != 1 || + std::vector(collapsedSliceDims.begin(), + collapsedSliceDims.end()) != + std::vector{0}) { + return rewriter.notifyMatchFailure(op, + "Did not satisfy collapsedSliceDims"); + } + if (shape.size() == startIndicesType.getShape().size()) { + if (startIndicesType.getShape()[shape.size() - 1] == 1) { + reduceWeightTensorDim = true; + } else { + return rewriter.notifyMatchFailure(op, + "Did not satisfy startIndicesType"); + } + } + + if (reduceWeightTensorDim) { + // insert reshape op to remove the last dimension of start indices + // before gather/ embedding op + std::vector newShape(startIndicesType.getShape().begin(), + startIndicesType.getShape().end() - 1); + std::vector newShapeInt64(newShape.begin(), newShape.end()); + tensor::EmptyOp reshapeOutputTensor = rewriter.create( + op.getLoc(), llvm::ArrayRef(newShapeInt64), + startIndicesType.getElementType()); + mlir::tt::ttir::ReshapeOp reshapeOp = + rewriter.create( + op.getLoc(), + mlir::RankedTensorType::get(newShapeInt64, + startIndicesType.getElementType()), + startIndices, reshapeOutputTensor, + rewriter.getI32ArrayAttr(newShape), + rewriter.getArrayAttr(SmallVector( + startIndicesType.getShape().size() - 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + if (reshapeOp && op) { + reshapeOp->moveBefore(op); + EmbeddingOp embeddingOp = rewriter.create( + op.getLoc(), op.getResult().getType(), + reshapeOp.getResult(), // input - start indices + op.getOperands()[0], // weight - input tensor + op.getOutput(), + rewriter.getArrayAttr( // operand constraints + SmallVector(op.getNumOperands() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + assert(embeddingOp != nullptr && "Failed to create embedding op"); + rewriter.replaceOp(op, embeddingOp); + return success(); + } else { + return rewriter.notifyMatchFailure(op, "Failed to create reshape op"); + } + } else { + EmbeddingOp embeddingOp = rewriter.create( + op.getLoc(), op.getResult().getType(), + op.getStartIndices(), // input - start indices + op.getOperands()[0], // weight - input tensor + op.getOutput(), + rewriter.getArrayAttr( // operand constraints + SmallVector(op.getNumOperands() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + assert(embeddingOp != nullptr && "Failed to create embedding op"); + rewriter.replaceOp(op, embeddingOp); + return success(); + } + + return rewriter.notifyMatchFailure(op, "Failed to create embedding op"); + } +}; + +class TTIRGatherPatternMatch + : public ttir::impl::TTIRGatherPatternMatchBase { +public: + using ttir::impl::TTIRGatherPatternMatchBase< + TTIRGatherPatternMatch>::TTIRGatherPatternMatchBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { + signalPassFailure(); + return; + } + } +}; + } // namespace mlir::tt::ttir diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir new file mode 100644 index 0000000000..ed124da7de --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir @@ -0,0 +1,25 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module @jit_gather attributes {} { + func.func public @test_gather_0(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> { + %0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<32000x1024xf32>, tensor<1x32xi32>) -> tensor<1x32x1024xf32> + // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.embedding"[[C:.*]] + return %0 : tensor<1x32x1024xf32> + } + func.func public @test_gather_1(%operand: tensor<448x384xf32>, %start_indices: tensor<1x2x1xi64>) -> tensor<1x2x384xf32> { + %0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<448x384xf32>, tensor<1x2x1xi64>) -> tensor<1x2x384xf32> + // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.embedding"[[C:.*]] + return %0 : tensor<1x2x384xf32> + } + + func.func public @test_gather_2(%operand: tensor<51864x384xf32>, %start_indices: tensor<1x2xi64>) -> tensor<1x2x384xf32> { + %0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<51864x384xf32>, tensor<1x2xi64>) -> tensor<1x2x384xf32> + // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.embedding"[[C:.*]] + return %0 : tensor<1x2x384xf32> + } + +}