From 2e5d2933a569320b7c3abc1b3529a8ac48ca2208 Mon Sep 17 00:00:00 2001 From: Bezulj Marko Date: Fri, 1 Nov 2024 15:20:28 +0000 Subject: [PATCH] added tt::LayoutAttr to the backend interface --- include/ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td | 6 +++--- lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp | 6 +++--- lib/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td index c4ab4feedb..0f28e18a90 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td @@ -20,7 +20,7 @@ def TTNNOpBackendInterface : OpInterface<"TTNNOpBackend"> { }], /*retTy=*/"size_t", /*methodName=*/"getOpPerfCycles", - /*args=*/(ins), // TBD + /*args=*/(ins "const tt::LayoutAttr&":$output_layout), // Subject to change /*methodBody=*/"", /*defaultImplementation=*/"return std::numeric_limits::max();" >, @@ -30,7 +30,7 @@ def TTNNOpBackendInterface : OpInterface<"TTNNOpBackend"> { }], /*retTy=*/"size_t", /*methodName=*/"getOpL1Usage", - /*args=*/(ins), // TBD + /*args=*/(ins "const tt::LayoutAttr&":$output_layout), // Subject to change /*methodBody=*/"", /*defaultImplementation=*/"return 0;" >, @@ -40,7 +40,7 @@ def TTNNOpBackendInterface : OpInterface<"TTNNOpBackend"> { }], /*retTy=*/"bool", /*methodName=*/"isOpLegal", - /*args=*/(ins), // TBD + /*args=*/(ins "const tt::LayoutAttr&":$output_layout), // Subject to change /*methodBody=*/"", /*defaultImplementation=*/"return true;" >, diff --git a/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp b/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp index bc1d4b3c2f..08f0a0212b 100644 --- a/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp +++ b/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp @@ -16,9 +16,9 @@ bool mock_is_output_tensor_legal_for_op(Operation *op, tt::LayoutAttr layout) { // if (TTNNOpBackend backend = dyn_cast(op)) { // llvm::outs() << op->getName() << "=" << layout << "\n"; - // llvm::outs() << "\t[Perf] = " << backend.getOpPerfCycles() << "\n"; - // llvm::outs() << "\t[L1] = " << backend.getOpL1Usage() << "\n"; - // llvm::outs() << "\t[Legal] = " << backend.isOpLegal() << "\n"; + // llvm::outs() << "\t[Perf] = " << backend.getOpPerfCycles(layout) << "\n"; + // llvm::outs() << "\t[L1] = " << backend.getOpL1Usage(layout) << "\n"; + // llvm::outs() << "\t[Legal] = " << backend.isOpLegal(layout) << "\n"; // } return true; diff --git a/lib/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp b/lib/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp index 62220f9b6b..0389c3ddb6 100644 --- a/lib/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp @@ -13,17 +13,17 @@ namespace mlir::tt::ttnn { //===----------------------------------------------------------------------===// // // Relu backend interface -size_t ReluOp::getOpPerfCycles() { +size_t ReluOp::getOpPerfCycles(const tt::LayoutAttr &output_layout) { // Implement a custom estimate for relu op cycles. return 5; } -size_t ReluOp::getOpL1Usage() { +size_t ReluOp::getOpL1Usage(const tt::LayoutAttr &output_layout) { // Implement a custom estimate for relu op L1 usage. return 10; } -bool ReluOp::isOpLegal() { +bool ReluOp::isOpLegal(const tt::LayoutAttr &output_layout) { // Implement a custom check for relu op legality. return true; }