diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td b/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td index f05cfdee7..6d43043aa 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td @@ -6,6 +6,7 @@ #define TTMLIR_TTMLIR_DIALECT_TTIR_TTIRDIALECT_TD include "mlir/IR/OpBase.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// @@ -41,6 +42,28 @@ def TTIR_Dialect : Dialect { //===----------------------------------------------------------------------===// class TTIR_Op traits = []> : - Op; + Op; + +//===----------------------------------------------------------------------===// +// TTIR traits definition. +//===----------------------------------------------------------------------===// + +def TwoOperands : ParamNativeOpTrait<"NOperands", "2">; +def ThreeOperands : ParamNativeOpTrait<"NOperands", "3">; +def FourOperands : ParamNativeOpTrait<"NOperands", "4">; + +class TTIR_Trait traits = []> : NativeOpTrait { + let cppNamespace = "::mlir::tt::ttir::OpTrait"; +} + +// Involution is property of an operation where applying the operation twice results in the original value. +// Example: not(not(x)) == x +def TTIR_Involution : TTIR_Trait<"TTIRInvolution", [DestinationStyleOpInterface, TwoOperands, NoMemoryEffect]>; +// Idempotence is property of an operation where applying the operation twice results in the same value as applying it once. +// Example: abs(abs(x)) == abs(x) +def TTIR_Idempotence : TTIR_Trait<"TTIRIdempotence", [DestinationStyleOpInterface, TwoOperands, NoMemoryEffect]>; +// BinaryIdempotence is property of a binary operation where applying the operation on the same value results in the same value. +// Example: and(x, x) == x +def TTIR_BinaryIdempotence : TTIR_Trait<"TTIRBinaryIdempotence", [DestinationStyleOpInterface, ThreeOperands, NoMemoryEffect]>; #endif diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h index f23fd6d88..4b54d5164 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h @@ -5,16 +5,18 @@ #ifndef TTMLIR_DIALECT_TTIR_IR_TTIROPS_H #define TTMLIR_DIALECT_TTIR_IR_TTIROPS_H +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h" +#include "ttmlir/Dialect/TTIR/IR/TTIRTraits.h" + #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" - -#include "TTIROpsInterfaces.h" #include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.h.inc" diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 839bd81d9..710f88cfe 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -19,7 +19,7 @@ include "mlir/IR/CommonAttrConstraints.td" include "mlir/IR/OpBase.td" class TTIR_DPSOp traits = []> : - TTIR_Op { + TTIR_Op { let extraClassDeclaration = [{ MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } }]; @@ -98,8 +98,9 @@ def TTIR_GetDimensionSizeOp : TTIR_Op<"get_dimension_size"> { let results = (outs AnyRankedTensor:$result); - let hasFolder = 1; let hasVerifier = 1; + + let hasFolder = 1; } def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpInterface]> { @@ -186,12 +187,8 @@ def TTIR_DeallocOp : TTIR_Op<"dealloc"> { // TTIR top level named ops //===----------------------------------------------------------------------===// -def TwoOperands : ParamNativeOpTrait<"NOperands", "2">; -def ThreeOperands : ParamNativeOpTrait<"NOperands", "3">; -def FourOperands : ParamNativeOpTrait<"NOperands", "4">; - class TTIR_ElementwiseOp traits = []> : - TTIR_DPSOp { + TTIR_DPSOp { let description = [{ Base class for elementwise operations. Elementwise operations can take inputs with different shape, @@ -204,7 +201,7 @@ class TTIR_ElementwiseOp traits = []> : } class TTIR_ElementwiseTernaryOp traits = []> : - TTIR_ElementwiseOp { + TTIR_ElementwiseOp { let summary = "Eltwise ternary op."; let description = [{ Eltwise ternary op. @@ -227,7 +224,7 @@ def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where"> { } class TTIR_ElementwiseUnaryOp traits = []> : - TTIR_ElementwiseOp { + TTIR_ElementwiseOp { let summary = "Eltwise unary op."; let description = [{ Eltwise unary op. @@ -242,7 +239,7 @@ class TTIR_ElementwiseUnaryOp traits = []> : ]; } -def TTIR_AbsOp: TTIR_ElementwiseUnaryOp<"abs"> { +def TTIR_AbsOp: TTIR_ElementwiseUnaryOp<"abs", [TTIR_Idempotence]> { let summary = "Eltwise absolute op."; let description = [{ Eltwise absolute operation. @@ -256,7 +253,7 @@ def TTIR_CbrtOp: TTIR_ElementwiseUnaryOp<"cbrt"> { }]; } -def TTIR_CeilOp: TTIR_ElementwiseUnaryOp<"ceil"> { +def TTIR_CeilOp: TTIR_ElementwiseUnaryOp<"ceil", [TTIR_Idempotence]> { let summary = "Eltwise ceil op."; let description = [{ Eltwise ceil operation. @@ -270,7 +267,7 @@ def TTIR_CosOp: TTIR_ElementwiseUnaryOp<"cos"> { }]; } -def TTIR_FloorOp: TTIR_ElementwiseUnaryOp<"floor"> { +def TTIR_FloorOp: TTIR_ElementwiseUnaryOp<"floor", [TTIR_Idempotence]> { let summary = "Eltwise floor op."; let description = [{ Eltwise floor operation. @@ -298,7 +295,7 @@ def TTIR_LogicalNotOp: TTIR_ElementwiseUnaryOp<"logical_not"> { }]; } -def TTIR_BitwiseNotOp : TTIR_ElementwiseUnaryOp<"bitwise_not"> { +def TTIR_BitwiseNotOp : TTIR_ElementwiseUnaryOp<"bitwise_not", [TTIR_Involution]> { let summary = "Eltwise bitwise NOT."; let description = [{ Performs element-wise NOT of tensor `operand` and produces a `result` tensor. @@ -311,7 +308,7 @@ def TTIR_BitwiseNotOp : TTIR_ElementwiseUnaryOp<"bitwise_not"> { }]; } -def TTIR_NegOp: TTIR_ElementwiseUnaryOp<"neg"> { +def TTIR_NegOp: TTIR_ElementwiseUnaryOp<"neg", [TTIR_Involution]> { let summary = "Eltwise negate op."; let description = [{ Eltwise negate operation. @@ -332,14 +329,17 @@ def TTIR_TanhOp: TTIR_ElementwiseUnaryOp<"tanh"> { }]; } -def TTIR_ReciprocalOp : TTIR_ElementwiseUnaryOp<"reciprocal"> { +// TODO (azecevic): What should we do with 0.0 case? +// 1/0.0 = inf, 1/inf = 0.0, but TTNN isn't IEEE754 compliant. +// https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/Handling_Special_Value/special_values.md +def TTIR_ReciprocalOp : TTIR_ElementwiseUnaryOp<"reciprocal", [TTIR_Involution]> { let summary = "Eltwise reciprocal."; let description = [{ Eltwise reciprocal operation. }]; } -def TTIR_ReluOp : TTIR_ElementwiseUnaryOp<"relu"> { +def TTIR_ReluOp : TTIR_ElementwiseUnaryOp<"relu", [TTIR_Idempotence]> { let summary = "Eltwise ReLU."; let description = [{ Eltwise ReLU operation. @@ -360,7 +360,7 @@ def TTIR_SigmoidOp: TTIR_ElementwiseUnaryOp<"sigmoid"> { }]; } -def TTIR_SignOp: TTIR_ElementwiseUnaryOp<"sign"> { +def TTIR_SignOp: TTIR_ElementwiseUnaryOp<"sign", [TTIR_Idempotence]> { let summary = "Eltwise sign operation."; let description = [{ Returns the sign of the `operand` element-wise and produces a `result` @@ -469,7 +469,7 @@ def TTIR_LeakyReluOp : TTIR_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> { } class TTIR_ElementwiseBinaryOp traits = []> : - TTIR_ElementwiseOp { + TTIR_ElementwiseOp { let summary = "Eltwise binary op."; let description = [{ Eltwise binary op. @@ -484,6 +484,7 @@ class TTIR_ElementwiseBinaryOp traits = []> : ]; } +// TODO (azecevic): NaN != NaN, otherwise eq(x, x) == 1. def TTIR_EqualOp : TTIR_ElementwiseBinaryOp<"eq"> { let summary = "Eltwise equal to."; let description = [{ @@ -491,6 +492,7 @@ def TTIR_EqualOp : TTIR_ElementwiseBinaryOp<"eq"> { }]; } +// TODO (azecevic): NaN != NaN, otherwise ne(x, x) == 0. def TTIR_NotEqualOp : TTIR_ElementwiseBinaryOp<"ne"> { let summary = "Eltwise not equal to."; let description = [{ @@ -498,6 +500,7 @@ def TTIR_NotEqualOp : TTIR_ElementwiseBinaryOp<"ne"> { }]; } +// TODO (azecevic): NaN != NaN, otherwise ge(x, x) == 1. def TTIR_GreaterEqualOp : TTIR_ElementwiseBinaryOp<"ge"> { let summary = "Eltwise greater than or equal to."; let description = [{ @@ -505,6 +508,7 @@ def TTIR_GreaterEqualOp : TTIR_ElementwiseBinaryOp<"ge"> { }]; } +// TODO (azecevic): NaN != NaN, otherwise gt(x, x) == 0. def TTIR_GreaterThanOp : TTIR_ElementwiseBinaryOp<"gt"> { let summary = "Eltwise greater than."; let description = [{ @@ -512,6 +516,7 @@ def TTIR_GreaterThanOp : TTIR_ElementwiseBinaryOp<"gt"> { }]; } +// TODO (azecevic): NaN != NaN, otherwise le(x, x) == 1. def TTIR_LessEqualOp : TTIR_ElementwiseBinaryOp<"le"> { let summary = "Eltwise less than or equal to."; let description = [{ @@ -519,6 +524,7 @@ def TTIR_LessEqualOp : TTIR_ElementwiseBinaryOp<"le"> { }]; } +// TODO (azecevic): NaN != NaN, otherwise lt(x, x) == 0. def TTIR_LessThanOp : TTIR_ElementwiseBinaryOp<"lt"> { let summary = "Eltwise less than."; let description = [{ @@ -526,14 +532,14 @@ def TTIR_LessThanOp : TTIR_ElementwiseBinaryOp<"lt"> { }]; } -def TTIR_LogicalAndOp : TTIR_ElementwiseBinaryOp<"logical_and"> { +def TTIR_LogicalAndOp : TTIR_ElementwiseBinaryOp<"logical_and", [TTIR_BinaryIdempotence]> { let summary = "Eltwise logical and."; let description = [{ Eltwise logical and operation. }]; } -def TTIR_LogicalOrOp : TTIR_ElementwiseBinaryOp<"logical_or"> { +def TTIR_LogicalOrOp : TTIR_ElementwiseBinaryOp<"logical_or", [TTIR_BinaryIdempotence]> { let summary = "Eltwise logical or."; let description = [{ Eltwise logical or operation. @@ -547,7 +553,7 @@ def TTIR_LogicalXorOp : TTIR_ElementwiseBinaryOp<"logical_xor"> { }]; } -def TTIR_BitwiseAndOp : TTIR_ElementwiseBinaryOp<"bitwise_and"> { +def TTIR_BitwiseAndOp : TTIR_ElementwiseBinaryOp<"bitwise_and", [TTIR_BinaryIdempotence]> { let summary = "Eltwise bitwise AND."; let description = [{ Performs element-wise bitwise AND of two tensors `lhs` and `rhs` @@ -561,7 +567,7 @@ def TTIR_BitwiseAndOp : TTIR_ElementwiseBinaryOp<"bitwise_and"> { }]; } -def TTIR_BitwiseOrOp : TTIR_ElementwiseBinaryOp<"bitwise_or"> { +def TTIR_BitwiseOrOp : TTIR_ElementwiseBinaryOp<"bitwise_or", [TTIR_BinaryIdempotence]> { let summary = "Eltwise bitwise OR."; let description = [{ Performs element-wise bitwise OR of two tensors `lhs` and `rhs` @@ -587,9 +593,11 @@ def TTIR_BitwiseXorOp : TTIR_ElementwiseBinaryOp<"bitwise_xor"> { %result = "ttir.bitwise_xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> // %result: [[4, 4], [4, 12]] }]; + + let hasCanonicalizer = 1; } -def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum"> { +def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum", [TTIR_BinaryIdempotence]> { let summary = "Eltwise minimum OP."; let description = [{ Calculates minimum of input tensors' values element-wise and stores result @@ -625,7 +633,7 @@ def TTIR_RemainderOp : TTIR_ElementwiseBinaryOp<"remainder"> { } class TTIR_ReductionOp traits = []> : - TTIR_DPSOp { + TTIR_DPSOp { let summary = "Reduction op."; let description = [{ @@ -772,6 +780,8 @@ def TTIR_TransposeOp : TTIR_DPSOp<"transpose"> { }]; let hasVerifier = 1; + + let hasCanonicalizer = 1; } def TTIR_ConcatOp : TTIR_DPSOp<"concat"> { @@ -854,6 +864,8 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> { }]; let hasVerifier = 1; + + let hasFolder = 1; } def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> { @@ -1245,6 +1257,8 @@ def TTIR_ReverseOp : TTIR_DPSOp<"reverse", [AllShapesMatch<["input", "result"]>] }]; let hasVerifier = 1; + + let hasCanonicalizer = 1; } def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike, @@ -1314,6 +1328,8 @@ def TTIR_LinearOp : TTIR_DPSOp<"linear"> { }]; let hasVerifier = 1; + + let hasCanonicalizeMethod = 1; } // ANCHOR: adding_an_op_matmul_ttir @@ -1362,6 +1378,8 @@ def TTIR_PermuteOp : TTIR_DPSOp<"permute"> { }]; let hasVerifier = 1; + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -1369,7 +1387,7 @@ def TTIR_PermuteOp : TTIR_DPSOp<"permute"> { //===----------------------------------------------------------------------===// class TTIR_GenericElementwiseUnaryOp traits = []> : - TTIR_ElementwiseUnaryOp { + TTIR_ElementwiseUnaryOp { let extraClassDeclaration = [{ MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } @@ -1400,7 +1418,7 @@ def TTIR_ExpOp: TTIR_GenericElementwiseUnaryOp<"exp"> { } class TTIR_GenericElementwiseBinaryOp traits = []> : - TTIR_ElementwiseBinaryOp { + TTIR_ElementwiseBinaryOp { let extraClassDeclaration = [{ MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIRTraits.h b/include/ttmlir/Dialect/TTIR/IR/TTIRTraits.h new file mode 100644 index 000000000..6ab34df99 --- /dev/null +++ b/include/ttmlir/Dialect/TTIR/IR/TTIRTraits.h @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TTIR_IR_TTIRTRAITS_H +#define TTMLIR_DIALECT_TTIR_IR_TTIRTRAITS_H + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace tt { +namespace ttir { +namespace OpTrait { + +namespace impl { +bool verifyInvolution(mlir::Operation *op); +bool verifyIdempotence(mlir::Operation *op); +bool verifyBinaryIdempotence(mlir::Operation *op); +mlir::OpFoldResult foldInvolution(mlir::Operation *op); +mlir::OpFoldResult foldIdempotence(mlir::Operation *op); +mlir::OpFoldResult foldBinaryIdempotence(mlir::Operation *op); +} // namespace impl + +template +class TTIRInvolution + : public mlir::TypeTrait::TraitBase { +public: + static mlir::LogicalResult foldTrait(mlir::Operation *op, ArrayRef, + SmallVectorImpl &results) { + if (!impl::verifyInvolution(op)) { + return mlir::failure(); + } + + results.push_back(impl::foldInvolution(op)); + return mlir::success(); + } +}; + +template +class TTIRIdempotence + : public mlir::TypeTrait::TraitBase { +public: + static mlir::LogicalResult foldTrait(mlir::Operation *op, ArrayRef, + SmallVectorImpl &results) { + if (!impl::verifyIdempotence(op)) { + return mlir::failure(); + } + + results.push_back(impl::foldIdempotence(op)); + return mlir::success(); + } +}; + +template +class TTIRBinaryIdempotence + : public mlir::TypeTrait::TraitBase { +public: + static mlir::LogicalResult foldTrait(mlir::Operation *op, ArrayRef, + SmallVectorImpl &results) { + if (!impl::verifyBinaryIdempotence(op)) { + return mlir::failure(); + } + + results.push_back(impl::foldBinaryIdempotence(op)); + return mlir::success(); + } +}; + +} // namespace OpTrait +} // namespace ttir +} // namespace tt +} // namespace mlir + +#endif // TTMLIR_DIALECT_TTIR_IR_TTIRTRAITS_H diff --git a/lib/Dialect/TTIR/IR/CMakeLists.txt b/lib/Dialect/TTIR/IR/CMakeLists.txt index 2dad1a49a..3ea07725a 100644 --- a/lib/Dialect/TTIR/IR/CMakeLists.txt +++ b/lib/Dialect/TTIR/IR/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTTIRDialect TTIRDialect.cpp TTIROps.cpp TTIROpsInterfaces.cpp + TTIRTraits.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ttmlir diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 73daad713..0b58ed860 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -5,7 +5,6 @@ #include "ttmlir/Dialect/TTIR/IR/TTIROps.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTIR/IR/TTIR.h" #include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.cpp.inc" #include "ttmlir/Utils.h" @@ -29,6 +28,53 @@ #define GET_OP_CLASSES #include "ttmlir/Dialect/TTIR/IR/TTIROps.cpp.inc" +//===----------------------------------------------------------------------===// +// BitwiseXorOp +//===----------------------------------------------------------------------===// + +// BitwiseXorOp canonicalization +void mlir::tt::ttir::BitwiseXorOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + // x ^ x == 0 + patterns.add( + +[](mlir::tt::ttir::BitwiseXorOp op, mlir::PatternRewriter &rewriter) { + if (op.getInputs()[0] != op.getInputs()[1]) { + return mlir::failure(); + } + + mlir::RankedTensorType tensorType = + mlir::cast(op.getInputs()[0].getType()); + auto elementType = tensorType.getElementType(); + Attribute zeroAttr; + if (mlir::isa(elementType)) { + zeroAttr = mlir::FloatAttr::get(elementType, 0.0); + } else if (mlir::isa(elementType)) { + zeroAttr = mlir::IntegerAttr::get(elementType, 0); + } else { + return mlir::failure(); + } + auto resultType = mlir::SplatElementsAttr::get(tensorType, zeroAttr); + + rewriter.replaceOpWithNewOp( + op, op->getOperand(0).getType(), resultType); + return mlir::success(); + }); +} + +//===----------------------------------------------------------------------===// +// BroadcastOp +//===----------------------------------------------------------------------===// + +// BroadcastOp folder +::mlir::OpFoldResult mlir::tt::ttir::BroadcastOp::fold(FoldAdaptor adaptor) { + // If the input doesn't change the shape, we can fold the operation. + if (llvm::all_of(getBroadcastDimensions(), + [](const int32_t dim) { return dim == 1; })) { + return getInput(); + } + return {}; +} + //===----------------------------------------------------------------------===// // ClampOp //===----------------------------------------------------------------------===// @@ -91,16 +137,6 @@ ::mlir::OpFoldResult mlir::tt::ttir::ConstantOp::fold(FoldAdaptor adaptor) { // GetDimensionSizeOp //===----------------------------------------------------------------------===// -// GetDimensionSizeOp folder -::mlir::OpFoldResult -mlir::tt::ttir::GetDimensionSizeOp::fold(FoldAdaptor adaptor) { - RankedTensorType inputTensorType = getOperand().getType(); - uint32_t dimensionIndex = getDimension(); - int32_t dimSize = inputTensorType.getShape()[dimensionIndex]; - - return mlir::DenseElementsAttr::get(getType(), dimSize); -} - // GetDimensionSizeOp verification ::mlir::LogicalResult mlir::tt::ttir::GetDimensionSizeOp::verify() { RankedTensorType inputTensorType = getOperand().getType(); @@ -115,6 +151,16 @@ ::mlir::LogicalResult mlir::tt::ttir::GetDimensionSizeOp::verify() { return success(); } +// GetDimensionSizeOp folder +::mlir::OpFoldResult +mlir::tt::ttir::GetDimensionSizeOp::fold(FoldAdaptor adaptor) { + RankedTensorType inputTensorType = getOperand().getType(); + uint32_t dimensionIndex = getDimension(); + int32_t dimSize = inputTensorType.getShape()[dimensionIndex]; + + return mlir::DenseElementsAttr::get(getType(), dimSize); +} + //===----------------------------------------------------------------------===// // Conv2dOp //===----------------------------------------------------------------------===// @@ -376,7 +422,6 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() { // ReshapeOp folder ::mlir::OpFoldResult mlir::tt::ttir::ReshapeOp::fold(FoldAdaptor adaptor) { - if (getType() == getOperand(0).getType()) { return getOperand(0); } @@ -405,8 +450,8 @@ ::mlir::LogicalResult mlir::tt::ttir::BroadcastOp::verify() { // Verify that inputShape can be legally broadcasted to outputShape. llvm::SmallVector broadcastedShape; - if (!OpTrait::util::getBroadcastedShape(inputShape, outputShape, - broadcastedShape)) { + if (!mlir::OpTrait::util::getBroadcastedShape(inputShape, outputShape, + broadcastedShape)) { return emitOpError() << "Input tensor shape (" << ttmlir::utils::join(inputShape, ",") << ") is not broadcastable to output shape (" @@ -859,6 +904,75 @@ ::mlir::LogicalResult mlir::tt::ttir::TransposeOp::verify() { return success(); } +// TransposeOp canonicalization +void mlir::tt::ttir::TransposeOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + // TransposeOp can be removed if the both 'dim0' and 'dim1' are the same. + patterns.add( + +[](mlir::tt::ttir::TransposeOp op, mlir::PatternRewriter &rewriter) { + if (op.getDim0() != op.getDim1()) { + return mlir::failure(); + } + + rewriter.replaceAllOpUsesWith(op, op.getInput()); + return success(); + }); + + // Rewrite a transpose of to a canonical form where the 'dim0' is less than + // 'dim1'. + patterns.add( + +[](mlir::tt::ttir::TransposeOp op, mlir::PatternRewriter &rewriter) { + if (op.getDim0() <= op.getDim1()) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getInput(), op.getOutput(), op.getDim1(), + op.getDim0()); + return mlir::success(); + }); + + // Rewrite a tranpose dims to a canonical form where the 'dim0' and 'dim1' are + // in range [0, N), where N is a rank of input tensor. + patterns.add( + +[](mlir::tt::ttir::TransposeOp op, mlir::PatternRewriter &rewriter) { + int64_t rank = op.getInput().getType().getRank(); + int32_t dim0 = op.getDim0(); + int32_t dim1 = op.getDim1(); + + if (dim0 >= 0 && dim1 >= 0) { + return mlir::failure(); + } + + if (dim0 < 0) { + op.setDim0(dim0 + rank); + } + if (dim1 < 0) { + op.setDim1(dim1 + rank); + } + return mlir::success(); + }); + + // Transposing twice in the row over the same dimensions results in identity, + // hence y = T(T(x)) can be replaced with y = x. + patterns.add( + +[](mlir::tt::ttir::TransposeOp op, mlir::PatternRewriter &rewriter) { + auto producerOp = + op.getInput().getDefiningOp(); + if (!producerOp || op->getName() != producerOp->getName()) { + return mlir::failure(); + } + + if (op.getDim0() != producerOp.getDim0() || + op.getDim1() != producerOp.getDim1()) { + return mlir::failure(); + } + + rewriter.replaceAllOpUsesWith(op, producerOp.getInput()); + return mlir::success(); + }); +} + //===----------------------------------------------------------------------===// // TypecastOp //===----------------------------------------------------------------------===// @@ -1087,8 +1201,8 @@ ::mlir::LogicalResult mlir::tt::ttir::LinearOp::verify() { // Verify that the batch dimensions of input A and B are broadcast // compatible. llvm::SmallVector broadcastedShape; - if (!OpTrait::util::getBroadcastedShape(inputABatchDims, inputBBatchDims, - broadcastedShape)) { + if (!mlir::OpTrait::util::getBroadcastedShape( + inputABatchDims, inputBBatchDims, broadcastedShape)) { return emitOpError("Batch dimensions of input A(" + ttmlir::utils::join(inputABatchDims, ",") + @@ -1125,8 +1239,8 @@ ::mlir::LogicalResult mlir::tt::ttir::LinearOp::verify() { // Verify that the dimensions of the matmul of A and B are broadcast // compatible with input bias. llvm::SmallVector matmulShape = expectedOutputShape; - if (!OpTrait::util::getBroadcastedShape(matmulShape, biasShape, - expectedOutputShape)) { + if (!mlir::OpTrait::util::getBroadcastedShape(matmulShape, biasShape, + expectedOutputShape)) { return emitOpError("Bias shape(" + ttmlir::utils::join(biasShape, ",") + ") is not broadcast compatible with the matmul output " "shape(" + @@ -1172,6 +1286,19 @@ ::mlir::LogicalResult mlir::tt::ttir::LinearOp::verify() { return success(); } +// LinearOp canonicalize method +::mlir::LogicalResult +mlir::tt::ttir::LinearOp::canonicalize(ttir::LinearOp op, + mlir::PatternRewriter &rewriter) { + if (op.getBias()) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp(op, op.getType(), op.getA(), + op.getB(), op.getOutput()); + return mlir::success(); +} + //===----------------------------------------------------------------------===// // MatmulOp //===----------------------------------------------------------------------===// @@ -1238,8 +1365,8 @@ ::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() { // Verify that the batch dimensions of input A and B are broadcast // compatible llvm::SmallVector broadcastedShape; - if (!OpTrait::util::getBroadcastedShape(inputABatchDims, inputBBatchDims, - broadcastedShape)) { + if (!mlir::OpTrait::util::getBroadcastedShape( + inputABatchDims, inputBBatchDims, broadcastedShape)) { return emitOpError("Batch dimensions of input A(" + ttmlir::utils::join(inputABatchDims, ",") + @@ -1604,6 +1731,7 @@ ::mlir::LogicalResult mlir::tt::ttir::FillCacheOp::verify() { // ReverseOp //===----------------------------------------------------------------------===// +// ReverseOp verification ::mlir::LogicalResult mlir::tt::ttir::ReverseOp::verify() { llvm::ArrayRef dimensions = getDimensions(); @@ -1634,6 +1762,49 @@ ::mlir::LogicalResult mlir::tt::ttir::ReverseOp::verify() { return success(); } +// ReverseOp canonicalization +void mlir::tt::ttir::ReverseOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + // Reverse dimensions of two consecutive ReverseOps can be folded into a + // single ReverseOp where the dimensions are the symmetric difference of the + // two sets of dimensions. + patterns.add(+[](mlir::tt::ttir::ReverseOp op, + mlir::PatternRewriter &rewriter) { + auto producerOp = op.getInput().getDefiningOp(); + if (!producerOp) { + return mlir::failure(); + } + + llvm::SmallBitVector reverseDimensions(op.getInput().getType().getRank()); + llvm::for_each(op.getDimensions(), [&reverseDimensions](int64_t dim) { + reverseDimensions.flip(dim); + }); + llvm::for_each( + producerOp.getDimensions(), + [&reverseDimensions](int64_t dim) { reverseDimensions.flip(dim); }); + + llvm::SmallVector setIndices; + llvm::copy_if(llvm::seq(reverseDimensions.size()), + std::back_inserter(setIndices), + [&](int64_t i) { return reverseDimensions.test(i); }); + + rewriter.replaceOpWithNewOp( + op, op.getType(), producerOp.getInput(), op.getOutput(), setIndices); + return success(); + }); + + // ReverseOp with empty reverse dimensions is a no-op. + patterns.add( + +[](mlir::tt::ttir::ReverseOp op, mlir::PatternRewriter &rewriter) { + if (!op.getDimensions().empty()) { + return mlir::failure(); + } + + rewriter.replaceAllOpUsesWith(op, op.getInput()); + return mlir::success(); + }); +} + //===----------------------------------------------------------------------===// // PermuteOp //===----------------------------------------------------------------------===// @@ -1670,6 +1841,46 @@ ::mlir::LogicalResult mlir::tt::ttir::PermuteOp::verify() { return success(); } +// PermuteOp canonicalization +void mlir::tt::ttir::PermuteOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + // Permute dimensions of two consecutive PermuteOps can be folded into a + // single PermuteOp where the permutation is the composition of the two + // permutations. + patterns.add( + +[](mlir::tt::ttir::PermuteOp op, mlir::PatternRewriter &rewriter) { + auto producerOp = op.getInput().getDefiningOp(); + if (!producerOp) { + return mlir::failure(); + } + + // I: identity permutation + // P1: permutation of producerOp + // P2: permutation of op + // P: permutation of the composed PermuteOp + // P = applyPermutation(applyPermutation(I, P1), P2) = + // applyPermutation(P1, P2) + llvm::SmallVector composedPermutation = + ttmlir::utils::applyPermutation(producerOp.getPermutation(), + op.getPermutation()); + + rewriter.replaceOpWithNewOp( + op, op.getType(), producerOp.getInput(), op.getOutput(), + composedPermutation); + return mlir::success(); + }); + + // PermuteOp with identity permutation is a no-op. + patterns.add( + +[](mlir::tt::ttir::PermuteOp op, mlir::PatternRewriter &rewriter) { + if (llvm::is_sorted(op.getPermutation())) { + rewriter.replaceAllOpUsesWith(op, op.getInput()); + return mlir::success(); + } + return mlir::failure(); + }); +} + //===----------------------------------------------------------------------===// // GenericOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTIR/IR/TTIRTraits.cpp b/lib/Dialect/TTIR/IR/TTIRTraits.cpp new file mode 100644 index 000000000..ab8d94bfc --- /dev/null +++ b/lib/Dialect/TTIR/IR/TTIRTraits.cpp @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TTIR/IR/TTIRTraits.h" + +// Check if all operands and result have the same type. Function assumes op has +// at least one operand and exactly one result. +static bool operandAndResultSameType(mlir::Operation *op) { + return llvm::all_equal(op->getOperandTypes()) && + op->getOperand(0).getType() == op->getResult(0).getType(); +} + +// If Op has TTIRInvolution trait, then it's foldable if: +// 1. Argument and result types are the same. +// 2. Argument is defined by the same op. +// 3. 1) is true for the producing op of the argument. +// op(op(T a, T r0), T r1) +bool mlir::tt::ttir::OpTrait::impl::verifyInvolution(mlir::Operation *op) { + if (!operandAndResultSameType(op)) { + return false; + } + Operation *producerOp = op->getOperand(0).getDefiningOp(); + if (!producerOp || producerOp->getName() != op->getName()) { + return false; + } + return operandAndResultSameType(producerOp); +} + +// If Op has TTIRIdempotence trait, then it's foldable if: +// 1. Argument and result types are the same. +// 2. Argument is defined by the same op. +// 3. 1) is true for the producing op of the argument. +// op(op(T a, T r0), T r1) +bool mlir::tt::ttir::OpTrait::impl::verifyIdempotence(mlir::Operation *op) { + if (!operandAndResultSameType(op)) { + return false; + } + mlir::Operation *producerOp = op->getOperand(0).getDefiningOp(); + if (!producerOp || producerOp->getName() != op->getName()) { + return false; + } + return operandAndResultSameType(producerOp); +} + +// If Op has TTIRBinaryIdempotence trait, then it's foldable if: +// 1. Both inputs are the same. +// 2. Inputs and result types are the same. +bool mlir::tt::ttir::OpTrait::impl::verifyBinaryIdempotence( + mlir::Operation *op) { + if (op->getOperand(0) != op->getOperand(1)) { + return false; + } + + return op->getResult(0).getType() == op->getOperand(0).getType(); +} + +mlir::OpFoldResult +mlir::tt::ttir::OpTrait::impl::foldInvolution(mlir::Operation *op) { + return op->getOperand(0).getDefiningOp()->getOperand(0); +} + +mlir::OpFoldResult +mlir::tt::ttir::OpTrait::impl::foldIdempotence(mlir::Operation *op) { + return op->getOperand(0); +} + +mlir::OpFoldResult +mlir::tt::ttir::OpTrait::impl::foldBinaryIdempotence(mlir::Operation *op) { + return op->getOperand(0); +} diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index f1ec29999..2f70764da 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -22,7 +22,9 @@ void createTTNNPipelineTTIRPasses( ttir::TTIRLoadSystemDescOptions systemDescOptions; systemDescOptions.path = options.systemDescPath; + pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::tt::createTTIRToTTIRDecompositionPass()); + pm.addPass(mlir::createCanonicalizerPass()); // Inlines all private functions. I.e flattens the program into the main // function. Removes all private functions. diff --git a/test/ttmlir/Dialect/TTIR/canonicalize/binary_idempotence_tests.mlir b/test/ttmlir/Dialect/TTIR/canonicalize/binary_idempotence_tests.mlir new file mode 100644 index 000000000..b972f24e6 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/canonicalize/binary_idempotence_tests.mlir @@ -0,0 +1,9 @@ +// RUN: ttmlir-opt -canonicalize %s | FileCheck %s +module { + func.func @binary_idempotence(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32> { + // CHECK-NOT: "ttir.logical_and" + %0 = tensor.empty() : tensor<64x64xf32> + %1 = "ttir.logical_and"(%arg0, %arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + return %1 : tensor<64x64xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/canonicalize/bitwise_xor_canonicalize_tests.mlir b/test/ttmlir/Dialect/TTIR/canonicalize/bitwise_xor_canonicalize_tests.mlir new file mode 100644 index 000000000..007e5fab1 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/canonicalize/bitwise_xor_canonicalize_tests.mlir @@ -0,0 +1,22 @@ +// RUN: ttmlir-opt -canonicalize %s | FileCheck %s +module { + func.func @bitwise_xor_integer(%arg0: tensor<64x128xui16>) -> tensor<64x128xui16> { + // CHECK-NOT: "ttir.bitwise_xor" + // CHECK: "ttir.constant" + // CHECK-SAME: value = dense<0> : tensor<64x128xui16> + // CHECK-NOT: "ttir.bitwise_xor" + %0 = tensor.empty() : tensor<64x128xui16> + %1 = "ttir.bitwise_xor"(%arg0, %arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xui16>, tensor<64x128xui16>, tensor<64x128xui16>) -> tensor<64x128xui16> + return %1 : tensor<64x128xui16> + } + + func.func @bitwise_xor_float(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + // CHECK-NOT: "ttir.bitwise_xor" + // CHECK: "ttir.constant" + // CHECK-SAME: value = dense<0.000000e+00> : tensor<64x128xbf16> + // CHECK-NOT: "ttir.bitwise_xor" + %0 = tensor.empty() : tensor<64x128xbf16> + %1 = "ttir.bitwise_xor"(%arg0, %arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTIR/canonicalize/broadcast_fold_tests.mlir b/test/ttmlir/Dialect/TTIR/canonicalize/broadcast_fold_tests.mlir new file mode 100644 index 000000000..7c40ecf1e --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/canonicalize/broadcast_fold_tests.mlir @@ -0,0 +1,9 @@ +// RUN: ttmlir-opt -canonicalize %s | FileCheck %s +module { + func.func @broadcast_noop(%arg0: tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> { + // CHECK-NOT: "ttir.broadcast" + %0 = tensor.empty() : tensor<1x2x3x4x5xbf16> + %1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array}> : (tensor<1x2x3x4x5xbf16>, tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> + return %1 : tensor<1x2x3x4x5xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTIR/canonicalize/idempotence_tests.mlir b/test/ttmlir/Dialect/TTIR/canonicalize/idempotence_tests.mlir new file mode 100644 index 000000000..c2307b7e1 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/canonicalize/idempotence_tests.mlir @@ -0,0 +1,34 @@ +// RUN: ttmlir-opt -canonicalize %s | FileCheck %s +module { + func.func @idempotence_two_in_the_row(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32> { + // CHECK: "ttir.relu" + // CHECK-NOT: "ttir.relu" + %0 = tensor.empty() : tensor<64x64xf32> + %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + %2 = tensor.empty() : tensor<64x64xf32> + %3 = "ttir.relu"(%1, %2) <{operandSegmentSizes = array}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + return %3 : tensor<64x64xf32> + } + + func.func @idempotence_three_in_the_row(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32> { + // CHECK: "ttir.relu" + // CHECK-NOT: "ttir.relu" + %0 = tensor.empty() : tensor<64x64xf32> + %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + %2 = tensor.empty() : tensor<64x64xf32> + %3 = "ttir.relu"(%1, %2) <{operandSegmentSizes = array}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + %4 = tensor.empty() : tensor<64x64xf32> + %5 = "ttir.relu"(%2, %4) <{operandSegmentSizes = array}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + return %5 : tensor<64x64xf32> + } + + func.func @not_idempotence_diffrent_types(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32> { + // CHECK: "ttir.relu" + // CHECK: "ttir.relu" + %0 = tensor.empty() : tensor<64x64xbf16> + %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x64xf32>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %2 = tensor.empty() : tensor<64x64xf32> + %3 = "ttir.relu"(%1, %2) <{operandSegmentSizes = array}> : (tensor<64x64xbf16>, tensor<64x64xf32>) -> tensor<64x64xf32> + return %3 : tensor<64x64xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/canonicalize/involution_tests.mlir b/test/ttmlir/Dialect/TTIR/canonicalize/involution_tests.mlir new file mode 100644 index 000000000..5ae4c42ca --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/canonicalize/involution_tests.mlir @@ -0,0 +1,23 @@ +// RUN: ttmlir-opt -canonicalize %s | FileCheck %s +module { + func.func @involution_two_in_the_row(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32> { + // CHECK-NOT: "ttir.neg" + %0 = tensor.empty() : tensor<64x64xf32> + %1 = "ttir.neg"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + %2 = tensor.empty() : tensor<64x64xf32> + %3 = "ttir.neg"(%1, %2) <{operandSegmentSizes = array}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + return %3 : tensor<64x64xf32> + } + + func.func @involution_three_in_the_row(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32> { + // CHECK: "ttir.neg" + // CHECK-NOT: "ttir.neg" + %0 = tensor.empty() : tensor<64x64xf32> + %1 = "ttir.neg"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + %2 = tensor.empty() : tensor<64x64xf32> + %3 = "ttir.neg"(%1, %2) <{operandSegmentSizes = array}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + %4 = tensor.empty() : tensor<64x64xf32> + %5 = "ttir.neg"(%3, %4) <{operandSegmentSizes = array}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + return %5 : tensor<64x64xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/canonicalize/linear_op_fold_tests.mlir b/test/ttmlir/Dialect/TTIR/canonicalize/linear_op_fold_tests.mlir new file mode 100644 index 000000000..4c1780fac --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/canonicalize/linear_op_fold_tests.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt -canonicalize %s | FileCheck %s +module { + func.func @linear_1d_1d(%arg0: tensor<3x64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<3x64x64xbf16> { + %0 = tensor.empty() : tensor<3x64x64xbf16> + // CHECK-NOT: "ttir.linear" + // CHECK: "ttir.matmul"(%arg0, %arg1, %0) + // CHECK-NOT: "ttir.linear" + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<3x64x128xbf16>, tensor<128x64xbf16>, tensor<3x64x64xbf16>) -> tensor<3x64x64xbf16> + return %1 : tensor<3x64x64xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTIR/canonicalize/permute_op_canonicalize_tests.mlir b/test/ttmlir/Dialect/TTIR/canonicalize/permute_op_canonicalize_tests.mlir new file mode 100644 index 000000000..0191de8d6 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/canonicalize/permute_op_canonicalize_tests.mlir @@ -0,0 +1,20 @@ +// RUN: ttmlir-opt -canonicalize %s | FileCheck %s +module { + func.func @permute_composition(%arg0: tensor<1x2x3x4x5xbf16>) -> tensor<3x2x1x5x4xbf16> { + // CHECK: "ttir.permute" + // CHECK-SAME: permutation = array + // CHECK-NOT: "ttir.permute" + %0 = tensor.empty() : tensor<3x2x5x4x1xbf16> + %1 = "ttir.permute"(%arg0, %0) <{permutation = array}> : (tensor<1x2x3x4x5xbf16>, tensor<3x2x5x4x1xbf16>) -> tensor<3x2x5x4x1xbf16> + %2 = tensor.empty() : tensor<3x2x1x5x4xbf16> + %3 = "ttir.permute"(%1, %2) <{permutation = array}> : (tensor<3x2x5x4x1xbf16>, tensor<3x2x1x5x4xbf16>) -> tensor<3x2x1x5x4xbf16> + return %3 : tensor<3x2x1x5x4xbf16> + } + + func.func @permute_noop(%arg0: tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> { + // CHECK-NOT: "ttir.permute" + %0 = tensor.empty() : tensor<1x2x3x4x5xbf16> + %1 = "ttir.permute"(%arg0, %0) <{permutation = array}> : (tensor<1x2x3x4x5xbf16>, tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> + return %1 : tensor<1x2x3x4x5xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTIR/canonicalize/revese_op_canonicalize_tests.mlir b/test/ttmlir/Dialect/TTIR/canonicalize/revese_op_canonicalize_tests.mlir new file mode 100644 index 000000000..753e09a2b --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/canonicalize/revese_op_canonicalize_tests.mlir @@ -0,0 +1,30 @@ +// RUN: ttmlir-opt -canonicalize %s | FileCheck %s +module { + func.func @reverse_composition(%arg0: tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> { + // CHECK: "ttir.reverse" + // CHECK-SAME: dimensions = array + // CHECK-NOT: "ttir.reverse" + %0 = tensor.empty() : tensor<1x2x3x4x5xbf16> + %1 = "ttir.reverse"(%arg0, %0) <{dimensions = array}> : (tensor<1x2x3x4x5xbf16>, tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> + %2 = tensor.empty() : tensor<1x2x3x4x5xbf16> + %3 = "ttir.reverse"(%1, %2) <{dimensions = array}> : (tensor<1x2x3x4x5xbf16>, tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> + return %3 : tensor<1x2x3x4x5xbf16> + } + + func.func @reverse_noop(%arg0: tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> { + // CHECK-NOT: "ttir.reverse" + %0 = tensor.empty() : tensor<1x2x3x4x5xbf16> + %1 = "ttir.reverse"(%arg0, %0) <{dimensions = array}> : (tensor<1x2x3x4x5xbf16>, tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> + return %1 : tensor<1x2x3x4x5xbf16> + } + + + func.func @reverse_composition_noop(%arg0: tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> { + // CHECK-NOT: "ttir.reverse" + %0 = tensor.empty() : tensor<1x2x3x4x5xbf16> + %1 = "ttir.reverse"(%arg0, %0) <{dimensions = array}> : (tensor<1x2x3x4x5xbf16>, tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> + %2 = tensor.empty() : tensor<1x2x3x4x5xbf16> + %3 = "ttir.reverse"(%1, %2) <{dimensions = array}> : (tensor<1x2x3x4x5xbf16>, tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> + return %3 : tensor<1x2x3x4x5xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTIR/canonicalize/transpose_fold_tests.mlir b/test/ttmlir/Dialect/TTIR/canonicalize/transpose_fold_tests.mlir new file mode 100644 index 000000000..6d5df4476 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/canonicalize/transpose_fold_tests.mlir @@ -0,0 +1,20 @@ +// RUN: ttmlir-opt -canonicalize %s | FileCheck %s +module { + func.func @transpose_involution(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + // CHECK-NOT: "ttir.transpose" + %0 = tensor.empty() : tensor<128x64xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16> + %2 = tensor.empty() : tensor<64x128xbf16> + %3 = "ttir.transpose"(%1, %2) <{dim0 = 1 : si32, dim1 = 0 : si32}> : (tensor<128x64xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %3 : tensor<64x128xbf16> + } + + func.func @transpose_normalize_range(%arg0: tensor<32x64x128xbf16>) -> tensor<32x128x64xbf16> { + // CHECK: "ttir.transpose" + // CHECK-SAME: dim0 = 1 : si32 + // CHECK-SAME: dim1 = 2 : si32 + %0 = tensor.empty() : tensor<32x128x64xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = -1 : si32, dim1 = -2 : si32}> : (tensor<32x64x128xbf16>, tensor<32x128x64xbf16>) -> tensor<32x128x64xbf16> + return %1 : tensor<32x128x64xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir b/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir index 8719ab3cc..8767a0409 100644 --- a/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir +++ b/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir @@ -8,8 +8,6 @@ module { // CHECK-SAME: shape = [1024 : i32, 256 : i32, 1 : i32, 1 : i32] // CHECK: "ttnn.permute" // CHECK-SAME: permutation = array - // CHECK: "ttnn.permute" - // CHECK-SAME: permutation = array // CHECK: "ttnn.conv2d" // CHECK: "ttnn.permute" // CHECK-SAME: permutation = array diff --git a/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir index ef0a6729e..620cd1d5c 100644 --- a/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir +++ b/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir @@ -1,18 +1,5 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s module { - func.func @linear_1d_1d(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<1xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<1xbf16 - %0 = tensor.empty() : tensor<1xbf16> - // CHECK: "ttnn.linear" - // CHECK-SAME: tensor<128xbf16 - // CHECK-SAME: tensor<128xbf16 - // CHECK-SAME: tensor<1xbf16 - // CHECK-SAME: tensor<1xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>) -> tensor<1xbf16> - return %1 : tensor<1xbf16> - } - func.func @linear_1d_1d_bias(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>, %bias: tensor<1xbf16>) -> tensor<1xbf16> { // CHECK: "ttnn.empty" // CHECK-SAME: tensor<1xbf16 @@ -41,33 +28,7 @@ module { return %1 : tensor<128xbf16> } - func.func @linear_2d_1d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128xbf16>) -> tensor<64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<64xbf16 - %0 = tensor.empty() : tensor<64xbf16> - // CHECK: "ttnn.linear" - // CHECK-SAME: tensor<64x128xbf16 - // CHECK-SAME: tensor<128xbf16 - // CHECK-SAME: tensor<64xbf16 - // CHECK-SAME: tensor<64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128xbf16>, tensor<64xbf16>) -> tensor<64xbf16> - return %1 : tensor<64xbf16> - } - - func.func @linear_2d_2d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<64x64xbf16 - %0 = tensor.empty() : tensor<64x64xbf16> - // CHECK: "ttnn.linear" - // CHECK-SAME: tensor<64x128xbf16 - // CHECK-SAME: tensor<128x64xbf16 - // CHECK-SAME: tensor<64x64xbf16 - // CHECK-SAME: tensor<64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> - return %1 : tensor<64x64xbf16> - } - - func.func @linear_2d_2d_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + func.func @linear_2d_2d_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { // CHECK: "ttnn.empty" // CHECK-SAME: tensor<64x64xbf16 %0 = tensor.empty() : tensor<64x64xbf16> @@ -81,111 +42,7 @@ module { return %1 : tensor<64x64xbf16> } - func.func @linear_1d_nd(%arg0: tensor<128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<12x7x64xbf16 - %0 = tensor.empty() : tensor<12x7x64xbf16> - // CHECK: "ttnn.linear" - // CHECK-SAME: tensor<128xbf16 - // CHECK-SAME: tensor<12x7x128x64xbf16 - // CHECK-SAME: tensor<12x7x64xbf16 - // CHECK-SAME: tensor<12x7x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64xbf16>) -> tensor<12x7x64xbf16> - return %1 : tensor<12x7x64xbf16> - } - - func.func @linear_nd_1d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64xbf16>) -> tensor<12x7x128xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<12x7x128xbf16 - %0 = tensor.empty() : tensor<12x7x128xbf16> - // CHECK: "ttnn.linear" - // CHECK-SAME: tensor<12x7x128x64xbf16 - // CHECK-SAME: tensor<64xbf16 - // CHECK-SAME: tensor<12x7x128xbf16 - // CHECK-SAME: tensor<12x7x128xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<12x7x128x64xbf16>, tensor<64xbf16>, tensor<12x7x128xbf16>) -> tensor<12x7x128xbf16> - return %1 : tensor<12x7x128xbf16> - } - - func.func @linear_2d_nd(%arg0: tensor<64x128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<12x7x64x64xbf16 - %0 = tensor.empty() : tensor<12x7x64x64xbf16> - // CHECK: "ttnn.linear" - // CHECK-SAME: tensor<64x128xbf16 - // CHECK-SAME: tensor<12x7x128x64xbf16 - // CHECK-SAME: tensor<12x7x64x64xbf16 - // CHECK-SAME: tensor<12x7x64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64x64xbf16>) -> tensor<12x7x64x64xbf16> - return %1 : tensor<12x7x64x64xbf16> - } - - func.func @linear_nd_2d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<12x7x128x128xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<12x7x128x128xbf16 - %0 = tensor.empty() : tensor<12x7x128x128xbf16> - // CHECK: "ttnn.linear" - // CHECK-SAME: tensor<12x7x128x64xbf16 - // CHECK-SAME: tensor<64x128xbf16 - // CHECK-SAME: tensor<12x7x128x128xbf16 - // CHECK-SAME: tensor<12x7x128x128xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<12x7x128x64xbf16>, tensor<64x128xbf16>, tensor<12x7x128x128xbf16>) -> tensor<12x7x128x128xbf16> - return %1 : tensor<12x7x128x128xbf16> - } - // linear nd - nd tests - func.func @linear_nd_nd_same_rank_same_dims(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<7x128x64xbf16>) -> tensor<7x64x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<7x64x64xbf16 - %0 = tensor.empty() : tensor<7x64x64xbf16> - // CHECK: "ttnn.linear" - // CHECK-SAME: tensor<7x64x128xbf16 - // CHECK-SAME: tensor<7x128x64xbf16 - // CHECK-SAME: tensor<7x64x64xbf16 - // CHECK-SAME: tensor<7x64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<7x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> - return %1 : tensor<7x64x64xbf16> - } - - func.func @linear_nd_nd_same_rank_broadcastable_dims_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x128x64xbf16>) -> tensor<7x64x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<7x64x64xbf16 - %0 = tensor.empty() : tensor<7x64x64xbf16> - // CHECK: "ttnn.linear" - // CHECK-SAME: tensor<7x64x128xbf16 - // CHECK-SAME: tensor<1x128x64xbf16 - // CHECK-SAME: tensor<7x64x64xbf16 - // CHECK-SAME: tensor<7x64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> - return %1 : tensor<7x64x64xbf16> - } - - func.func @linear_nd_nd_same_rank_broadcastable_dims_2(%arg0: tensor<1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<7x7x64x64xbf16 - %0 = tensor.empty() : tensor<7x7x64x64xbf16> - // CHECK: "ttnn.linear" - // CHECK-SAME: tensor<1x7x64x128xbf16 - // CHECK-SAME: tensor<7x1x128x64xbf16 - // CHECK-SAME: tensor<7x7x64x64xbf16 - // CHECK-SAME: tensor<7x7x64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x7x64x64xbf16>) -> tensor<7x7x64x64xbf16> - return %1 : tensor<7x7x64x64xbf16> - } - - func.func @linear_nd_nd_different_rank_broadcastable_dims_2(%arg0: tensor<12x1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<12x7x7x64x64xbf16 - %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> - // CHECK: "ttnn.linear" - // CHECK-SAME: tensor<12x1x7x64x128xbf16 - // CHECK-SAME: tensor<7x1x128x64xbf16 - // CHECK-SAME: tensor<12x7x7x64x64xbf16 - // CHECK-SAME: tensor<12x7x7x64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<12x1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> - return %1 : tensor<12x7x7x64x64xbf16> - } - func.func @linear_nd_nd_bias_broadcast_bias(%arg0: tensor<14x7x32x32xbf16>, %arg1:tensor<14x1x32x64xbf16>, %bias: tensor<64xbf16>) -> tensor<14x7x32x64xbf16> { // CHECK: "ttnn.empty" // CHECK-SAME: tensor<14x7x32x64xbf16 diff --git a/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir b/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir index 44165e05d..579bb2e82 100644 --- a/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir +++ b/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir @@ -1,19 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s module { - func.func @simple_linear_without_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<64x64xbf16 - %0 = tensor.empty() : tensor<64x64xbf16> - // CHECK: "ttnn.linear" - // CHECK-SAME: tensor<64x128xbf16 - // CHECK-SAME: tensor<128x64xbf16 - // CHECK-SAME: tensor<64x64xbf16 - // CHECK-SAME: tensor<64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> - return %1 : tensor<64x64xbf16> - } - func.func @simple_linear_with_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { // CHECK: "ttnn.empty" // CHECK-SAME: tensor<64x64xbf16 diff --git a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/fork_join.mlir b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/fork_join.mlir index 657da9339..a9d2af1c9 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/fork_join.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/fork_join.mlir @@ -27,7 +27,7 @@ module attributes {} { // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_3]]> %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> %2 = tensor.empty() : tensor<64x64xbf16> - %3 = "ttir.relu"(%1, %2) <{operandSegmentSizes = array}> : (tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %3 = "ttir.gelu"(%1, %2) <{operandSegmentSizes = array}> : (tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> %4 = tensor.empty() : tensor<64x32xbf16> %5 = "ttir.matmul"(%1, %arg1, %4) : (tensor<64x64xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> %6 = tensor.empty() : tensor<64x32xbf16> @@ -35,7 +35,7 @@ module attributes {} { %8 = tensor.empty() : tensor<64x32xbf16> // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_5]]> // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_5]]> - // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_5]]> + // CHECK: %{{.*}} = "ttnn.gelu"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_5]]> // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_5]]> %9 = "ttir.matmul"(%3, %7, %8) : (tensor<64x64xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> return %9 : tensor<64x32xbf16> diff --git a/test/ttmlir/Dialect/TTNN/permute/permute_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/permute/permute_tests_positive.mlir index 54988e880..90feda1e7 100644 --- a/test/ttmlir/Dialect/TTNN/permute/permute_tests_positive.mlir +++ b/test/ttmlir/Dialect/TTNN/permute/permute_tests_positive.mlir @@ -1,15 +1,5 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s module { - func.func @permute_identity(%arg0: tensor<8x32x64x128xf32>) -> tensor<8x32x64x128xf32> { - %0 = tensor.empty() : tensor<8x32x64x128xf32> - // CHECK: "ttnn.permute" - // CHECK-SAME: permutation = array - // CHECK-SAME: tensor<8x32x64x128xf32 - // CHECK-SAME: tensor<8x32x64x128xf32 - %1 = "ttir.permute"(%arg0, %0) <{permutation = array}> : (tensor<8x32x64x128xf32>, tensor<8x32x64x128xf32>) -> tensor<8x32x64x128xf32> - return %1 : tensor<8x32x64x128xf32> - } - func.func @permute_general(%arg0: tensor<8x32x64x128xf32>) -> tensor<64x8x128x32xf32> { %0 = tensor.empty() : tensor<64x8x128x32xf32> // CHECK: "ttnn.permute" @@ -19,14 +9,4 @@ module { %1 = "ttir.permute"(%arg0, %0) <{permutation = array}> : (tensor<8x32x64x128xf32>, tensor<64x8x128x32xf32>) -> tensor<64x8x128x32xf32> return %1 : tensor<64x8x128x32xf32> } - - func.func @permute_1d(%arg0: tensor<32xf32>) -> tensor<32xf32> { - %0 = tensor.empty() : tensor<32xf32> - // CHECK: "ttnn.permute" - // CHECK-SAME: permutation = array - // CHECK-SAME: tensor<32xf32 - // CHECK-SAME: tensor<32xf32 - %1 = "ttir.permute"(%arg0, %0) <{permutation = array}> : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> - return %1 : tensor<32xf32> - } } diff --git a/test/ttmlir/Silicon/TTNN/simple_linear.mlir b/test/ttmlir/Silicon/TTNN/simple_linear.mlir index b65bf99db..9d283777f 100644 --- a/test/ttmlir/Silicon/TTNN/simple_linear.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_linear.mlir @@ -3,19 +3,6 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn module { - func.func @simple_linear_without_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<64x64xbf16 - %0 = tensor.empty() : tensor<64x64xbf16> - // CHECK: "ttnn.linear" - // CHECK-SAME: tensor<64x128xbf16 - // CHECK-SAME: tensor<128x64xbf16 - // CHECK-SAME: tensor<64x64xbf16 - // CHECK-SAME: tensor<64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> - return %1 : tensor<64x64xbf16> - } - func.func @simple_linear_with_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { // CHECK: "ttnn.empty" // CHECK-SAME: tensor<64x64xbf16