From 15e04bc63445b45f1cb7b124dc9ed2ad919981b6 Mon Sep 17 00:00:00 2001 From: Akhmed Rakhmati Date: Fri, 28 Jun 2024 19:46:45 +0000 Subject: [PATCH] #8835: add example template of a ttnn operation --- .../op_library/composite/composite_ops.hpp | 4 +- tt_metal/tools/profiler/op_profiler.hpp | 8 +- ttnn/CMakeLists.txt | 5 +- ttnn/cpp/pybind11/operations/__init__.hpp | 8 +- ttnn/cpp/ttnn/device_operation.hpp | 6 +- .../ttnn/operations/eltwise/binary/binary.hpp | 105 ++++++++++++------ ...ary_op.cpp => binary_device_operation.cpp} | 94 +++++++++------- ...ary_op.hpp => binary_device_operation.hpp} | 2 +- ...t_and_width_multi_core_program_factory.cpp | 7 +- ...cast_height_multi_core_program_factory.cpp | 21 ++-- ...dcast_width_multi_core_program_factory.cpp | 6 +- ...lement_wise_multi_core_program_factory.cpp | 6 +- .../device/example_device_operation.cpp | 40 +++++++ .../device/example_device_operation.hpp | 101 +++++++++++++++++ .../device/multi_core_program_factory.cpp | 29 +++++ .../device/single_core_program_factory.cpp | 28 +++++ .../operations/example/example/example.hpp | 35 ++++++ .../example/example/example_pybind.hpp | 31 ++++++ 18 files changed, 436 insertions(+), 100 deletions(-) rename ttnn/cpp/ttnn/operations/eltwise/binary/device/{binary_op.cpp => binary_device_operation.cpp} (83%) rename ttnn/cpp/ttnn/operations/eltwise/binary/device/{binary_op.hpp => binary_device_operation.hpp} (99%) create mode 100644 ttnn/cpp/ttnn/operations/example/example/device/example_device_operation.cpp create mode 100644 ttnn/cpp/ttnn/operations/example/example/device/example_device_operation.hpp create mode 100644 ttnn/cpp/ttnn/operations/example/example/device/multi_core_program_factory.cpp create mode 100644 ttnn/cpp/ttnn/operations/example/example/device/single_core_program_factory.cpp create mode 100644 ttnn/cpp/ttnn/operations/example/example/example.hpp create mode 100644 ttnn/cpp/ttnn/operations/example/example/example_pybind.hpp diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp index 31b399fc981a..8e2911a0a2fc 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp @@ -12,9 +12,7 @@ #include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_metal/common/constants.hpp" #include "ttnn/cpp/ttnn/operations/creation.hpp" - - -#include "ttnn/operations/eltwise/binary/device/binary_op.hpp" +#include "ttnn/operations/eltwise/binary/device/binary_device_operation.hpp" namespace tt { diff --git a/tt_metal/tools/profiler/op_profiler.hpp b/tt_metal/tools/profiler/op_profiler.hpp index 79c231ea50f9..9383aa6c9f1f 100644 --- a/tt_metal/tools/profiler/op_profiler.hpp +++ b/tt_metal/tools/profiler/op_profiler.hpp @@ -385,7 +385,13 @@ inline std::string op_meta_data_serialized_json( j["optional_input_tensors"] = std::vector{}; - auto perfModel = operation_t::create_op_performance_model(operation_attributes, tensor_args, tensor_return_value); + auto perfModel = [&]() { + if constexpr (requires { operation_t::create_op_performance_model; }) { + return operation_t::create_op_performance_model(operation_attributes, tensor_args, tensor_return_value); + } else { + return operation::OpPerformanceModel{}; + } + }(); j["performance_model"]["compute_ns"] = perfModel.get_compute_ns(); j["performance_model"]["ideal_ns"] = perfModel.get_ideal_ns(); j["performance_model"]["bandwidth_ns"] = perfModel.get_bandwidth_ns(); diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index c8262b3a3a27..d274d10412ee 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -5,12 +5,15 @@ set(TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary/device/unary_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/example/example/device/example_device_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/example/example/device/single_core_program_factory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/example/example/device/multi_core_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp ) diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index e16c04ceb7ef..a63333751b97 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -9,6 +9,7 @@ #include "pybind11/operations/ccl.hpp" #include "pybind11/operations/conv2d.hpp" +#include "pybind11/operations/copy.hpp" #include "pybind11/operations/core.hpp" #include "pybind11/operations/creation.hpp" #include "pybind11/operations/data_movement.hpp" @@ -18,7 +19,6 @@ #include "pybind11/operations/maxpool2d.hpp" #include "pybind11/operations/normalization.hpp" #include "pybind11/operations/pool.hpp" -#include "pybind11/operations/copy.hpp" #include "pybind11/operations/ternary.hpp" #include "pybind11/operations/transformer.hpp" @@ -26,7 +26,8 @@ #include "ttnn/operations/eltwise/unary/unary_pybind.hpp" #include "ttnn/operations/reduction/reduction_pybind.hpp" #include "ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp" - +#include "ttnn/operations/example/example/example_pybind.hpp" +#include "ttnn/operations/reduction/reduction_pybind.hpp" namespace py = pybind11; @@ -35,6 +36,9 @@ namespace ttnn { namespace operations { void py_module(py::module& module) { + auto m_example = module.def_submodule("example", "example operation"); + example::py_module(m_example); + auto m_unary = module.def_submodule("unary", "unary operations"); unary::py_module(m_unary); diff --git a/ttnn/cpp/ttnn/device_operation.hpp b/ttnn/cpp/ttnn/device_operation.hpp index ec9b2a93434d..0c822f6b59fe 100644 --- a/ttnn/cpp/ttnn/device_operation.hpp +++ b/ttnn/cpp/ttnn/device_operation.hpp @@ -80,7 +80,7 @@ concept DeviceOperationConcept = requires { }; template -concept DeviceOperationWithCustomProgramCacheConcept = DeviceOperationConcept and requires { +concept DeviceOperationConceptWithCustomProgramCacheConcept = DeviceOperationConcept and requires { [](auto&& program_factory, const typename device_operation_t::operation_attributes_t& operation_attributes, const typename device_operation_t::tensor_args_t& tensor_args) { @@ -91,7 +91,7 @@ concept DeviceOperationWithCustomProgramCacheConcept = DeviceOperationConcept [[nodiscard]] std::variant constexpr map_index_to_variant(std::size_t i, std::variant) { assert(i < sizeof...(Ts)); - static constexpr std::variant table[] = { Ts{ }... }; + static constexpr std::variant table[] = {Ts{}...}; return table[i]; } @@ -101,7 +101,7 @@ template inline auto compute_program_hash( const typename device_operation_t::operation_attributes_t& operation_attributes, const typename device_operation_t::tensor_args_t& tensor_args) { - if constexpr (DeviceOperationWithCustomProgramCacheConcept) { + if constexpr (DeviceOperationConceptWithCustomProgramCacheConcept) { ZoneScopedN("Compute custom program hash"); return device_operation_t::compute_program_hash(operation_attributes, tensor_args); } else { diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp index a0cf1c9d9381..ab7f8d7e1c0e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp @@ -5,7 +5,7 @@ #pragma once -#include "device/binary_op.hpp" +#include "device/binary_device_operation.hpp" #include "ttnn/device_operation.hpp" #include "ttnn/operations/data_movement.hpp" @@ -28,7 +28,7 @@ constexpr bool is_associative(BinaryOpType op) { } template -struct ExecuteBinary { +struct Binary { static inline const std::array input_tensor_schemas() { return { ttnn::TensorSchema{ @@ -108,11 +108,11 @@ struct ExecuteBinary { dtype = optional_output_tensor.value().get_dtype(); } - return ttnn::device_operation::run( + return ttnn::device_operation::run( queue_id, - Binary::operation_attributes_t{ + BinaryDeviceOperation::operation_attributes_t{ binary_op_type, in_place, activations, output_memory_config, dtype, std::nullopt}, - Binary::tensor_args_t{input_tensor_a, input_tensor_b, optional_output_tensor}); + BinaryDeviceOperation::tensor_args_t{input_tensor_a, input_tensor_b, optional_output_tensor}); } template @@ -145,8 +145,14 @@ struct ExecuteBinary { const std::optional &memory_config = std::nullopt, const std::optional &optional_output_tensor = std::nullopt, std::optional activations = std::nullopt) { - - return ExecuteBinary::execute_on_worker_thread(DefaultQueueId, input_tensor_a, scalar, dtype, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, optional_output_tensor, activations); + return Binary::execute_on_worker_thread( + DefaultQueueId, + input_tensor_a, + scalar, + dtype, + operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + optional_output_tensor, + activations); } template @@ -172,36 +178,71 @@ struct ExecuteBinary { Layout::TILE); Tensor scalar_tensor_device = scalar_tensor_host.to(input_tensor_a.device()); // TODO(arakhmati): #7637 pass in memory_config instead of operation::DEFAULT_OUTPUT_MEMORY_CONFIG - return ExecuteBinary::execute_on_worker_thread( - input_tensor_a, scalar_tensor_device, dtype, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, optional_output_tensor, activations); + return Binary::execute_on_worker_thread( + input_tensor_a, + scalar_tensor_device, + dtype, + operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + optional_output_tensor, + activations); } }; } // operations::binary -constexpr auto add = ttnn::register_operation>("ttnn::add"); -constexpr auto add_ = ttnn::register_operation>("ttnn::add_"); -constexpr auto subtract = ttnn::register_operation>("ttnn::subtract"); -constexpr auto subtract_ = ttnn::register_operation>("ttnn::subtract_"); -constexpr auto multiply = ttnn::register_operation>("ttnn::multiply"); -constexpr auto multiply_ = ttnn::register_operation>("ttnn::multiply_"); - -constexpr auto eq = ttnn::register_operation>("ttnn::eq"); -constexpr auto ne = ttnn::register_operation>("ttnn::ne"); -constexpr auto ge = ttnn::register_operation>("ttnn::ge"); -constexpr auto gt = ttnn::register_operation>("ttnn::gt"); -constexpr auto le = ttnn::register_operation>("ttnn::le"); -constexpr auto lt = ttnn::register_operation>("ttnn::lt"); -constexpr auto logical_and = ttnn::register_operation>("ttnn::logical_and"); -constexpr auto logical_or = ttnn::register_operation>("ttnn::logical_or"); -constexpr auto ldexp = ttnn::register_operation>("ttnn::ldexp"); - -constexpr auto logaddexp = ttnn::register_operation>("ttnn::logaddexp"); -constexpr auto logaddexp2 = ttnn::register_operation>("ttnn::logaddexp2"); -constexpr auto squared_difference = ttnn::register_operation>("ttnn::squared_difference"); -constexpr auto divide = ttnn::register_operation>("ttnn::divide"); -constexpr auto bias_gelu = ttnn::register_operation>("ttnn::bias_gelu"); - +constexpr auto add = + ttnn::register_operation>("ttnn::add"); +constexpr auto add_ = + ttnn::register_operation>("ttnn::add_"); +constexpr auto subtract = + ttnn::register_operation>( + "ttnn::subtract"); +constexpr auto subtract_ = + ttnn::register_operation>( + "ttnn::subtract_"); +constexpr auto multiply = + ttnn::register_operation>( + "ttnn::multiply"); +constexpr auto multiply_ = + ttnn::register_operation>( + "ttnn::multiply_"); + +constexpr auto eq = + ttnn::register_operation>("ttnn::eq"); +constexpr auto ne = + ttnn::register_operation>("ttnn::ne"); +constexpr auto ge = + ttnn::register_operation>("ttnn::ge"); +constexpr auto gt = + ttnn::register_operation>("ttnn::gt"); +constexpr auto le = + ttnn::register_operation>("ttnn::le"); +constexpr auto lt = + ttnn::register_operation>("ttnn::lt"); +constexpr auto logical_and = + ttnn::register_operation>( + "ttnn::logical_and"); +constexpr auto logical_or = + ttnn::register_operation>( + "ttnn::logical_or"); +constexpr auto ldexp = + ttnn::register_operation>("ttnn::ldexp"); + +constexpr auto logaddexp = + ttnn::register_operation>( + "ttnn::logaddexp"); +constexpr auto logaddexp2 = + ttnn::register_operation>( + "ttnn::logaddexp2"); +constexpr auto squared_difference = + ttnn::register_operation>( + "ttnn::squared_difference"); +constexpr auto divide = + ttnn::register_operation>( + "ttnn::divide"); +constexpr auto bias_gelu = + ttnn::register_operation>( + "ttnn::bias_gelu"); template ttnn::Tensor operator+(const ttnn::Tensor &input_tensor_a, InputBType scalar) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp similarity index 83% rename from ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp rename to ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp index 501f5d75ac6e..bf8031e47e55 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp @@ -2,8 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "binary_op.hpp" - +#include "binary_device_operation.hpp" #include "tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp" #include "tt_eager/tt_dnn/op_library/work_split.hpp" #include "tt_metal/common/constants.hpp" @@ -15,7 +14,10 @@ namespace utils { using namespace tt::tt_metal; std::map get_defines( - BinaryOpType op_type, const std::optional input_dtype, const std::optional output_dtype, const std::optional> fused_activations) { + BinaryOpType op_type, + const std::optional input_dtype, + const std::optional output_dtype, + const std::optional> fused_activations) { std::map defines; string op_name = "sub_tiles"; string op_binary_type = "EltwiseBinaryType::ELWSUB"; @@ -102,28 +104,29 @@ std::map get_defines( default: TT_ASSERT(false && "Undefined op type"); } - if(input_dtype.has_value() && output_dtype.has_value() && + if (input_dtype.has_value() && output_dtype.has_value() && ((input_dtype.value() == DataType::BFLOAT16 && output_dtype.value() == DataType::UINT32) || - (input_dtype.value() == DataType::BFLOAT16 && output_dtype.value() == DataType::UINT16) || - (input_dtype.value() == DataType::BFLOAT16 && output_dtype.value() == DataType::INT32) || - (input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::BFLOAT16) || - (input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::BFLOAT16) || - (input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::BFLOAT16) || - (input_dtype.value() == DataType::BFLOAT16 && output_dtype.value() == DataType::FLOAT32) || - (input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::UINT16) || - (input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::FLOAT32) || - (input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::INT32) || - (input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::FLOAT32) || - (input_dtype.value() == DataType::BFLOAT8_B && output_dtype.value() == DataType::UINT16) || - (input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::BFLOAT8_B) || - (input_dtype.value() == DataType::BFLOAT8_B && output_dtype.value() == DataType::INT32) || - (input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::BFLOAT8_B))){ + (input_dtype.value() == DataType::BFLOAT16 && output_dtype.value() == DataType::UINT16) || + (input_dtype.value() == DataType::BFLOAT16 && output_dtype.value() == DataType::INT32) || + (input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::BFLOAT16) || + (input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::BFLOAT16) || + (input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::BFLOAT16) || + (input_dtype.value() == DataType::BFLOAT16 && output_dtype.value() == DataType::FLOAT32) || + (input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::UINT16) || + (input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::FLOAT32) || + (input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::INT32) || + (input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::FLOAT32) || + (input_dtype.value() == DataType::BFLOAT8_B && output_dtype.value() == DataType::UINT16) || + (input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::BFLOAT8_B) || + (input_dtype.value() == DataType::BFLOAT8_B && output_dtype.value() == DataType::INT32) || + (input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::BFLOAT8_B))) { TT_ASSERT(defines.count("SFPU_OP_CHAIN_0") == 0 && "SFPU_OP_CHAIN_0 already defined"); auto in_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(input_dtype.value())); auto out_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(output_dtype.value())); - defines.insert({"SFPU_OP_CHAIN_0", - fmt::format("typecast_tile_init(); typecast_tile<{0}u, {1}u>(i);", in_dataformat, out_dataformat)}); + defines.insert( + {"SFPU_OP_CHAIN_0", + fmt::format("typecast_tile_init(); typecast_tile<{0}u, {1}u>(i);", in_dataformat, out_dataformat)}); defines.insert({"SFPU_OP_TYPECAST_INCLUDE", "1"}); } @@ -143,9 +146,9 @@ std::map get_defines( } // namespace utils -Binary::program_factory_t Binary::select_program_factory( +BinaryDeviceOperation::program_factory_t BinaryDeviceOperation::select_program_factory( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - ZoneScopedN("Binary::select_program_factory"); + ZoneScopedN("BinaryDeviceOperation::select_program_factory"); const auto& input_shape_a = tensor_args.input_tensor_a.tensor_attributes->shape; const auto& input_shape_b = tensor_args.input_tensor_b.tensor_attributes->shape; @@ -166,16 +169,16 @@ Binary::program_factory_t Binary::select_program_factory( return BroadcastWidthMultiCore{}; } } - TT_THROW("ttnn::operations::binary::Binary: unsupported broadcast"); + TT_THROW("ttnn::operations::binary::BinaryDeviceOperation: unsupported broadcast"); } -void Binary::validate_on_program_cache_miss( +void BinaryDeviceOperation::validate_on_program_cache_miss( const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { const auto& input_tensor_a = tensor_args.input_tensor_a; const auto& input_tensor_b = tensor_args.input_tensor_b; const auto& output_tensor = tensor_args.output_tensor; - Binary::validate_on_program_cache_hit(attributes, tensor_args); + BinaryDeviceOperation::validate_on_program_cache_hit(attributes, tensor_args); TT_FATAL( input_tensor_a.device() == input_tensor_b.device(), @@ -243,7 +246,8 @@ void Binary::validate_on_program_cache_miss( "ignored"); } } -void Binary::validate_on_program_cache_hit(const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { +void BinaryDeviceOperation::validate_on_program_cache_hit( + const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { const auto& input_tensor_a = tensor_args.input_tensor_a; const auto& input_tensor_b = tensor_args.input_tensor_b; const auto& output_tensor = tensor_args.output_tensor; @@ -265,22 +269,24 @@ void Binary::validate_on_program_cache_hit(const operation_attributes_t& attribu if (batch_size_0_a != batch_size_0_b) { TT_ASSERT( batch_size_0_a > batch_size_0_b and batch_size_0_b == 1, - "ttnn::operations::binary::Binary: batch size mismatch"); + "ttnn::operations::binary::BinaryDeviceOperation: batch size mismatch"); } if (batch_size_1_a != batch_size_1_b) { TT_ASSERT( batch_size_1_a > batch_size_1_b and batch_size_1_b == 1, - "ttnn::operations::binary::Binary: batch size mismatch"); + "ttnn::operations::binary::BinaryDeviceOperation: batch size mismatch"); } if (height_a != height_b) { - TT_ASSERT(height_a > height_b and height_b == 1, "ttnn::operations::binary::Binary: height mismatch"); + TT_ASSERT( + height_a > height_b and height_b == 1, "ttnn::operations::binary::BinaryDeviceOperation: height mismatch"); } if (width_a != width_b) { - TT_ASSERT(width_a > width_b and width_b == 1, "ttnn::operations::binary::Binary: width mismatch"); + TT_ASSERT( + width_a > width_b and width_b == 1, "ttnn::operations::binary::BinaryDeviceOperation: width mismatch"); } } -Binary::shape_return_value_t Binary::compute_output_shapes( +BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_output_shapes( const operation_attributes_t&, const tensor_args_t& tensor_args) { const auto input_shape_a = tensor_args.input_tensor_a.tensor_attributes->shape; const auto input_shape_b = tensor_args.input_tensor_b.tensor_attributes->shape; @@ -303,7 +309,7 @@ Binary::shape_return_value_t Binary::compute_output_shapes( return ttnn::Shape::from_vector(output_shape, output_shape_with_tile_padding); } -Binary::tensor_return_value_t Binary::create_output_tensors( +BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { auto output_shape = compute_output_shapes(operation_attributes, tensor_args); const auto& input_tensor_a = tensor_args.input_tensor_a; @@ -361,15 +367,27 @@ Binary::tensor_return_value_t Binary::create_output_tensors( } } -tt::stl::hash::hash_t Binary::compute_program_hash( +tt::stl::hash::hash_t BinaryDeviceOperation::compute_program_hash( const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { const auto& input_tensor_a = tensor_args.input_tensor_a; const auto& input_tensor_b = tensor_args.input_tensor_b; auto program_factory = select_program_factory(attributes, tensor_args); - TT_ASSERT(std::holds_alternative(input_tensor_a.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(input_tensor_a.get_storage()),__FILE__, __LINE__)); - TT_ASSERT(std::holds_alternative(input_tensor_b.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(input_tensor_b.get_storage()),__FILE__, __LINE__)); - operation::Hash hash = operation::hash_operation( + TT_ASSERT( + std::holds_alternative(input_tensor_a.get_storage()), + fmt::format( + "Unexpected type {} in {}:{} ", + tt::stl::get_active_type_name_in_variant(input_tensor_a.get_storage()), + __FILE__, + __LINE__)); + TT_ASSERT( + std::holds_alternative(input_tensor_b.get_storage()), + fmt::format( + "Unexpected type {} in {}:{} ", + tt::stl::get_active_type_name_in_variant(input_tensor_b.get_storage()), + __FILE__, + __LINE__)); + operation::Hash hash = operation::hash_operation( attributes, program_factory.index(), input_tensor_a.dtype(), @@ -379,7 +397,7 @@ tt::stl::hash::hash_t Binary::compute_program_hash( return hash; } -operation::OpPerformanceModel Binary::create_op_performance_model( +operation::OpPerformanceModel BinaryDeviceOperation::create_op_performance_model( const operation_attributes_t& attributes, const tensor_args_t& tensor_args, tensor_return_value_t& tensor_return_value) { @@ -399,7 +417,7 @@ operation::OpPerformanceModel Binary::create_op_performance_model( // TODO: update OpPerformanceModel to work on variadic arguments operation::OpPerformanceModel result({input_tensor_a, input_tensor_b}, {output_tensor}, ideal_eltwise_cycles); #if 0 - tt::log_info(tt::LogOp, "Binary PerfModel:"); + tt::log_info(tt::LogOp, "BinaryDeviceOperation PerfModel:"); tt::log_info(tt::LogOp, "\t Data (Bytes): {}", total_bytes); tt::log_info(tt::LogOp, "\t ideal_eltwise_cycles: {}", ideal_eltwise_cycles); #endif diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp similarity index 99% rename from ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.hpp rename to ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp index c45bff9fde35..feb043f99c13 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp @@ -57,7 +57,7 @@ std::map get_defines( } // namespace utils -struct Binary { +struct BinaryDeviceOperation { struct operation_attributes_t { BinaryOpType binary_op_type; bool in_place; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp index e3c09d9c5453..18b3f24d435a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp @@ -4,7 +4,7 @@ #include -#include "binary_op.hpp" +#include "binary_device_operation.hpp" #include "impl/buffers/buffer.hpp" #include "tensor/tensor.hpp" #include "tt_dnn/op_library/bcast/bcast_op.hpp" @@ -25,7 +25,8 @@ static const tt::tt_metal::BcastOpMath binary_op_type_to_bcast_op_math(const Bin } } -Binary::BroadcastHeightAndWidthMultiCore::cached_program_t Binary::BroadcastHeightAndWidthMultiCore::create( +BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::cached_program_t +BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, tensor_return_value_t& tensor_return_value) { @@ -229,7 +230,7 @@ Binary::BroadcastHeightAndWidthMultiCore::cached_program_t Binary::BroadcastHeig cb_output}}; } -void Binary::BroadcastHeightAndWidthMultiCore::override_runtime_arguments( +void BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::override_runtime_arguments( cached_program_t& cached_program, const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp index 226091271201..5f3a99cb6ef8 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "binary_op.hpp" +#include "binary_device_operation.hpp" #include "tensor/tensor.hpp" #include "tt_dnn/op_library/bcast/bcast_op.hpp" #include "tt_dnn/op_library/work_split.hpp" @@ -22,10 +22,11 @@ static const tt::tt_metal::BcastOpMath binary_op_type_to_bcast_op_math(const Bin } } -Binary::BroadcastHeightMultiCore::cached_program_t Binary :: BroadcastHeightMultiCore::create( - const operation_attributes_t& operation_attributes, - const tensor_args_t& tensor_args, - tensor_return_value_t& tensor_return_value) { +BinaryDeviceOperation::BroadcastHeightMultiCore::cached_program_t +BinaryDeviceOperation ::BroadcastHeightMultiCore::create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { using namespace tt; using namespace tt::tt_metal; @@ -200,11 +201,11 @@ Binary::BroadcastHeightMultiCore::cached_program_t Binary :: BroadcastHeightMult }; } -void Binary :: BroadcastHeightMultiCore::override_runtime_arguments( - cached_program_t& cached_program, - const operation_attributes_t& operation_attributes, - const tensor_args_t& tensor_args, - tensor_return_value_t& tensor_return_value) { +void BinaryDeviceOperation ::BroadcastHeightMultiCore::override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { using namespace tt; using namespace tt::tt_metal; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp index b595735161c5..3bdd6e935723 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "binary_op.hpp" +#include "binary_device_operation.hpp" #include "tensor/tensor.hpp" #include "tt_dnn/op_library/bcast/bcast_op.hpp" #include "tt_dnn/op_library/work_split.hpp" @@ -22,7 +22,7 @@ static const tt::tt_metal::BcastOpMath binary_op_type_to_bcast_op_math(const Bin } } -Binary::BroadcastWidthMultiCore::cached_program_t Binary::BroadcastWidthMultiCore::create( +BinaryDeviceOperation::BroadcastWidthMultiCore::cached_program_t BinaryDeviceOperation::BroadcastWidthMultiCore::create( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, tensor_return_value_t& tensor_return_value) { @@ -199,7 +199,7 @@ Binary::BroadcastWidthMultiCore::cached_program_t Binary::BroadcastWidthMultiCor {binary_reader_kernel_id, unary_writer_kernel_id, bcast_kernel_id, compute_with_storage_grid_size}}; } -void Binary::BroadcastWidthMultiCore::override_runtime_arguments( +void BinaryDeviceOperation::BroadcastWidthMultiCore::override_runtime_arguments( cached_program_t& cached_program, const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp index 521a83d0d8d6..5f2bfcdcad19 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp @@ -4,7 +4,7 @@ #include -#include "binary_op.hpp" +#include "binary_device_operation.hpp" #include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_eager/tt_dnn/op_library/work_split.hpp" #include "tt_metal/common/constants.hpp" @@ -245,7 +245,7 @@ inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args( UpdateCircularBufferTotalSize(program, cb_output, num_tiles_per_core_group_1 * dst_single_tile_size); } } -Binary::ElementWiseMultiCore::cached_program_t Binary::ElementWiseMultiCore::create( +BinaryDeviceOperation::ElementWiseMultiCore::cached_program_t BinaryDeviceOperation::ElementWiseMultiCore::create( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, tensor_return_value_t& tensor_return_value) { @@ -425,7 +425,7 @@ Binary::ElementWiseMultiCore::cached_program_t Binary::ElementWiseMultiCore::cre dst_single_tile_size}}; } -void Binary::ElementWiseMultiCore::override_runtime_arguments( +void BinaryDeviceOperation::ElementWiseMultiCore::override_runtime_arguments( cached_program_t& cached_program, const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, diff --git a/ttnn/cpp/ttnn/operations/example/example/device/example_device_operation.cpp b/ttnn/cpp/ttnn/operations/example/example/device/example_device_operation.cpp new file mode 100644 index 000000000000..62d328d8c8ef --- /dev/null +++ b/ttnn/cpp/ttnn/operations/example/example/device/example_device_operation.cpp @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "example_device_operation.hpp" + +namespace ttnn::operations::example { + +ExampleDeviceOperation::program_factory_t ExampleDeviceOperation::select_program_factory( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + bool some_condition_based_on_operation_attributes_and_or_tensor_args = true; + if (some_condition_based_on_operation_attributes_and_or_tensor_args) { + return SingleCore{}; + } + return MultiCore{}; +} + +void ExampleDeviceOperation::validate_on_program_cache_miss( + const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {} + +void ExampleDeviceOperation::validate_on_program_cache_hit( + const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {} + +ExampleDeviceOperation::shape_return_value_t ExampleDeviceOperation::compute_output_shapes( + const operation_attributes_t&, const tensor_args_t& tensor_args) { + return tensor_args.input_tensor.tensor_attributes->shape; +} + +ExampleDeviceOperation::tensor_return_value_t ExampleDeviceOperation::create_output_tensors( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + auto output_shape = compute_output_shapes(operation_attributes, tensor_args); + const auto& input_tensor = tensor_args.input_tensor; + return create_device_tensor( + output_shape, + input_tensor.tensor_attributes->dtype, + input_tensor.tensor_attributes->layout, + input_tensor.device()); +} + +} // namespace ttnn::operations::example diff --git a/ttnn/cpp/ttnn/operations/example/example/device/example_device_operation.hpp b/ttnn/cpp/ttnn/operations/example/example/device/example_device_operation.hpp new file mode 100644 index 000000000000..7359dda5ccea --- /dev/null +++ b/ttnn/cpp/ttnn/operations/example/example/device/example_device_operation.hpp @@ -0,0 +1,101 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "tensor/tensor.hpp" +#include "ttnn/core.hpp" +#include "ttnn/device_operation.hpp" +#include "ttnn/types.hpp" + +namespace ttnn::operations::example { + +struct ExampleDeviceOperation { + struct operation_attributes_t { + bool attribute; + int some_other_attribute; + }; + struct tensor_args_t { + // An example of the tensor that can only be used as an input + const Tensor& input_tensor; + + // An example of the tensor that can be used for input/output or just for pre-allocated output + // Tensor& io_tensor; + + // An example of an optional tensor + // std::optional optional_output_tensor; + + // An example of a vector of tensors + // std::vector vector_of_tensors; + + // An example of a tuple of tensors + // std::tuple tuple_of_tensors; + + // An example of a vector of optional tensors + // std::vector> vector_of_optional_tensors; + + // An example of a tuple of tensors + // std::tuple>, std::optional> some_crazy_tuple_of_tensors; + }; + + // Can be a single ttnn::Shape, std::optional, std::vector, std::tuple etc. + using shape_return_value_t = ttnn::Shape; + + // Can be a single Tensor, std::optional, std::vector, std::tuple etc. + using tensor_return_value_t = Tensor; + + struct SingleCore { + struct cached_program_attributes_t { + int some_program_attribute_to_share; + }; + using cached_program_t = ttnn::device_operation::CachedProgram; + + static cached_program_t create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + + static void override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + }; + + struct MultiCore { + struct cached_program_attributes_t { + int some_program_attribute_to_share; + int some_other_program_attribute_to_share; + }; + using cached_program_t = ttnn::device_operation::CachedProgram; + + static cached_program_t create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + + static void override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + }; + + using program_factory_t = std::variant; + + static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); + + static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); + + static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + + static tensor_return_value_t create_output_tensors( + const operation_attributes_t& operation_attributes, const tensor_args_t&); +}; + +} // namespace ttnn::operations::example diff --git a/ttnn/cpp/ttnn/operations/example/example/device/multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/example/example/device/multi_core_program_factory.cpp new file mode 100644 index 000000000000..97e089e95449 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/example/example/device/multi_core_program_factory.cpp @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "example_device_operation.hpp" + +namespace ttnn::operations::example { +ExampleDeviceOperation::MultiCore::cached_program_t ExampleDeviceOperation::MultiCore::create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { + using namespace tt; + using namespace tt::tt_metal; + + tt_metal::Program program = tt_metal::CreateProgram(); + + int some_program_attribute_to_share = 0; + int some_other_program_attribute_to_share = 0; + + return {std::move(program), {some_program_attribute_to_share, some_other_program_attribute_to_share}}; +} + +void ExampleDeviceOperation::MultiCore::override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) {} + +} // namespace ttnn::operations::example diff --git a/ttnn/cpp/ttnn/operations/example/example/device/single_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/example/example/device/single_core_program_factory.cpp new file mode 100644 index 000000000000..31852725f60c --- /dev/null +++ b/ttnn/cpp/ttnn/operations/example/example/device/single_core_program_factory.cpp @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "example_device_operation.hpp" + +namespace ttnn::operations::example { +ExampleDeviceOperation::SingleCore::cached_program_t ExampleDeviceOperation::SingleCore::create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { + using namespace tt; + using namespace tt::tt_metal; + + tt_metal::Program program = tt_metal::CreateProgram(); + + int some_program_attribute_to_share = 0; + + return {std::move(program), {some_program_attribute_to_share}}; +} + +void ExampleDeviceOperation::SingleCore::override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) {} + +} // namespace ttnn::operations::example diff --git a/ttnn/cpp/ttnn/operations/example/example/example.hpp b/ttnn/cpp/ttnn/operations/example/example/example.hpp new file mode 100644 index 000000000000..19f7744ec877 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/example/example/example.hpp @@ -0,0 +1,35 @@ + +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "device/example_device_operation.hpp" + +namespace ttnn::operations::example { + +struct ExampleOperation { + + static Tensor execute_on_main_thread( + uint8_t queue_id, + const Tensor &input_tensor) { + return ttnn::device_operation::run( + queue_id, + ExampleDeviceOperation::operation_attributes_t{}, + ExampleDeviceOperation::tensor_args_t{input_tensor}); + } + + static Tensor execute_on_main_thread( + const Tensor &input_tensor) { + return execute_on_main_thread(0, input_tensor); + } +}; + +} // namespace ttnn::operations::binary + +namespace ttnn { + +constexpr auto example = ttnn::register_operation("ttnn::example"); + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/example/example/example_pybind.hpp b/ttnn/cpp/ttnn/operations/example/example/example_pybind.hpp new file mode 100644 index 000000000000..94ae3c8dd098 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/example/example/example_pybind.hpp @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" +#include "ttnn/operations/example/example/example.hpp" +#include "ttnn/types.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::example { + +void py_module(py::module& module) { + bind_registered_operation( + module, + ttnn::example, + R"doc(example(input_tensor: ttnn.Tensor) -> ttnn.Tensor)doc", + ttnn::pybind_overload_t{ + [](const decltype(ttnn::example)& self, const ttnn::Tensor& input_tensor, const uint8_t& queue_id) + -> ttnn::Tensor { return self(queue_id, input_tensor); }, + py::arg("input_tensor"), + py::kw_only(), + py::arg("queue_id") = 0}); +} + +} // namespace ttnn::operations::example