diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td b/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td index b71541b8d..f05cfdee7 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td @@ -32,6 +32,8 @@ def TTIR_Dialect : Dialect { "::mlir::cf::ControlFlowDialect", "::mlir::tt::TTDialect" ]; + + let hasConstantMaterializer = 1; } //===----------------------------------------------------------------------===// diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index 7a1b4e812..25140b9bc 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -901,34 +901,6 @@ struct PoolingToPool2dPattern : public OpConversionPattern { } }; -class GetDimensionSizeToConstantConversionPattern - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ttir::GetDimensionSizeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - const RankedTensorType inputTensorType = - mlir::cast(op.getOperand().getType()); - - int64_t dimensionIndex = op.getDimension(); - - int32_t dimSize = inputTensorType.getShape()[dimensionIndex]; - - mlir::ShapedType valueType = mlir::cast(op.getType()); - - mlir::ElementsAttr valueAttr = - mlir::DenseElementsAttr::get(valueType, dimSize); - - rewriter.replaceOpWithNewOp(op, valueType, - valueAttr); - - return success(); - } -}; - // SelectOp is converted to a series of SliceOp and potentially a ConcatOp if // the sliced dimension is sliced multiple times. For example, if the input // tensor is @@ -1179,7 +1151,6 @@ void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); diff --git a/lib/Dialect/TTIR/IR/TTIRDialect.cpp b/lib/Dialect/TTIR/IR/TTIRDialect.cpp index b51a80cd6..120650d74 100644 --- a/lib/Dialect/TTIR/IR/TTIRDialect.cpp +++ b/lib/Dialect/TTIR/IR/TTIRDialect.cpp @@ -72,3 +72,16 @@ void TTIRDialect::initialize() { #include "ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.cpp.inc" >(); } + +//===----------------------------------------------------------------------===// +// TTIR constant materializer. +//===----------------------------------------------------------------------===// + +::mlir::Operation *TTIRDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + if (auto elementsAttr = mlir::dyn_cast(value)) { + return builder.create(loc, type, elementsAttr); + } + return {}; +} diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index cfc149de1..319d8b1e6 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -92,28 +92,16 @@ ::mlir::OpFoldResult mlir::tt::ttir::ConstantOp::fold(FoldAdaptor adaptor) { // GetDimensionSizeOp folder ::mlir::OpFoldResult mlir::tt::ttir::GetDimensionSizeOp::fold(FoldAdaptor adaptor) { - - const RankedTensorType inputTensorType = - mlir::cast(getOperand().getType()); - - int64_t dimensionIndex = getDimension(); - - if (dimensionIndex >= - static_cast(inputTensorType.getShape().size())) { - return nullptr; - }; - + RankedTensorType inputTensorType = getOperand().getType(); + uint32_t dimensionIndex = getDimension(); int32_t dimSize = inputTensorType.getShape()[dimensionIndex]; - mlir::ShapedType valueType = mlir::cast(getType()); - - return mlir::DenseElementsAttr::get(valueType, dimSize); + return mlir::DenseElementsAttr::get(getType(), dimSize); } // GetDimensionSizeOp verification ::mlir::LogicalResult mlir::tt::ttir::GetDimensionSizeOp::verify() { - const RankedTensorType inputTensorType = - mlir::cast(getOperand().getType()); + RankedTensorType inputTensorType = getOperand().getType(); int64_t dimensionIndex = getDimension(); diff --git a/test/ttmlir/Dialect/TTIR/Decomposition/get_dimension_size_decomposition.mlir b/test/ttmlir/Dialect/TTIR/Decomposition/get_dimension_size_decomposition.mlir new file mode 100644 index 000000000..aea7e3eb1 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/Decomposition/get_dimension_size_decomposition.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --ttir-to-ttir-decomposition %s | FileCheck %s +module { + func.func @get_dimension_size_decomposition(%arg0: tensor<32x64x128xf32>) -> tensor<1xi32> { + // CHECK: [[VAL:%.+]] = "ttir.constant" + // CHECK-SAME: value = dense<128> : tensor<1xi32> + // CHECK: return [[VAL]] : tensor<1xi32> + %0 = "ttir.get_dimension_size"(%arg0) <{dimension = 2 : i32}> : (tensor<32x64x128xf32>) -> tensor<1xi32> + return %0 : tensor<1xi32> + } +}