From 278a774241e754e25f9aa4ffd9b23eb16ff6a043 Mon Sep 17 00:00:00 2001 From: Jackson Nie Date: Wed, 30 Oct 2024 22:58:19 +0000 Subject: [PATCH] Update eltwise flatbuffer to include extra params table, move eltwise composite to its own file to match ttnn implementation --- include/ttmlir/Target/TTNN/program.fbs | 5 ++ lib/Target/TTNN/TTNNToFlatbuffer.cpp | 6 ++- runtime/lib/ttnn/operations/CMakeLists.txt | 8 +++- .../eltwise/{ => binary}/binary.cpp | 44 +---------------- .../operations/eltwise/{ => binary}/binary.h | 0 .../eltwise/binary/binary_composite.cpp | 47 +++++++++++++++++++ .../eltwise/binary/binary_composite.h | 27 +++++++++++ .../operations/eltwise/{ => unary}/unary.cpp | 32 +------------ .../operations/eltwise/{ => unary}/unary.h | 0 .../eltwise/unary/unary_composite.cpp | 42 +++++++++++++++++ .../eltwise/unary/unary_composite.h | 26 ++++++++++ .../ttnn/operations/eltwise/binary/utils.cpp | 25 ++++++++++ .../ttnn/operations/eltwise/binary/utils.h | 19 ++++++++ .../ttnn/operations/eltwise/unary/utils.cpp | 18 +++++++ .../ttnn/operations/eltwise/unary/utils.h | 20 ++++++++ runtime/lib/ttnn/program.cpp | 42 +++++++++++++---- 16 files changed, 277 insertions(+), 84 deletions(-) rename runtime/lib/ttnn/operations/eltwise/{ => binary}/binary.cpp (66%) rename runtime/lib/ttnn/operations/eltwise/{ => binary}/binary.h (100%) create mode 100644 runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp create mode 100644 runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h rename runtime/lib/ttnn/operations/eltwise/{ => unary}/unary.cpp (75%) rename runtime/lib/ttnn/operations/eltwise/{ => unary}/unary.h (100%) create mode 100644 runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp create mode 100644 runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h create mode 100644 runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp create mode 100644 runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.h create mode 100644 runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp create mode 100644 runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/unary/utils.h diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 56c80410d0..27da085a87 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -88,10 +88,15 @@ enum EltwiseOpType: uint32 { Cos = 27 } +union EltwiseOpParams { + +} + table EltwiseOp { type: EltwiseOpType; ins: [tt.target.TensorRef]; out: tt.target.TensorRef; + params: EltwiseOpParams; } enum ReductionOpType: uint32 { diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 39352d1bb9..d731f4b5c3 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -294,6 +294,9 @@ template ::flatbuffers::Offset<::tt::target::ttnn::EltwiseOp> createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { ::tt::target::ttnn::EltwiseOpType type; + ::tt::target::ttnn::EltwiseOpParams paramsType = + ::tt::target::ttnn::EltwiseOpParams::NONE; + ::flatbuffers::Offset params = 0; if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Abs; } else if constexpr (std::is_same_v) { @@ -360,7 +363,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { return ::tt::target::ttnn::CreateEltwiseOpDirect( *cache.fbb, type, &ins, cache.at<::tt::target::TensorRef>( - getOperandThroughDPSOps(op.getOutputs().front()))); + getOperandThroughDPSOps(op.getOutputs().front())), + paramsType, params); } template diff --git a/runtime/lib/ttnn/operations/CMakeLists.txt b/runtime/lib/ttnn/operations/CMakeLists.txt index f557d318b8..db67164ef2 100644 --- a/runtime/lib/ttnn/operations/CMakeLists.txt +++ b/runtime/lib/ttnn/operations/CMakeLists.txt @@ -1,5 +1,7 @@ set(TTNN_OPS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp @@ -9,8 +11,10 @@ set(TTNN_OPS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/data_movement/slice.cpp ${CMAKE_CURRENT_SOURCE_DIR}/data_movement/transpose.cpp ${CMAKE_CURRENT_SOURCE_DIR}/deletion/dealloc.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary/binary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary/binary_composite.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary_composite.cpp ${CMAKE_CURRENT_SOURCE_DIR}/embedding/embedding.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layout/to_device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layout/from_device.cpp diff --git a/runtime/lib/ttnn/operations/eltwise/binary.cpp b/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp similarity index 66% rename from runtime/lib/ttnn/operations/eltwise/binary.cpp rename to runtime/lib/ttnn/operations/eltwise/binary/binary.cpp index 2e35f20ed8..6266dd721f 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp @@ -4,26 +4,12 @@ #include "binary.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/operations/eltwise/binary/utils.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "ttnn/operations/eltwise/binary/binary_composite.hpp" namespace tt::runtime::ttnn::operations::binary { -static void -getEltwiseBinaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op, - ProgramTensorPool &tensorPool, - ::ttnn::Tensor **lhs, ::ttnn::Tensor **rhs) { - LOG_ASSERT(op->ins()->size() == 2, "Expected 2 inputs"); - *lhs = &(tensorPool.at(op->ins()->Get(0)->global_id())); - *rhs = &(tensorPool.at(op->ins()->Get(1)->global_id())); - DEBUG_ASSERT((*lhs)->is_allocated()); - DEBUG_ASSERT((*rhs)->is_allocated()); - - // Switch the order of operands if the second operand requires broadcast - if ((*rhs)->volume() < (*lhs)->volume()) { - std::swap(*lhs, *rhs); - } -} - static void runEltwiseBinaryOP( const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, std::function<::ttnn::Tensor( @@ -48,24 +34,6 @@ static void runEltwiseBinaryOP( tensorPool.insert_or_assign(op->out()->global_id(), out); } -static void runEltwiseBinaryCompositeOP( - const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, - std::function< - ::ttnn::Tensor(const ::ttnn::Tensor &, const ::ttnn::Tensor &, - const std::optional<::tt::tt_metal::MemoryConfig> &)> - ttnnOp) { - - ::ttnn::Tensor *lhs = nullptr; - ::ttnn::Tensor *rhs = nullptr; - getEltwiseBinaryOPInputTensors(op, tensorPool, &lhs, &rhs); - - ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); - - ::ttnn::Tensor out = ttnnOp(*lhs, *rhs, outputMemoryConfig); - tensorPool.insert_or_assign(op->out()->global_id(), out); -} - void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); switch (op->type()) { @@ -118,14 +86,6 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { runEltwiseBinaryOP(op, tensorPool, ::ttnn::divide); break; } - case ::tt::target::ttnn::EltwiseOpType::Maximum: { - runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::maximum); - break; - } - case ::tt::target::ttnn::EltwiseOpType::Minimum: { - runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::minimum); - break; - } default: throw std::invalid_argument("Unsupported Eltwise Binary operation"); } diff --git a/runtime/lib/ttnn/operations/eltwise/binary.h b/runtime/lib/ttnn/operations/eltwise/binary/binary.h similarity index 100% rename from runtime/lib/ttnn/operations/eltwise/binary.h rename to runtime/lib/ttnn/operations/eltwise/binary/binary.h diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp new file mode 100644 index 0000000000..e0dbddf8f2 --- /dev/null +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +#include "binary_composite.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/operations/eltwise/binary/utils.h" +#include "tt/runtime/ttnn/operations/utils.h" + +namespace tt::runtime::ttnn::operations::binary::composite { + +static void runEltwiseBinaryCompositeOP( + const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, + std::function< + ::ttnn::Tensor(const ::ttnn::Tensor &, const ::ttnn::Tensor &, + const std::optional<::tt::tt_metal::MemoryConfig> &)> + ttnnOp) { + + ::ttnn::Tensor *lhs = nullptr; + ::ttnn::Tensor *rhs = nullptr; + getEltwiseBinaryOPInputTensors(op, tensorPool, &lhs, &rhs); + + ::tt::tt_metal::MemoryConfig outputMemoryConfig = + utils::createMemoryConfig(op->out()); + + ::ttnn::Tensor out = ttnnOp(*lhs, *rhs, outputMemoryConfig); + tensorPool.insert_or_assign(op->out()->global_id(), out); +} + +void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + switch (op->type()) { + case ::tt::target::ttnn::EltwiseOpType::Maximum: { + runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::maximum); + break; + } + case ::tt::target::ttnn::EltwiseOpType::Minimum: { + runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::minimum); + break; + } + default: + throw std::invalid_argument( + "Unsupported Eltwise Binary Composite operation"); + } +} + +} // namespace tt::runtime::ttnn::operations::binary::composite diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h new file mode 100644 index 0000000000..e04059940e --- /dev/null +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTNN_RUNTIME_ELTWISE_BINARY_COMPOSITE_H +#define TTNN_RUNTIME_ELTWISE_BINARY_COMPOSITE_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::binary::composite { + +inline bool isBinaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) { + switch (op->type()) { + case ::tt::target::ttnn::EltwiseOpType::Maximum: + case ::tt::target::ttnn::EltwiseOpType::Minimum: + return true; + default: + return false; + } +} + +void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context); + +} // namespace tt::runtime::ttnn::operations::binary::composite + +#endif diff --git a/runtime/lib/ttnn/operations/eltwise/unary.cpp b/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp similarity index 75% rename from runtime/lib/ttnn/operations/eltwise/unary.cpp rename to runtime/lib/ttnn/operations/eltwise/unary/unary.cpp index a4aca2c482..4b261d1653 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp @@ -4,22 +4,12 @@ #include "unary.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/operations/eltwise/unary/utils.h" #include "tt/runtime/ttnn/operations/utils.h" #include "ttnn/operations/copy.hpp" -#include "ttnn/operations/eltwise/unary/unary_composite.hpp" namespace tt::runtime::ttnn::operations::unary { -static void -getEltwiseUnaryOPInputTensor(const ::tt::target::ttnn::EltwiseOp *op, - ProgramTensorPool &tensorPool, - ::ttnn::Tensor **in) { - LOG_ASSERT(op->ins()->size() == 1, "Expected 1 input, got ", - op->ins()->size()); - *in = &(tensorPool.at(op->ins()->Get(0)->global_id())); - DEBUG_ASSERT((*in)->is_allocated()); -} - static void runEltwiseUnaryOP( const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, std::function< @@ -38,22 +28,6 @@ static void runEltwiseUnaryOP( tensorPool.insert_or_assign(op->out()->global_id(), out); } -static void runEltwiseUnaryCompositeOP( - const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, - std::function<::ttnn::Tensor(const ::ttnn::Tensor &, - const ::tt::tt_metal::MemoryConfig &)> - ttnnOp) { - - ::ttnn::Tensor *in = nullptr; - getEltwiseUnaryOPInputTensor(op, tensorPool, &in); - - ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); - - ::ttnn::Tensor out = ttnnOp(*in, outputMemoryConfig); - tensorPool.insert_or_assign(op->out()->global_id(), out); -} - static void runEltwiseUnaryWithFastAndApproximateModeOP( const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, std::function< @@ -80,10 +54,6 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { runEltwiseUnaryOP(op, tensorPool, ::ttnn::abs); break; } - case ::tt::target::ttnn::EltwiseOpType::Cbrt: { - runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::cbrt); - break; - } case ::tt::target::ttnn::EltwiseOpType::Ceil: { runEltwiseUnaryOP(op, tensorPool, ::ttnn::ceil); break; diff --git a/runtime/lib/ttnn/operations/eltwise/unary.h b/runtime/lib/ttnn/operations/eltwise/unary/unary.h similarity index 100% rename from runtime/lib/ttnn/operations/eltwise/unary.h rename to runtime/lib/ttnn/operations/eltwise/unary/unary.h diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp new file mode 100644 index 0000000000..da4af9c63f --- /dev/null +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +#include "unary_composite.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/operations/eltwise/unary/utils.h" +#include "tt/runtime/ttnn/operations/utils.h" +#include "ttnn/operations/eltwise/unary/unary_composite.hpp" + +namespace tt::runtime::ttnn::operations::unary::composite { + +static void runEltwiseUnaryCompositeOP( + const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, + std::function<::ttnn::Tensor(const ::ttnn::Tensor &, + const ::tt::tt_metal::MemoryConfig &)> + ttnnOp) { + + ::ttnn::Tensor *in = nullptr; + getEltwiseUnaryOPInputTensor(op, tensorPool, &in); + + ::tt::tt_metal::MemoryConfig outputMemoryConfig = + utils::createMemoryConfig(op->out()); + + ::ttnn::Tensor out = ttnnOp(*in, outputMemoryConfig); + tensorPool.insert_or_assign(op->out()->global_id(), out); +} + +void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + switch (op->type()) { + case ::tt::target::ttnn::EltwiseOpType::Cbrt: { + runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::cbrt); + break; + } + default: + throw std::invalid_argument( + "Unsupported Eltwise Binary Composite operation"); + } +} + +} // namespace tt::runtime::ttnn::operations::unary::composite diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h new file mode 100644 index 0000000000..11231492ec --- /dev/null +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTNN_RUNTIME_ELTWISE_UNARY_COMPOSITE_H +#define TTNN_RUNTIME_ELTWISE_UNARY_COMPOSITE_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::unary::composite { + +inline bool isUnaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) { + switch (op->type()) { + case ::tt::target::ttnn::EltwiseOpType::Cbrt: + return true; + default: + return false; + } +} + +void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context); + +} // namespace tt::runtime::ttnn::operations::unary::composite + +#endif diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp new file mode 100644 index 0000000000..e925303cfb --- /dev/null +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +#include "utils.h" +#include "tt/runtime/detail/logger.h" + +namespace tt::runtime::ttnn::operations::binary { + +void getEltwiseBinaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op, + ProgramTensorPool &tensorPool, + ::ttnn::Tensor **lhs, + ::ttnn::Tensor **rhs) { + LOG_ASSERT(op->ins()->size() == 2, "Expected 2 inputs"); + *lhs = &(tensorPool.at(op->ins()->Get(0)->global_id())); + *rhs = &(tensorPool.at(op->ins()->Get(1)->global_id())); + DEBUG_ASSERT((*lhs)->is_allocated()); + DEBUG_ASSERT((*rhs)->is_allocated()); + + // Switch the order of operands if the second operand requires broadcast + if ((*rhs)->volume() < (*lhs)->volume()) { + std::swap(*lhs, *rhs); + } +} + +} // namespace tt::runtime::ttnn::operations::binary diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.h b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.h new file mode 100644 index 0000000000..54eb6610fd --- /dev/null +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.h @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTNN_RUNTIME_ELTWISE_BINARY_UTILS_H +#define TTNN_RUNTIME_ELTWISE_BINARY_UTILS_H + +#include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::binary { +void getEltwiseBinaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op, + ProgramTensorPool &tensorPool, + ::ttnn::Tensor **lhs, ::ttnn::Tensor **rhs); + +} // namespace tt::runtime::ttnn::operations::binary + +#endif diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp new file mode 100644 index 0000000000..ee1504cbcd --- /dev/null +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +#include "utils.h" +#include "tt/runtime/detail/logger.h" + +namespace tt::runtime::ttnn::operations::unary { + +void getEltwiseUnaryOPInputTensor(const ::tt::target::ttnn::EltwiseOp *op, + ProgramTensorPool &tensorPool, + ::ttnn::Tensor **in) { + LOG_ASSERT(op->ins()->size() == 1, "Expected 1 input, got ", + op->ins()->size()); + *in = &(tensorPool.at(op->ins()->Get(0)->global_id())); + DEBUG_ASSERT((*in)->is_allocated()); +} + +} // namespace tt::runtime::ttnn::operations::unary diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/unary/utils.h b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/unary/utils.h new file mode 100644 index 0000000000..9a565f73ce --- /dev/null +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/unary/utils.h @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTNN_RUNTIME_ELTWISE_UNARY_UTILS_H +#define TTNN_RUNTIME_ELTWISE_UNARY_UTILS_H + +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::unary { +void getEltwiseUnaryOPInputTensor(const ::tt::target::ttnn::EltwiseOp *op, + ProgramTensorPool &tensorPool, + ::ttnn::Tensor **in); + +} // namespace tt::runtime::ttnn::operations::unary + +#endif diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index ab5d651e9c..f150d35c10 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -11,8 +11,10 @@ #include "operations/data_movement/slice.h" #include "operations/data_movement/transpose.h" #include "operations/deletion/dealloc.h" -#include "operations/eltwise/binary.h" -#include "operations/eltwise/unary.h" +#include "operations/eltwise/binary/binary.h" +#include "operations/eltwise/binary/binary_composite.h" +#include "operations/eltwise/unary/unary.h" +#include "operations/eltwise/unary/unary_composite.h" #include "operations/embedding/embedding.h" #include "operations/layout/from_device.h" #include "operations/layout/to_device.h" @@ -29,6 +31,7 @@ namespace tt::runtime::ttnn { using LogType = ::tt::runtime::logger::LogType; + struct ProgramExecutor { ProgramExecutor(const TensorMap &liveTensors, const std::unordered_set &programInputs, @@ -50,8 +53,36 @@ struct ProgramExecutor { private: ProgramContext context; void runOperation(const ::tt::target::ttnn::Operation *op); + void runEltwiseOperation(const ::tt::target::ttnn::EltwiseOp *op); }; +void ProgramExecutor::runEltwiseOperation( + const ::tt::target::ttnn::EltwiseOp *op) { + auto runUnaryOp = [&]() { + if (operations::unary::composite::isUnaryCompositeOp(op)) { + return operations::unary::composite::run(op, context); + } + return operations::unary::run(op, context); + }; + + auto runBinaryOp = [&]() { + if (operations::binary::composite::isBinaryCompositeOp(op)) { + return operations::binary::composite::run(op, context); + } + return operations::binary::run(op, context); + }; + + if (operations::unary::isUnaryOp(op)) { + return runUnaryOp(); + } + + if (operations::binary::isBinaryOp(op)) { + return runBinaryOp(); + } + + throw std::invalid_argument("Unsupported Eltwise operation"); +} + void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { switch (op->type_type()) { case ::tt::target::ttnn::OpType::GetDeviceOp: { @@ -80,12 +111,7 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { } case ::tt::target::ttnn::OpType::EltwiseOp: { const ::tt::target::ttnn::EltwiseOp *eltwiseOp = op->type_as_EltwiseOp(); - if (operations::unary::isUnaryOp(eltwiseOp)) { - return operations::unary::run(eltwiseOp, context); - } - LOG_ASSERT(operations::binary::isBinaryOp(eltwiseOp), - "Eltwise op should be either unary or binary"); - return operations::binary::run(eltwiseOp, context); + return runEltwiseOperation(eltwiseOp); } // ANCHOR: adding_an_op_matmul_runtime_program case ::tt::target::ttnn::OpType::MatmulOp: {