Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TTNNOpsBackendInterface definition #1131

Merged
merged 8 commits into from
Nov 4, 2024
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ add_mlir_dialect(TTNNOps ttnn)
add_mlir_doc(TTNNBase TTNNDialect src/autogen/md/Dialect/ -gen-dialect-doc)
add_mlir_doc(TTNNOps TTNNOp src/autogen/md/Dialect/ -gen-op-doc)

add_mlir_interface(TTNNOpsBackendInterfaces)

set(LLVM_TARGET_DEFINITIONS TTNNOpsEnums.td)
mlir_tablegen(TTNNOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(TTNNOpsEnums.cpp.inc -gen-enum-defs)
Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define TTMLIR_TTMLIR_DIALECT_TTNN_TTNNDIALECT_TD

include "mlir/IR/OpBase.td"
include "ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td"

//===----------------------------------------------------------------------===//
// TTNN dialect definition.
Expand Down Expand Up @@ -43,6 +44,6 @@ def TTNN_Dialect : Dialect {
//===----------------------------------------------------------------------===//

class TTNN_Op<string mnemonic, list<Trait> traits = []> :
Op<TTNN_Dialect, mnemonic, traits>;
Op<TTNN_Dialect, mnemonic, !listconcat(traits, [TTNNOpBackendInterface])>;

#endif
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.h.inc"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"

#define GET_OP_CLASSES
Expand Down
4 changes: 3 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ def TTNN_ReciprocalOp : TTNN_ElementwiseUnaryOp<"reciprocal"> {
}];
}

def TTNN_ReluOp : TTNN_ElementwiseUnaryOp<"relu"> {
def TTNN_ReluOp : TTNN_ElementwiseUnaryOp<"relu",
[DeclareOpInterfaceMethods<TTNNOpBackendInterface, ["getOpPerfCycles", "getOpL1Usage", "isOpLegal"]>]
> {
let summary = "Eltwise ReLU.";
let description = [{
Eltwise ReLU operation.
Expand Down
50 changes: 50 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSPERFINTERFACES_TD
#define TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSPERFINTERFACES_TD

include "mlir/IR/OpBase.td"

def TTNNOpBackendInterface : OpInterface<"TTNNOpBackend"> {
let description = [{
Interface to access a registered method to infer the return types for an
operation that can be used during type inference.
}];
let cppNamespace = "::mlir::tt::ttnn";
let methods = [
InterfaceMethod<
/*desc=*/[{
Return the op kernel estimate in clock cycles.
}],
/*retTy=*/"size_t",
/*methodName=*/"getOpPerfCycles",
/*args=*/(ins "const tt::LayoutAttr&":$output_layout), // Subject to change
/*methodBody=*/"",
/*defaultImplementation=*/"return std::numeric_limits<size_t>::max();"
>,
InterfaceMethod<
/*desc=*/[{
Return the op kernel estimate in clock cycles.
}],
/*retTy=*/"size_t",
/*methodName=*/"getOpL1Usage",
/*args=*/(ins "const tt::LayoutAttr&":$output_layout), // Subject to change
/*methodBody=*/"",
/*defaultImplementation=*/"return 0;"
>,
InterfaceMethod<
/*desc=*/[{
Return the op kernel estimate in clock cycles.
}],
/*retTy=*/"bool",
/*methodName=*/"isOpLegal",
/*args=*/(ins "const tt::LayoutAttr&":$output_layout), // Subject to change
/*methodBody=*/"",
/*defaultImplementation=*/"return true;"
>,
mbezuljTT marked this conversation as resolved.
Show resolved Hide resolved
];
}

#endif // TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSPERFINTERFACES_TD
8 changes: 7 additions & 1 deletion lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ namespace mlir::tt::ttnn {

bool mock_is_output_tensor_legal_for_op(Operation *op, tt::LayoutAttr layout) {
// Placeholder, needs to be replaced with a call the the TTNN op interface.
return true;

mbezuljTT marked this conversation as resolved.
Show resolved Hide resolved
mbezuljTT marked this conversation as resolved.
Show resolved Hide resolved
if (TTNNOpBackend backend = dyn_cast<TTNNOpBackend>(op)) {
return backend.isOpLegal(layout);
mbezuljTT marked this conversation as resolved.
Show resolved Hide resolved
}

assert(false && "Op is not a TTNN op.");
return false;
}

bool tensor_shape_compatible_with_shard(Operation *op, tt::LayoutAttr layout) {
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTTNNDialect
TTNNDialect.cpp
TTNNOps.cpp
TTNNOpsBackendInterfaces.cpp
TTNNOpsTypes.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
31 changes: 31 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

#include "ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp.inc"

namespace mlir::tt::ttnn {

//===----------------------------------------------------------------------===//
// ReluOp
//===----------------------------------------------------------------------===//

// // Relu backend interface
size_t ReluOp::getOpPerfCycles(const tt::LayoutAttr &output_layout) {
// Implement a custom estimate for relu op cycles.
return 5;
}

size_t ReluOp::getOpL1Usage(const tt::LayoutAttr &output_layout) {
// Implement a custom estimate for relu op L1 usage.
return 10;
}

bool ReluOp::isOpLegal(const tt::LayoutAttr &output_layout) {
// Implement a custom check for relu op legality.
return true;
}

} // namespace mlir::tt::ttnn
Loading