Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Canonicalization of TTIR ops #1670

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion include/ttmlir/Dialect/TTIR/IR/TTIRBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"


//===----------------------------------------------------------------------===//
// TTIR dialect definition.
Expand Down Expand Up @@ -41,6 +43,28 @@ def TTIR_Dialect : Dialect {
//===----------------------------------------------------------------------===//

class TTIR_Op<string mnemonic, list<Trait> traits = []> :
Op<TTIR_Dialect, mnemonic, !listconcat(traits, [Pure])>;
Op<TTIR_Dialect, mnemonic, !listconcat([Pure], traits)>;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed the order of listconcatarguments everywhere, because of dependent trait list for traits like TTIR_Involution. Otherwise it fails long before compiling (it check for existence of this traits while still parsing .td file), so it doesn't see a trait/interface like DestinationStyleOpInterface on ops that has it, because it builds list in bottom-up fashion ([trait_of_op, trait_of_ops_class, trait_op_ops_class_class,...]).

From what I saw it doesn't make any difference in generated C++ code. @nsmithtt Did you have an experience with this?


//===----------------------------------------------------------------------===//
// TTIR traits definition.
//===----------------------------------------------------------------------===//

def TwoOperands : ParamNativeOpTrait<"NOperands", "2">;
def ThreeOperands : ParamNativeOpTrait<"NOperands", "3">;
def FourOperands : ParamNativeOpTrait<"NOperands", "4">;

class TTIR_Trait<string name, list<Trait> traits = []> : NativeOpTrait<name, traits> {
let cppNamespace = "::mlir::tt::ttir::OpTrait";
}

// Involution is property of an operation where applying the operation twice results in the original value.
// Example: not(not(x)) == x
def TTIR_Involution : TTIR_Trait<"TTIRInvolution", [DestinationStyleOpInterface, TwoOperands, NoMemoryEffect]>;
// Idempotence is property of an operation where applying the operation twice results in the same value as applying it once.
// Example: abs(abs(x)) == abs(x)
def TTIR_Idempotence : TTIR_Trait<"TTIRIdempotence", [DestinationStyleOpInterface, TwoOperands, NoMemoryEffect]>;
// BinaryIdempotence is property of a binary operation where applying the operation on the same value results in the same value.
// Example: and(x, x) == x
def TTIR_BinaryIdempotence : TTIR_Trait<"TTIRBinaryIdempotence", [DestinationStyleOpInterface, ThreeOperands, NoMemoryEffect]>;

#endif
8 changes: 5 additions & 3 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
#ifndef TTMLIR_DIALECT_TTIR_IR_TTIROPS_H
#define TTMLIR_DIALECT_TTIR_IR_TTIROPS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h"
#include "ttmlir/Dialect/TTIR/IR/TTIRTraits.h"

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"

#include "TTIROpsInterfaces.h"

#include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.h.inc"

Expand Down
70 changes: 44 additions & 26 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/OpBase.td"

class TTIR_DPSOp<string mnemonic, list<Trait> traits = []> :
TTIR_Op<mnemonic, !listconcat(traits, [TTIROpInterface, DestinationStyleOpInterface])> {
TTIR_Op<mnemonic, !listconcat([TTIROpInterface, DestinationStyleOpInterface], traits)> {
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
}];
Expand Down Expand Up @@ -98,8 +98,9 @@ def TTIR_GetDimensionSizeOp : TTIR_Op<"get_dimension_size"> {

let results = (outs AnyRankedTensor:$result);

let hasFolder = 1;
let hasVerifier = 1;

let hasFolder = 1;
}

def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpInterface]> {
Expand Down Expand Up @@ -166,12 +167,8 @@ def TTIR_DeallocOp : TTIR_Op<"dealloc"> {
// TTIR top level named ops
//===----------------------------------------------------------------------===//

def TwoOperands : ParamNativeOpTrait<"NOperands", "2">;
def ThreeOperands : ParamNativeOpTrait<"NOperands", "3">;
def FourOperands : ParamNativeOpTrait<"NOperands", "4">;

class TTIR_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
TTIR_DPSOp<mnemonic, !listconcat(traits, [AttrSizedOperandSegments, TTIR_Broadcastable])> {
TTIR_DPSOp<mnemonic, !listconcat([AttrSizedOperandSegments, TTIR_Broadcastable], traits)> {

let description = [{
Base class for elementwise operations. Elementwise operations can take inputs with different shape,
Expand All @@ -184,7 +181,7 @@ class TTIR_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
}

class TTIR_ElementwiseTernaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, !listconcat(traits, [FourOperands])> {
TTIR_ElementwiseOp<mnemonic, !listconcat([FourOperands], traits)> {
let summary = "Eltwise ternary op.";
let description = [{
Eltwise ternary op.
Expand All @@ -207,7 +204,7 @@ def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where"> {
}

class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, !listconcat(traits, [TwoOperands])> {
TTIR_ElementwiseOp<mnemonic, !listconcat([TwoOperands], traits)> {
let summary = "Eltwise unary op.";
let description = [{
Eltwise unary op.
Expand All @@ -222,7 +219,7 @@ class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
];
}

def TTIR_AbsOp: TTIR_ElementwiseUnaryOp<"abs"> {
def TTIR_AbsOp: TTIR_ElementwiseUnaryOp<"abs", [TTIR_Idempotence]> {
let summary = "Eltwise absolute op.";
let description = [{
Eltwise absolute operation.
Expand All @@ -236,7 +233,7 @@ def TTIR_CbrtOp: TTIR_ElementwiseUnaryOp<"cbrt"> {
}];
}

def TTIR_CeilOp: TTIR_ElementwiseUnaryOp<"ceil"> {
def TTIR_CeilOp: TTIR_ElementwiseUnaryOp<"ceil", [TTIR_Idempotence]> {
let summary = "Eltwise ceil op.";
let description = [{
Eltwise ceil operation.
Expand All @@ -250,7 +247,7 @@ def TTIR_CosOp: TTIR_ElementwiseUnaryOp<"cos"> {
}];
}

def TTIR_FloorOp: TTIR_ElementwiseUnaryOp<"floor"> {
def TTIR_FloorOp: TTIR_ElementwiseUnaryOp<"floor", [TTIR_Idempotence]> {
let summary = "Eltwise floor op.";
let description = [{
Eltwise floor operation.
Expand Down Expand Up @@ -278,7 +275,7 @@ def TTIR_LogicalNotOp: TTIR_ElementwiseUnaryOp<"logical_not"> {
}];
}

def TTIR_BitwiseNotOp : TTIR_ElementwiseUnaryOp<"bitwise_not"> {
def TTIR_BitwiseNotOp : TTIR_ElementwiseUnaryOp<"bitwise_not", [TTIR_Involution]> {
let summary = "Eltwise bitwise NOT.";
let description = [{
Performs element-wise NOT of tensor `operand` and produces a `result` tensor.
Expand All @@ -291,7 +288,7 @@ def TTIR_BitwiseNotOp : TTIR_ElementwiseUnaryOp<"bitwise_not"> {
}];
}

def TTIR_NegOp: TTIR_ElementwiseUnaryOp<"neg"> {
def TTIR_NegOp: TTIR_ElementwiseUnaryOp<"neg", [TTIR_Involution]> {
let summary = "Eltwise negate op.";
let description = [{
Eltwise negate operation.
Expand All @@ -312,14 +309,17 @@ def TTIR_TanhOp: TTIR_ElementwiseUnaryOp<"tanh"> {
}];
}

def TTIR_ReciprocalOp : TTIR_ElementwiseUnaryOp<"reciprocal"> {
// TODO (azecevic): What should we do with 0.0 case?
// 1/0.0 = inf, 1/inf = 0.0, but TTNN isn't IEEE754 compliant.
// https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/Handling_Special_Value/special_values.md
def TTIR_ReciprocalOp : TTIR_ElementwiseUnaryOp<"reciprocal", [TTIR_Involution]> {
let summary = "Eltwise reciprocal.";
let description = [{
Eltwise reciprocal operation.
}];
}

def TTIR_ReluOp : TTIR_ElementwiseUnaryOp<"relu"> {
def TTIR_ReluOp : TTIR_ElementwiseUnaryOp<"relu", [TTIR_Idempotence]> {
let summary = "Eltwise ReLU.";
let description = [{
Eltwise ReLU operation.
Expand All @@ -340,7 +340,7 @@ def TTIR_SigmoidOp: TTIR_ElementwiseUnaryOp<"sigmoid"> {
}];
}

def TTIR_SignOp: TTIR_ElementwiseUnaryOp<"sign"> {
def TTIR_SignOp: TTIR_ElementwiseUnaryOp<"sign", [TTIR_Idempotence]> {
let summary = "Eltwise sign operation.";
let description = [{
Returns the sign of the `operand` element-wise and produces a `result`
Expand Down Expand Up @@ -449,7 +449,7 @@ def TTIR_LeakyReluOp : TTIR_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> {
}

class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, !listconcat(traits, [ThreeOperands])> {
TTIR_ElementwiseOp<mnemonic, !listconcat([ThreeOperands], traits)> {
let summary = "Eltwise binary op.";
let description = [{
Eltwise binary op.
Expand All @@ -464,56 +464,62 @@ class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
];
}

// TODO (azecevic): NaN != NaN, otherwise eq(x, x) == 1.
def TTIR_EqualOp : TTIR_ElementwiseBinaryOp<"eq"> {
let summary = "Eltwise equal to.";
let description = [{
Eltwise equal to operation.
}];
}

// TODO (azecevic): NaN != NaN, otherwise ne(x, x) == 0.
def TTIR_NotEqualOp : TTIR_ElementwiseBinaryOp<"ne"> {
let summary = "Eltwise not equal to.";
let description = [{
Eltwise not equal to operation.
}];
}

// TODO (azecevic): NaN != NaN, otherwise ge(x, x) == 1.
def TTIR_GreaterEqualOp : TTIR_ElementwiseBinaryOp<"ge"> {
let summary = "Eltwise greater than or equal to.";
let description = [{
Eltwise greater than or equal to operation.
}];
}

// TODO (azecevic): NaN != NaN, otherwise gt(x, x) == 0.
def TTIR_GreaterThanOp : TTIR_ElementwiseBinaryOp<"gt"> {
let summary = "Eltwise greater than.";
let description = [{
Eltwise greater than operation.
}];
}

// TODO (azecevic): NaN != NaN, otherwise le(x, x) == 1.
def TTIR_LessEqualOp : TTIR_ElementwiseBinaryOp<"le"> {
let summary = "Eltwise less than or equal to.";
let description = [{
Eltwise less than or equal to operation.
}];
}

// TODO (azecevic): NaN != NaN, otherwise lt(x, x) == 0.
def TTIR_LessThanOp : TTIR_ElementwiseBinaryOp<"lt"> {
let summary = "Eltwise less than.";
let description = [{
Eltwise less than operation.
}];
}

def TTIR_LogicalAndOp : TTIR_ElementwiseBinaryOp<"logical_and"> {
def TTIR_LogicalAndOp : TTIR_ElementwiseBinaryOp<"logical_and", [TTIR_BinaryIdempotence]> {
let summary = "Eltwise logical and.";
let description = [{
Eltwise logical and operation.
}];
}

def TTIR_LogicalOrOp : TTIR_ElementwiseBinaryOp<"logical_or"> {
def TTIR_LogicalOrOp : TTIR_ElementwiseBinaryOp<"logical_or", [TTIR_BinaryIdempotence]> {
let summary = "Eltwise logical or.";
let description = [{
Eltwise logical or operation.
Expand All @@ -527,7 +533,7 @@ def TTIR_LogicalXorOp : TTIR_ElementwiseBinaryOp<"logical_xor"> {
}];
}

def TTIR_BitwiseAndOp : TTIR_ElementwiseBinaryOp<"bitwise_and"> {
def TTIR_BitwiseAndOp : TTIR_ElementwiseBinaryOp<"bitwise_and", [TTIR_BinaryIdempotence]> {
let summary = "Eltwise bitwise AND.";
let description = [{
Performs element-wise bitwise AND of two tensors `lhs` and `rhs`
Expand All @@ -541,7 +547,7 @@ def TTIR_BitwiseAndOp : TTIR_ElementwiseBinaryOp<"bitwise_and"> {
}];
}

def TTIR_BitwiseOrOp : TTIR_ElementwiseBinaryOp<"bitwise_or"> {
def TTIR_BitwiseOrOp : TTIR_ElementwiseBinaryOp<"bitwise_or", [TTIR_BinaryIdempotence]> {
let summary = "Eltwise bitwise OR.";
let description = [{
Performs element-wise bitwise OR of two tensors `lhs` and `rhs`
Expand All @@ -567,9 +573,11 @@ def TTIR_BitwiseXorOp : TTIR_ElementwiseBinaryOp<"bitwise_xor"> {
%result = "ttir.bitwise_xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
}];

let hasCanonicalizer = 1;
}

def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum"> {
def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum", [TTIR_BinaryIdempotence]> {
let summary = "Eltwise minimum OP.";
let description = [{
Calculates minimum of input tensors' values element-wise and stores result
Expand Down Expand Up @@ -605,7 +613,7 @@ def TTIR_RemainderOp : TTIR_ElementwiseBinaryOp<"remainder"> {
}

class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> :
TTIR_DPSOp<mnemonic, !listconcat(traits, [TTIR_GenericRegionOpInterface])> {
TTIR_DPSOp<mnemonic, !listconcat([TTIR_GenericRegionOpInterface], traits)> {

let summary = "Reduction op.";
let description = [{
Expand Down Expand Up @@ -752,6 +760,8 @@ def TTIR_TransposeOp : TTIR_DPSOp<"transpose"> {
}];

let hasVerifier = 1;

let hasCanonicalizer = 1;
}

def TTIR_ConcatOp : TTIR_DPSOp<"concat"> {
Expand Down Expand Up @@ -827,6 +837,8 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> {
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasFolder = 1;
}

def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
Expand Down Expand Up @@ -1218,6 +1230,8 @@ def TTIR_ReverseOp : TTIR_DPSOp<"reverse", [AllShapesMatch<["input", "result"]>]
}];

let hasVerifier = 1;

let hasCanonicalizer = 1;
}

def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike,
Expand Down Expand Up @@ -1287,6 +1301,8 @@ def TTIR_LinearOp : TTIR_DPSOp<"linear"> {
}];

let hasVerifier = 1;

let hasCanonicalizeMethod = 1;
}

// ANCHOR: adding_an_op_matmul_ttir
Expand Down Expand Up @@ -1335,14 +1351,16 @@ def TTIR_PermuteOp : TTIR_DPSOp<"permute"> {
}];

let hasVerifier = 1;

let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// TTIR top level generic ops
//===----------------------------------------------------------------------===//

class TTIR_GenericElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseUnaryOp<mnemonic, !listconcat(traits, [TTIR_GenericRegionOpInterface])> {
TTIR_ElementwiseUnaryOp<mnemonic, !listconcat([TTIR_GenericRegionOpInterface], traits)> {

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
Expand Down Expand Up @@ -1373,7 +1391,7 @@ def TTIR_ExpOp: TTIR_GenericElementwiseUnaryOp<"exp"> {
}

class TTIR_GenericElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseBinaryOp<mnemonic, !listconcat(traits, [TTIR_GenericRegionOpInterface])> {
TTIR_ElementwiseBinaryOp<mnemonic, !listconcat([TTIR_GenericRegionOpInterface], traits)> {

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
Expand Down
Loading
Loading