Skip to content

Commit

Permalink
TTNNOpsBackendInterface => TNN_OpModelInterface (#1266)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbezuljTT authored Nov 14, 2024
1 parent 0c75727 commit db38351
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 24 deletions.
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ add_mlir_dialect(TTNNOps ttnn)
add_mlir_doc(TTNNBase TTNNDialect src/autogen/md/Dialect/ -gen-dialect-doc)
add_mlir_doc(TTNNOps TTNNOp src/autogen/md/Dialect/ -gen-op-doc)

add_mlir_interface(TTNNOpsBackendInterfaces)
add_mlir_interface(TTNNOpModelInterface)

set(LLVM_TARGET_DEFINITIONS TTNNOpsEnums.td)
mlir_tablegen(TTNNOpsEnums.h.inc -gen-enum-decls)
Expand Down
4 changes: 2 additions & 2 deletions include/ttmlir/Dialect/TTNN/IR/TTNNBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#define TTMLIR_TTMLIR_DIALECT_TTNN_TTNNDIALECT_TD

include "mlir/IR/OpBase.td"
include "ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td"
include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.td"

//===----------------------------------------------------------------------===//
// TTNN dialect definition.
Expand Down Expand Up @@ -44,6 +44,6 @@ def TTNN_Dialect : Dialect {
//===----------------------------------------------------------------------===//

class TTNN_Op<string mnemonic, list<Trait> traits = []> :
Op<TTNN_Dialect, mnemonic, !listconcat(traits, [TTNNOpBackendInterface])>;
Op<TTNN_Dialect, mnemonic, !listconcat(traits, [TTNN_OpModelInterface])>;

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@

include "mlir/IR/OpBase.td"

def TTNNOpBackendInterface : OpInterface<"TTNNOpBackend"> {
def TTNN_OpModelInterface : OpInterface<"OpModel"> {
let description = [{
Interface to access a registered method to infer the return types for an
operation that can be used during type inference.
Interface to access TTNN op model methods.
}];
let cppNamespace = "::mlir::tt::ttnn";
let methods = [
InterfaceMethod<
/*desc=*/[{
Return the op kernel estimate in clock cycles.
Returns the op kernel estimate in clock cycles.
}],
/*retTy=*/"size_t",
/*methodName=*/"getOpPerfCycles",
Expand All @@ -26,17 +25,20 @@ def TTNNOpBackendInterface : OpInterface<"TTNNOpBackend"> {
>,
InterfaceMethod<
/*desc=*/[{
Return the op kernel estimate in clock cycles.
Returns the op memory L1 usage estimate in bytes. The return value is a tuple of 3 values:
- The first value is CB L1 peak allocation in bytes.
- The second value is Tensor L1 peak allocation in bytes.
- The third value is Output L1 buffer allocation in bytes.
}],
/*retTy=*/"size_t",
/*retTy=*/"std::tuple<size_t, size_t, size_t>",
/*methodName=*/"getOpL1Usage",
/*args=*/(ins "const std::vector<tt::LayoutAttr>&":$input_layouts, "const tt::LayoutAttr&":$output_layout),
/*methodBody=*/"",
/*defaultImplementation=*/"return 0;"
/*defaultImplementation=*/"return std::make_tuple(0,0,0);"
>,
InterfaceMethod<
/*desc=*/[{
Return the op kernel estimate in clock cycles.
Returns if input/output layouts are legal for the op.
}],
/*retTy=*/"bool",
/*methodName=*/"isOpLegal",
Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.h.inc"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.h.inc"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"

#define GET_OP_CLASSES
Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def TTNN_ReciprocalOp : TTNN_ElementwiseUnaryOp<"reciprocal"> {
}

def TTNN_ReluOp : TTNN_ElementwiseUnaryOp<"relu",
[DeclareOpInterfaceMethods<TTNNOpBackendInterface, ["getOpPerfCycles", "getOpL1Usage", "isOpLegal"]>]
[DeclareOpInterfaceMethods<TTNN_OpModelInterface, ["getOpPerfCycles", "getOpL1Usage", "isOpLegal"]>]
> {
let summary = "Eltwise ReLU.";
let description = [{
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Analysis/ShardSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ bool ShardSolver::checkShardCompatible(
// TEMP : Dummy mock implementation, will be replaced.
//

if (TTNNOpBackend backend = dyn_cast<TTNNOpBackend>(consumerOp)) {
if (OpModel backend = dyn_cast<OpModel>(consumerOp)) {
if (false ==
backend.isOpLegal(std::vector{producerLayout}, consumerLayout)) {
return false;
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
add_mlir_dialect_library(MLIRTTNNDialect
TTNNDialect.cpp
TTNNOps.cpp
TTNNOpsBackendInterfaces.cpp
TTNNOpModelInterface.cpp
TTNNOpsTypes.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,31 @@

#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

#include "ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp.inc"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.cpp.inc"
#include <tuple>

namespace mlir::tt::ttnn {

//===----------------------------------------------------------------------===//
// ReluOp
// ReluOp - TTNN Op Model Interface
//===----------------------------------------------------------------------===//

// // Relu backend interface
size_t ReluOp::getOpPerfCycles(const std::vector<tt::LayoutAttr> &input_layouts,
const tt::LayoutAttr &output_layout) {
// Implement a custom estimate for relu op cycles.
// TODO(mbezulj) wire to tt-metal once we have API
return 5;
}

size_t ReluOp::getOpL1Usage(const std::vector<tt::LayoutAttr> &input_layouts,
const tt::LayoutAttr &output_layout) {
// Implement a custom estimate for relu op L1 usage.
return 10;
std::tuple<size_t, size_t, size_t>
ReluOp::getOpL1Usage(const std::vector<tt::LayoutAttr> &input_layouts,
const tt::LayoutAttr &output_layout) {
// TODO(mbezulj) wire to tt-metal once we have API
return std::make_tuple(1024, 2048, 1024);
}

bool ReluOp::isOpLegal(const std::vector<tt::LayoutAttr> &input_layouts,
const tt::LayoutAttr &output_layout) {
// Implement a custom check for relu op legality.
// TODO(mbezulj) wire to tt-metal once we have API
return true;
}

Expand Down

0 comments on commit db38351

Please sign in to comment.