diff --git a/include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt b/include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt index c27f08706..d3f0fb762 100644 --- a/include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt +++ b/include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt @@ -2,6 +2,8 @@ add_mlir_dialect(TTNNOps ttnn) add_mlir_doc(TTNNBase TTNNDialect src/autogen/md/Dialect/ -gen-dialect-doc) add_mlir_doc(TTNNOps TTNNOp src/autogen/md/Dialect/ -gen-op-doc) +add_mlir_interface(TTNNOpsBackendInterfaces) + set(LLVM_TARGET_DEFINITIONS TTNNOpsEnums.td) mlir_tablegen(TTNNOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(TTNNOpsEnums.cpp.inc -gen-enum-defs) diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNBase.td b/include/ttmlir/Dialect/TTNN/IR/TTNNBase.td index 1f733d17b..f2852f33f 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNBase.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNBase.td @@ -6,6 +6,7 @@ #define TTMLIR_TTMLIR_DIALECT_TTNN_TTNNDIALECT_TD include "mlir/IR/OpBase.td" +include "ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td" //===----------------------------------------------------------------------===// // TTNN dialect definition. @@ -43,6 +44,6 @@ def TTNN_Dialect : Dialect { //===----------------------------------------------------------------------===// class TTNN_Op traits = []> : - Op; + Op; #endif diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h index 20db8029c..a91d3ec2a 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h @@ -15,6 +15,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.h.inc" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" #define GET_OP_CLASSES diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 01ebae803..7787f8ee3 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -214,7 +214,9 @@ def TTNN_ReciprocalOp : TTNN_ElementwiseUnaryOp<"reciprocal"> { }]; } -def TTNN_ReluOp : TTNN_ElementwiseUnaryOp<"relu"> { +def TTNN_ReluOp : TTNN_ElementwiseUnaryOp<"relu", + [DeclareOpInterfaceMethods] + > { let summary = "Eltwise ReLU."; let description = [{ Eltwise ReLU operation. diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td new file mode 100644 index 000000000..6a4a2a8cf --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSPERFINTERFACES_TD +#define TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSPERFINTERFACES_TD + +include "mlir/IR/OpBase.td" + +def TTNNOpBackendInterface : OpInterface<"TTNNOpBackend"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + let cppNamespace = "::mlir::tt::ttnn"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the op kernel estimate in clock cycles. + }], + /*retTy=*/"size_t", + /*methodName=*/"getOpPerfCycles", + /*args=*/(ins "const std::vector&":$input_layouts, "const tt::LayoutAttr&":$output_layout), + /*methodBody=*/"", + /*defaultImplementation=*/"return std::numeric_limits::max();" + >, + InterfaceMethod< + /*desc=*/[{ + Return the op kernel estimate in clock cycles. + }], + /*retTy=*/"size_t", + /*methodName=*/"getOpL1Usage", + /*args=*/(ins "const std::vector&":$input_layouts, "const tt::LayoutAttr&":$output_layout), + /*methodBody=*/"", + /*defaultImplementation=*/"return 0;" + >, + InterfaceMethod< + /*desc=*/[{ + Return the op kernel estimate in clock cycles. + }], + /*retTy=*/"bool", + /*methodName=*/"isOpLegal", + /*args=*/(ins "const std::vector&":$input_layouts, "const tt::LayoutAttr&":$output_layout), + /*methodBody=*/"", + /*defaultImplementation=*/"return true;" + >, + ]; +} + +#endif // TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSPERFINTERFACES_TD diff --git a/lib/Dialect/TTNN/Analysis/ShardSolver.cpp b/lib/Dialect/TTNN/Analysis/ShardSolver.cpp index 6cc7d1eff..1ff49c5c7 100644 --- a/lib/Dialect/TTNN/Analysis/ShardSolver.cpp +++ b/lib/Dialect/TTNN/Analysis/ShardSolver.cpp @@ -4,6 +4,7 @@ #include "ttmlir/Dialect/TTNN/Analysis/ShardSolver.h" #include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include #include #include @@ -503,6 +504,13 @@ bool ShardSolver::checkShardCompatible( // TEMP : Dummy mock implementation, will be replaced. // + if (TTNNOpBackend backend = dyn_cast(consumerOp)) { + if (false == + backend.isOpLegal(std::vector{producerLayout}, consumerLayout)) { + return false; + } + } + // May need to fetch other inputs for consumerOp(weights/join node). // diff --git a/lib/Dialect/TTNN/IR/CMakeLists.txt b/lib/Dialect/TTNN/IR/CMakeLists.txt index 6a3f54fc8..8a009fb4a 100644 --- a/lib/Dialect/TTNN/IR/CMakeLists.txt +++ b/lib/Dialect/TTNN/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTTNNDialect TTNNDialect.cpp TTNNOps.cpp + TTNNOpsBackendInterfaces.cpp TTNNOpsTypes.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp b/lib/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp new file mode 100644 index 000000000..9d58b4b39 --- /dev/null +++ b/lib/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" + +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp.inc" + +namespace mlir::tt::ttnn { + +//===----------------------------------------------------------------------===// +// ReluOp +//===----------------------------------------------------------------------===// + +// // Relu backend interface +size_t ReluOp::getOpPerfCycles(const std::vector &input_layouts, + const tt::LayoutAttr &output_layout) { + // Implement a custom estimate for relu op cycles. + return 5; +} + +size_t ReluOp::getOpL1Usage(const std::vector &input_layouts, + const tt::LayoutAttr &output_layout) { + // Implement a custom estimate for relu op L1 usage. + return 10; +} + +bool ReluOp::isOpLegal(const std::vector &input_layouts, + const tt::LayoutAttr &output_layout) { + // Implement a custom check for relu op legality. + return true; +} + +} // namespace mlir::tt::ttnn