Skip to content

Commit

Permalink
#8835: add example template of a ttnn operation
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Jun 28, 2024
1 parent 97d9fe8 commit fc23d32
Show file tree
Hide file tree
Showing 18 changed files with 441 additions and 108 deletions.
4 changes: 1 addition & 3 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
#include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp"
#include "tt_metal/common/constants.hpp"
#include "ttnn/cpp/ttnn/operations/creation.hpp"


#include "ttnn/operations/eltwise/binary/device/binary_op.hpp"
#include "ttnn/operations/eltwise/binary/device/binary_program_dispatcher.hpp"

namespace tt {

Expand Down
5 changes: 4 additions & 1 deletion ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ set(TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/binary_program_dispatcher.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary/device/unary_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/example/example/device/example_program_dispatcher.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/example/example/device/single_core_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/example/example/device/multi_core_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp
)

Expand Down
8 changes: 6 additions & 2 deletions ttnn/cpp/pybind11/operations/__init__.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "pybind11/operations/ccl.hpp"
#include "pybind11/operations/conv2d.hpp"
#include "pybind11/operations/copy.hpp"
#include "pybind11/operations/core.hpp"
#include "pybind11/operations/creation.hpp"
#include "pybind11/operations/data_movement.hpp"
Expand All @@ -18,15 +19,15 @@
#include "pybind11/operations/maxpool2d.hpp"
#include "pybind11/operations/normalization.hpp"
#include "pybind11/operations/pool.hpp"
#include "pybind11/operations/copy.hpp"
#include "pybind11/operations/ternary.hpp"
#include "pybind11/operations/transformer.hpp"

#include "ttnn/operations/eltwise/binary/binary_pybind.hpp"
#include "ttnn/operations/eltwise/unary/unary_pybind.hpp"
#include "ttnn/operations/reduction/reduction_pybind.hpp"
#include "ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp"

#include "ttnn/operations/example/example/example_pybind.hpp"
#include "ttnn/operations/reduction/reduction_pybind.hpp"

namespace py = pybind11;

Expand All @@ -35,6 +36,9 @@ namespace ttnn {
namespace operations {

void py_module(py::module& module) {
auto m_example = module.def_submodule("example", "example operation");
example::py_module(m_example);

auto m_unary = module.def_submodule("unary", "unary operations");
unary::py_module(m_unary);

Expand Down
107 changes: 74 additions & 33 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

#pragma once

#include "device/binary_op.hpp"
#include "ttnn/device_operation.hpp"
#include "device/binary_program_dispatcher.hpp"
#include "ttnn/operations/data_movement.hpp"
#include "ttnn/program_dispatcher.hpp"

namespace ttnn {

Expand All @@ -28,7 +28,7 @@ constexpr bool is_associative(BinaryOpType op) {
}

template <BinaryOpType binary_op_type, bool in_place>
struct ExecuteBinary {
struct Binary {
static inline const std::array<TensorSchema, 2> input_tensor_schemas() {
return {
ttnn::TensorSchema{
Expand Down Expand Up @@ -108,11 +108,11 @@ struct ExecuteBinary {
dtype = optional_output_tensor.value().get_dtype();
}

return ttnn::device_operation::run<Binary>(
return ttnn::program_dispatcher::run<BinaryProgramDispatcher>(
queue_id,
Binary::operation_attributes_t{
BinaryProgramDispatcher::operation_attributes_t{
binary_op_type, in_place, activations, output_memory_config, dtype, std::nullopt},
Binary::tensor_args_t{input_tensor_a, input_tensor_b, optional_output_tensor});
BinaryProgramDispatcher::tensor_args_t{input_tensor_a, input_tensor_b, optional_output_tensor});
}

template <typename... Args>
Expand Down Expand Up @@ -145,8 +145,14 @@ struct ExecuteBinary {
const std::optional<ttnn::MemoryConfig> &memory_config = std::nullopt,
const std::optional<Tensor> &optional_output_tensor = std::nullopt,
std::optional<FusedActivations> activations = std::nullopt) {

return ExecuteBinary::execute_on_worker_thread(DefaultQueueId, input_tensor_a, scalar, dtype, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, optional_output_tensor, activations);
return Binary::execute_on_worker_thread(
DefaultQueueId,
input_tensor_a,
scalar,
dtype,
operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
optional_output_tensor,
activations);
}

template <typename... Args>
Expand All @@ -172,36 +178,71 @@ struct ExecuteBinary {
Layout::TILE);
Tensor scalar_tensor_device = scalar_tensor_host.to(input_tensor_a.device());
// TODO(arakhmati): #7637 pass in memory_config instead of operation::DEFAULT_OUTPUT_MEMORY_CONFIG
return ExecuteBinary::execute_on_worker_thread(
input_tensor_a, scalar_tensor_device, dtype, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, optional_output_tensor, activations);
return Binary::execute_on_worker_thread(
input_tensor_a,
scalar_tensor_device,
dtype,
operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
optional_output_tensor,
activations);
}
};

} // operations::binary

constexpr auto add = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::ADD, false>>("ttnn::add");
constexpr auto add_ = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::ADD, true>>("ttnn::add_");
constexpr auto subtract = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::SUB, false>>("ttnn::subtract");
constexpr auto subtract_ = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::SUB, true>>("ttnn::subtract_");
constexpr auto multiply = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::MUL, false>>("ttnn::multiply");
constexpr auto multiply_ = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::MUL, true>>("ttnn::multiply_");

constexpr auto eq = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::EQ, false>>("ttnn::eq");
constexpr auto ne = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::NE, false>>("ttnn::ne");
constexpr auto ge = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::GTE, false>>("ttnn::ge");
constexpr auto gt = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::GT, false>>("ttnn::gt");
constexpr auto le = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::LTE, false>>("ttnn::le");
constexpr auto lt = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::LT, false>>("ttnn::lt");
constexpr auto logical_and = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::LOGICAL_AND, false>>("ttnn::logical_and");
constexpr auto logical_or = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::LOGICAL_OR, false>>("ttnn::logical_or");
constexpr auto ldexp = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::LDEXP, false>>("ttnn::ldexp");

constexpr auto logaddexp = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::LOGADDEXP, false>>("ttnn::logaddexp");
constexpr auto logaddexp2 = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::LOGADDEXP2, false>>("ttnn::logaddexp2");
constexpr auto squared_difference = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::SQUARED_DIFFERENCE, false>>("ttnn::squared_difference");
constexpr auto divide = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::DIV_FAST, false>>("ttnn::divide");
constexpr auto bias_gelu = ttnn::register_operation<operations::binary::ExecuteBinary<operations::binary::BinaryOpType::BIAS_GELU, false>>("ttnn::bias_gelu");

constexpr auto add =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::ADD, false>>("ttnn::add");
constexpr auto add_ =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::ADD, true>>("ttnn::add_");
constexpr auto subtract =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::SUB, false>>(
"ttnn::subtract");
constexpr auto subtract_ =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::SUB, true>>(
"ttnn::subtract_");
constexpr auto multiply =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::MUL, false>>(
"ttnn::multiply");
constexpr auto multiply_ =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::MUL, true>>(
"ttnn::multiply_");

constexpr auto eq =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::EQ, false>>("ttnn::eq");
constexpr auto ne =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::NE, false>>("ttnn::ne");
constexpr auto ge =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::GTE, false>>("ttnn::ge");
constexpr auto gt =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::GT, false>>("ttnn::gt");
constexpr auto le =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::LTE, false>>("ttnn::le");
constexpr auto lt =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::LT, false>>("ttnn::lt");
constexpr auto logical_and =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::LOGICAL_AND, false>>(
"ttnn::logical_and");
constexpr auto logical_or =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::LOGICAL_OR, false>>(
"ttnn::logical_or");
constexpr auto ldexp =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::LDEXP, false>>("ttnn::ldexp");

constexpr auto logaddexp =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::LOGADDEXP, false>>(
"ttnn::logaddexp");
constexpr auto logaddexp2 =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::LOGADDEXP2, false>>(
"ttnn::logaddexp2");
constexpr auto squared_difference =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::SQUARED_DIFFERENCE, false>>(
"ttnn::squared_difference");
constexpr auto divide =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::DIV_FAST, false>>(
"ttnn::divide");
constexpr auto bias_gelu =
ttnn::register_operation<operations::binary::Binary<operations::binary::BinaryOpType::BIAS_GELU, false>>(
"ttnn::bias_gelu");

template <typename InputBType>
ttnn::Tensor operator+(const ttnn::Tensor &input_tensor_a, InputBType scalar) {
Expand Down
Loading

0 comments on commit fc23d32

Please sign in to comment.