Skip to content

Commit

Permalink
TTIR constant materialization (#1548)
Browse files Browse the repository at this point in the history
* TTIR constant materialization

* GetDimensionSizeOp removed from decomp and realized through fold + materialization

* Added test for get_dimension_size -> constant conversion
  • Loading branch information
azecevicTT authored Dec 12, 2024
1 parent 63a8f1f commit ed84a56
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 45 deletions.
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIRBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def TTIR_Dialect : Dialect {
"::mlir::cf::ControlFlowDialect",
"::mlir::tt::TTDialect"
];

let hasConstantMaterializer = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
29 changes: 0 additions & 29 deletions lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,34 +901,6 @@ struct PoolingToPool2dPattern : public OpConversionPattern<ttir::PoolingOp> {
}
};

class GetDimensionSizeToConstantConversionPattern
: public OpConversionPattern<ttir::GetDimensionSizeOp> {
public:
using OpConversionPattern<ttir::GetDimensionSizeOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::GetDimensionSizeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

const RankedTensorType inputTensorType =
mlir::cast<RankedTensorType>(op.getOperand().getType());

int64_t dimensionIndex = op.getDimension();

int32_t dimSize = inputTensorType.getShape()[dimensionIndex];

mlir::ShapedType valueType = mlir::cast<mlir::ShapedType>(op.getType());

mlir::ElementsAttr valueAttr =
mlir::DenseElementsAttr::get<int>(valueType, dimSize);

rewriter.replaceOpWithNewOp<mlir::tt::ttir::ConstantOp>(op, valueType,
valueAttr);

return success();
}
};

// SelectOp is converted to a series of SliceOp and potentially a ConcatOp if
// the sliced dimension is sliced multiple times. For example, if the input
// tensor is
Expand Down Expand Up @@ -1179,7 +1151,6 @@ void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx,
patterns.add<IndexToSliceConversionPattern>(typeConverter, ctx);
patterns.add<Legalize1DConvolutionPattern>(typeConverter, ctx);
patterns.add<ConvolutionToConv2dPattern>(typeConverter, ctx);
patterns.add<GetDimensionSizeToConstantConversionPattern>(typeConverter, ctx);
patterns.add<GatherToEmbeddingConversionPattern>(typeConverter, ctx);
patterns.add<SelectToSliceConversionPattern>(typeConverter, ctx);
patterns.add<ArangeForceLastDimensionPattern>(typeConverter, ctx);
Expand Down
13 changes: 13 additions & 0 deletions lib/Dialect/TTIR/IR/TTIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,16 @@ void TTIRDialect::initialize() {
#include "ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.cpp.inc"
>();
}

//===----------------------------------------------------------------------===//
// TTIR constant materializer.
//===----------------------------------------------------------------------===//

::mlir::Operation *TTIRDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto elementsAttr = mlir::dyn_cast<mlir::ElementsAttr>(value)) {
return builder.create<ttir::ConstantOp>(loc, type, elementsAttr);
}
return {};
}
20 changes: 4 additions & 16 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,28 +92,16 @@ ::mlir::OpFoldResult mlir::tt::ttir::ConstantOp::fold(FoldAdaptor adaptor) {
// GetDimensionSizeOp folder
::mlir::OpFoldResult
mlir::tt::ttir::GetDimensionSizeOp::fold(FoldAdaptor adaptor) {

const RankedTensorType inputTensorType =
mlir::cast<RankedTensorType>(getOperand().getType());

int64_t dimensionIndex = getDimension();

if (dimensionIndex >=
static_cast<int64_t>(inputTensorType.getShape().size())) {
return nullptr;
};

RankedTensorType inputTensorType = getOperand().getType();
uint32_t dimensionIndex = getDimension();
int32_t dimSize = inputTensorType.getShape()[dimensionIndex];

mlir::ShapedType valueType = mlir::cast<mlir::ShapedType>(getType());

return mlir::DenseElementsAttr::get<int>(valueType, dimSize);
return mlir::DenseElementsAttr::get<int32_t>(getType(), dimSize);
}

// GetDimensionSizeOp verification
::mlir::LogicalResult mlir::tt::ttir::GetDimensionSizeOp::verify() {
const RankedTensorType inputTensorType =
mlir::cast<RankedTensorType>(getOperand().getType());
RankedTensorType inputTensorType = getOperand().getType();

int64_t dimensionIndex = getDimension();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --ttir-to-ttir-decomposition %s | FileCheck %s
module {
func.func @get_dimension_size_decomposition(%arg0: tensor<32x64x128xf32>) -> tensor<1xi32> {
// CHECK: [[VAL:%.+]] = "ttir.constant"
// CHECK-SAME: value = dense<128> : tensor<1xi32>
// CHECK: return [[VAL]] : tensor<1xi32>
%0 = "ttir.get_dimension_size"(%arg0) <{dimension = 2 : i32}> : (tensor<32x64x128xf32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
}
}

0 comments on commit ed84a56

Please sign in to comment.