diff --git a/docs/source/ttnn/ttnn/adding_new_ttnn_operation.rst b/docs/source/ttnn/ttnn/adding_new_ttnn_operation.rst index d97e86bfc9bd..e16a6a14ff70 100644 --- a/docs/source/ttnn/ttnn/adding_new_ttnn_operation.rst +++ b/docs/source/ttnn/ttnn/adding_new_ttnn_operation.rst @@ -7,174 +7,98 @@ Adding New ttnn Operation Not all operations may be functional on all Tenstorrent hardware (Grayskull, Wormhole, or others). -C++ Implementation ------------------- - - -Add `tt_eager/tt_dnn/op_library//.hpp`: - -.. code-block:: cpp - - #pragma once - - #include - - #include "tensor/tensor.hpp" - #include "tt_dnn/op_library/operation.hpp" - - namespace tt { - namespace tt_metal { - - struct { - bool some_arg; - - // These methods are needed if the operation takes in input tensor and produces output tensors - void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; - std::vector create_output_tensors(const std::vector &input_tensors) const; - operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; - - // This is needed until we get C++20 - static constexpr auto attribute_names = std::forward_as_tuple("some_arg"); - const auto attribute_values() const { - return std::forward_as_tuple(this->some_arg); - } - }; - - Tensor (const Tensor &input_tensor, bool some_arg); - } // namespace tt_metal - } // namespace tt +What is a ttnn operation? +------------------------- +A ttnn operation is a function that takes in one or more input tensors and produces one or more output tensors. It is implemented in C++ and can be called from Python. -.. note: +What steps are needed to add ttnn operation in C++? +--------------------------------------------------- +1. (Optional) Implement device operation in C++. Device operation is a struct that specifies how to create output tensors and a program to run on the device. If the ttnn operation is composed of other ttnn operations, then you can skip this step. +2. Implement ttnn operation in C++ and register it using `ttnn::register_operation`. - If you need optional input tensors or would like to pass in optional output tensors, then refer to :doc:`Operations ` for how to write ops that use them +What steps are needed to add ttnn operation in Python? +------------------------------------------------------ +1. Take an existing registerd C++ operation and add a Python binding for it using `ttnn::bind_registered_operation`. +2. In python, decorate the operation using `ttnn.register_operation`. (This step will be deprecated in the future) -Add `tt_eager/tt_dnn/op_library//.cpp`: -.. code-block:: cpp - - #include "tt_metal/host_api.hpp" - #include "tt_dnn/op_library/run_operation.hpp" - - namespace tt { - namespace tt_metal { - - - void ::validate(const std::vector &input_tensors) const { - ... - } - - std::vector ::compute_output_shapes(const std::vector &input_tensors) const { - std::vector output_shapes = ...; - return output_shapes; - } - - std::vector create_output_tensors(const std::vector &input_tensors) const { - std::vector output_tensors = ...; - return output_tensors; - } - - operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const { - Program program = ...; - return operation::ProgramWithCallbacks{program}; - } - - }; +C++ Implementation +------------------ - Tensor (const Tensor &input_tensor, bool some_arg) { - std::vector input_tensors = {input_tensor}; - std::vector output_tensors operation::run(DeviceOperation({some_arg}, {input_tensor})); - return output_tensors[0]; - } +Step 1: Implement device operation (Optional) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - } // namespace tt_metal - } // namespace tt +In order to add a new device operation, follow the directory structure shown below: +`ttnn/cpp/ttnn/operations///device/_device_operation.hpp` +`ttnn/cpp/ttnn/operations///device/_device_operation.cpp` +`ttnn/cpp/ttnn/operations///device/_program_factory.cpp` -Add pybindings --------------- +.. note:: + Add as many program factories as needed -In `tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp`, add the following lines +A concrete example of a device operation can be found in `ttnn/cpp/ttnn/operations/examples/example/device` -.. code-block:: cpp +`ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.hpp`: - m_tensor.def("", &, py::arg("input_tensor").noconvert(), py::arg("some_arg").noconvert(), R"doc( - runs new operation on input tensor. +.. literalinclude:: examples/example/device/example_device_operation.hpp - .. csv-table:: - :header: "Argument", "Description", "Data type", "Valid range", "Required" +`ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.cpp`: - "input_tensor", "Input tensor", "Tensor", "Tensor of shape [W0, Z0, Y0, X0]", "Yes" - "some_arg", "Some arg", "bool", "Some arg to do some stuff in new operation", "Yes" - )doc"); +.. literalinclude:: examples/example/device/example_device_operation.cpp +`ttnn/cpp/ttnn/operations/examples/example/device/single_core_program_factory.cpp`: +.. literalinclude:: examples/example/device/single_core_program_factory.cpp -Adding a unit test ------------------- +`ttnn/cpp/ttnn/operations/examples/example/device/multi_core_program_factory.cpp`: -Add `tests/ttnn/unit_tests/ttl/test_.py`: +.. literalinclude:: examples/example/device/multi_core_program_factory.cpp -.. code-block:: python - import pytest - import torch - import ttnn +Step 2: Implement the operation in C++ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - from tests.ttnn.utils_for_testing import assert_with_pcc +In order to add a new operation, add the following file: - @pytest.mark.parametrize("height", [32]) - @pytest.mark.parametrize("width", [32]) - def test_(device, height, width): - torch.manual_seed(0) +`ttnn/cpp/ttnn/operations///.hpp` - torch_input_tensor = torch.rand(1, 1, height, width) - torch_output_tensor = torch.exp(torch_input_tensor) +A concrete example: - input_tensor = ttnn.from_torch(torch_input_tensor, device=device) - output_tensor = ttnn.experimental.tensor.(input_tensor) +`ttnn/cpp/ttnn/operations/examples/example/example.hpp`: - output_tensor = ttnn.to_torch(output_tensor) +.. literalinclude:: examples/example/example.hpp - assert_with_pcc(torch_output_tensor, output_tensor) +Python Implementation +--------------------- +Step 1: Add Python binding +~~~~~~~~~~~~~~~~~~~~~~~~~~ -Adding a sweep test -------------------- +In order to add a python binding for the operation, follow the directory structure shown below: -Add `tests/ttnn/sweep_tests/sweeps/ttl_.py`: +`ttnn/python/ttnn/operations///_pybind.hpp` +`ttnn/python/ttnn/operations//_pybind.hpp` -.. code-block:: python +A concrete example: - from typing import Optional, Tuples - import torch - import ttnn - from tests.ttnn.utils_for_testing import check_with_pcc +`ttnn/python/ttnn/operations/examples/example/example_pybind.hpp`: +.. literalinclude:: examples/example/example_pybind.hpp - parameters = { - "height": [384, 1024], - "width": [1024, 4096], - } +`ttnn/python/ttnn/operations/examples/examples_pybind.hpp`: +.. literalinclude:: examples/example/example_pybind.hpp - def run( - height, - width, - *, - device, - ) -> Tuple[bool, Optional[str]]: +Finally, call the module defined in `examples/example/example_pybind.hpp` wherever you want it to be added. - torch_input_tensor = torch.rand(1, 1, height, width) - torch_output_tensor = torch.exp(torch_input_tensor) - input_tensor = ttnn.from_torch(torch_input_tensor, device=device) - output_tensor = ttnn.experimental.tensor.(input_tensor) - output_tensor = ttnn.to_torch(output_tensor) +Step 2: Register the operation in Python +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - assert_with_pcc(torch_output_tensor, output_tensor) +In order to add a new operation, follow the directory structure shown below: diff --git a/docs/source/ttnn/ttnn/examples b/docs/source/ttnn/ttnn/examples new file mode 120000 index 000000000000..518d351ad81f --- /dev/null +++ b/docs/source/ttnn/ttnn/examples @@ -0,0 +1 @@ +../../../../ttnn/cpp/ttnn/operations/examples/ \ No newline at end of file diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp index 5183005b76d2..562c7a91cb5f 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp @@ -12,10 +12,8 @@ #include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_metal/common/constants.hpp" #include "ttnn/cpp/ttnn/operations/creation.hpp" - - -#include "ttnn/operations/eltwise/binary/device/binary_op.hpp" #include "ttnn/operations/data_movement/pad/pad.hpp" +#include "ttnn/operations/eltwise/binary/device/binary_device_operation.hpp" namespace tt { diff --git a/tt_metal/tools/profiler/op_profiler.hpp b/tt_metal/tools/profiler/op_profiler.hpp index 79c231ea50f9..9383aa6c9f1f 100644 --- a/tt_metal/tools/profiler/op_profiler.hpp +++ b/tt_metal/tools/profiler/op_profiler.hpp @@ -385,7 +385,13 @@ inline std::string op_meta_data_serialized_json( j["optional_input_tensors"] = std::vector{}; - auto perfModel = operation_t::create_op_performance_model(operation_attributes, tensor_args, tensor_return_value); + auto perfModel = [&]() { + if constexpr (requires { operation_t::create_op_performance_model; }) { + return operation_t::create_op_performance_model(operation_attributes, tensor_args, tensor_return_value); + } else { + return operation::OpPerformanceModel{}; + } + }(); j["performance_model"]["compute_ns"] = perfModel.get_compute_ns(); j["performance_model"]["ideal_ns"] = perfModel.get_ideal_ns(); j["performance_model"]["bandwidth_ns"] = perfModel.get_bandwidth_ns(); diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index c8262b3a3a27..9523974bed23 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -5,12 +5,15 @@ set(TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary/device/unary_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/examples/example/device/example_device_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/examples/example/device/single_core_program_factory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/examples/example/device/multi_core_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp ) diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index 9a08c23647d6..633a8003d638 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -9,6 +9,7 @@ #include "pybind11/operations/ccl.hpp" #include "pybind11/operations/conv2d.hpp" +#include "pybind11/operations/copy.hpp" #include "pybind11/operations/core.hpp" #include "pybind11/operations/creation.hpp" #include "pybind11/operations/embedding.hpp" @@ -17,12 +18,12 @@ #include "pybind11/operations/maxpool2d.hpp" #include "pybind11/operations/normalization.hpp" #include "pybind11/operations/pool.hpp" -#include "pybind11/operations/copy.hpp" #include "pybind11/operations/ternary.hpp" #include "pybind11/operations/transformer.hpp" - #include "ttnn/operations/eltwise/binary/binary_pybind.hpp" +#include "ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp" #include "ttnn/operations/eltwise/unary/unary_pybind.hpp" +#include "ttnn/operations/examples/examples_pybind.hpp" #include "ttnn/operations/reduction/reduction_pybind.hpp" #include "ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp" #include "ttnn/operations/data_movement/data_movement_pybind.hpp" @@ -35,6 +36,9 @@ namespace ttnn { namespace operations { void py_module(py::module& module) { + auto m_example = module.def_submodule("example", "example operation"); + examples::py_module(m_example); + auto m_unary = module.def_submodule("unary", "unary operations"); unary::py_module(m_unary); diff --git a/ttnn/cpp/ttnn/device_operation.hpp b/ttnn/cpp/ttnn/device_operation.hpp index ec9b2a93434d..d7d37f48e858 100644 --- a/ttnn/cpp/ttnn/device_operation.hpp +++ b/ttnn/cpp/ttnn/device_operation.hpp @@ -20,14 +20,14 @@ namespace ttnn { namespace device_operation { -template +template struct CachedProgram { tt::tt_metal::Program program; - // Cached program needs to share program_attributes between create and override_runtime_arguments functions - program_attributes_t program_attributes; + // Cached program needs to share shared_variables between create and override_runtime_arguments functions + shared_variables_t shared_variables; - CachedProgram(tt::tt_metal::Program&& program, program_attributes_t&& program_attributes) : - program{std::move(program)}, program_attributes{program_attributes} {} + CachedProgram(tt::tt_metal::Program&& program, shared_variables_t&& shared_variables) : + program{std::move(program)}, shared_variables{shared_variables} {} }; struct CachedProgramFactory { @@ -38,8 +38,8 @@ struct CachedProgramFactory { // program_factory_index is used to map a runtime value to a program factory type that is being used std::size_t program_factory_index; - template - CachedProgramFactory(CachedProgram&& cached_program, std::size_t program_factory_index) : + template + CachedProgramFactory(CachedProgram&& cached_program, std::size_t program_factory_index) : cached_program{std::move(cached_program)}, program_factory_index{program_factory_index} {} }; @@ -91,7 +91,7 @@ concept DeviceOperationWithCustomProgramCacheConcept = DeviceOperationConcept [[nodiscard]] std::variant constexpr map_index_to_variant(std::size_t i, std::variant) { assert(i < sizeof...(Ts)); - static constexpr std::variant table[] = { Ts{ }... }; + static constexpr std::variant table[] = {Ts{}...}; return table[i]; } diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp index a0cf1c9d9381..ab7f8d7e1c0e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp @@ -5,7 +5,7 @@ #pragma once -#include "device/binary_op.hpp" +#include "device/binary_device_operation.hpp" #include "ttnn/device_operation.hpp" #include "ttnn/operations/data_movement.hpp" @@ -28,7 +28,7 @@ constexpr bool is_associative(BinaryOpType op) { } template -struct ExecuteBinary { +struct Binary { static inline const std::array input_tensor_schemas() { return { ttnn::TensorSchema{ @@ -108,11 +108,11 @@ struct ExecuteBinary { dtype = optional_output_tensor.value().get_dtype(); } - return ttnn::device_operation::run( + return ttnn::device_operation::run( queue_id, - Binary::operation_attributes_t{ + BinaryDeviceOperation::operation_attributes_t{ binary_op_type, in_place, activations, output_memory_config, dtype, std::nullopt}, - Binary::tensor_args_t{input_tensor_a, input_tensor_b, optional_output_tensor}); + BinaryDeviceOperation::tensor_args_t{input_tensor_a, input_tensor_b, optional_output_tensor}); } template @@ -145,8 +145,14 @@ struct ExecuteBinary { const std::optional &memory_config = std::nullopt, const std::optional &optional_output_tensor = std::nullopt, std::optional activations = std::nullopt) { - - return ExecuteBinary::execute_on_worker_thread(DefaultQueueId, input_tensor_a, scalar, dtype, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, optional_output_tensor, activations); + return Binary::execute_on_worker_thread( + DefaultQueueId, + input_tensor_a, + scalar, + dtype, + operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + optional_output_tensor, + activations); } template @@ -172,36 +178,71 @@ struct ExecuteBinary { Layout::TILE); Tensor scalar_tensor_device = scalar_tensor_host.to(input_tensor_a.device()); // TODO(arakhmati): #7637 pass in memory_config instead of operation::DEFAULT_OUTPUT_MEMORY_CONFIG - return ExecuteBinary::execute_on_worker_thread( - input_tensor_a, scalar_tensor_device, dtype, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, optional_output_tensor, activations); + return Binary::execute_on_worker_thread( + input_tensor_a, + scalar_tensor_device, + dtype, + operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + optional_output_tensor, + activations); } }; } // operations::binary -constexpr auto add = ttnn::register_operation>("ttnn::add"); -constexpr auto add_ = ttnn::register_operation>("ttnn::add_"); -constexpr auto subtract = ttnn::register_operation>("ttnn::subtract"); -constexpr auto subtract_ = ttnn::register_operation>("ttnn::subtract_"); -constexpr auto multiply = ttnn::register_operation>("ttnn::multiply"); -constexpr auto multiply_ = ttnn::register_operation>("ttnn::multiply_"); - -constexpr auto eq = ttnn::register_operation>("ttnn::eq"); -constexpr auto ne = ttnn::register_operation>("ttnn::ne"); -constexpr auto ge = ttnn::register_operation>("ttnn::ge"); -constexpr auto gt = ttnn::register_operation>("ttnn::gt"); -constexpr auto le = ttnn::register_operation>("ttnn::le"); -constexpr auto lt = ttnn::register_operation>("ttnn::lt"); -constexpr auto logical_and = ttnn::register_operation>("ttnn::logical_and"); -constexpr auto logical_or = ttnn::register_operation>("ttnn::logical_or"); -constexpr auto ldexp = ttnn::register_operation>("ttnn::ldexp"); - -constexpr auto logaddexp = ttnn::register_operation>("ttnn::logaddexp"); -constexpr auto logaddexp2 = ttnn::register_operation>("ttnn::logaddexp2"); -constexpr auto squared_difference = ttnn::register_operation>("ttnn::squared_difference"); -constexpr auto divide = ttnn::register_operation>("ttnn::divide"); -constexpr auto bias_gelu = ttnn::register_operation>("ttnn::bias_gelu"); - +constexpr auto add = + ttnn::register_operation>("ttnn::add"); +constexpr auto add_ = + ttnn::register_operation>("ttnn::add_"); +constexpr auto subtract = + ttnn::register_operation>( + "ttnn::subtract"); +constexpr auto subtract_ = + ttnn::register_operation>( + "ttnn::subtract_"); +constexpr auto multiply = + ttnn::register_operation>( + "ttnn::multiply"); +constexpr auto multiply_ = + ttnn::register_operation>( + "ttnn::multiply_"); + +constexpr auto eq = + ttnn::register_operation>("ttnn::eq"); +constexpr auto ne = + ttnn::register_operation>("ttnn::ne"); +constexpr auto ge = + ttnn::register_operation>("ttnn::ge"); +constexpr auto gt = + ttnn::register_operation>("ttnn::gt"); +constexpr auto le = + ttnn::register_operation>("ttnn::le"); +constexpr auto lt = + ttnn::register_operation>("ttnn::lt"); +constexpr auto logical_and = + ttnn::register_operation>( + "ttnn::logical_and"); +constexpr auto logical_or = + ttnn::register_operation>( + "ttnn::logical_or"); +constexpr auto ldexp = + ttnn::register_operation>("ttnn::ldexp"); + +constexpr auto logaddexp = + ttnn::register_operation>( + "ttnn::logaddexp"); +constexpr auto logaddexp2 = + ttnn::register_operation>( + "ttnn::logaddexp2"); +constexpr auto squared_difference = + ttnn::register_operation>( + "ttnn::squared_difference"); +constexpr auto divide = + ttnn::register_operation>( + "ttnn::divide"); +constexpr auto bias_gelu = + ttnn::register_operation>( + "ttnn::bias_gelu"); template ttnn::Tensor operator+(const ttnn::Tensor &input_tensor_a, InputBType scalar) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp similarity index 89% rename from ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp rename to ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp index d4896f0bbdbe..fd6903f482df 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp @@ -2,8 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "binary_op.hpp" - +#include "binary_device_operation.hpp" #include "tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp" #include "tt_eager/tt_dnn/op_library/work_split.hpp" #include "tt_metal/common/constants.hpp" @@ -15,7 +14,10 @@ namespace utils { using namespace tt::tt_metal; std::map get_defines( - BinaryOpType op_type, const std::optional input_dtype, const std::optional output_dtype, const std::optional> fused_activations) { + BinaryOpType op_type, + const std::optional input_dtype, + const std::optional output_dtype, + const std::optional> fused_activations) { std::map defines; string op_name = "sub_tiles"; string op_binary_type = "EltwiseBinaryType::ELWSUB"; @@ -127,8 +129,9 @@ std::map get_defines( auto in_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(input_dtype.value())); auto out_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(output_dtype.value())); - defines.insert({"SFPU_OP_CHAIN_0", - fmt::format("typecast_tile_init(); typecast_tile<{0}u, {1}u>(i);", in_dataformat, out_dataformat)}); + defines.insert( + {"SFPU_OP_CHAIN_0", + fmt::format("typecast_tile_init(); typecast_tile<{0}u, {1}u>(i);", in_dataformat, out_dataformat)}); defines.insert({"SFPU_OP_TYPECAST_INCLUDE", "1"}); } @@ -148,9 +151,9 @@ std::map get_defines( } // namespace utils -Binary::program_factory_t Binary::select_program_factory( +BinaryDeviceOperation::program_factory_t BinaryDeviceOperation::select_program_factory( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - ZoneScopedN("Binary::select_program_factory"); + ZoneScopedN("BinaryDeviceOperation::select_program_factory"); const auto& input_shape_a = tensor_args.input_tensor_a.tensor_attributes->shape; const auto& input_shape_b = tensor_args.input_tensor_b.tensor_attributes->shape; @@ -171,16 +174,16 @@ Binary::program_factory_t Binary::select_program_factory( return BroadcastWidthMultiCore{}; } } - TT_THROW("ttnn::operations::binary::Binary: unsupported broadcast"); + TT_THROW("ttnn::operations::binary::BinaryDeviceOperation: unsupported broadcast"); } -void Binary::validate_on_program_cache_miss( +void BinaryDeviceOperation::validate_on_program_cache_miss( const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { const auto& input_tensor_a = tensor_args.input_tensor_a; const auto& input_tensor_b = tensor_args.input_tensor_b; const auto& output_tensor = tensor_args.output_tensor; - Binary::validate_on_program_cache_hit(attributes, tensor_args); + BinaryDeviceOperation::validate_on_program_cache_hit(attributes, tensor_args); TT_FATAL( input_tensor_a.device() == input_tensor_b.device(), @@ -248,7 +251,8 @@ void Binary::validate_on_program_cache_miss( "ignored"); } } -void Binary::validate_on_program_cache_hit(const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { +void BinaryDeviceOperation::validate_on_program_cache_hit( + const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { const auto& input_tensor_a = tensor_args.input_tensor_a; const auto& input_tensor_b = tensor_args.input_tensor_b; const auto& output_tensor = tensor_args.output_tensor; @@ -270,22 +274,24 @@ void Binary::validate_on_program_cache_hit(const operation_attributes_t& attribu if (batch_size_0_a != batch_size_0_b) { TT_ASSERT( batch_size_0_a > batch_size_0_b and batch_size_0_b == 1, - "ttnn::operations::binary::Binary: batch size mismatch"); + "ttnn::operations::binary::BinaryDeviceOperation: batch size mismatch"); } if (batch_size_1_a != batch_size_1_b) { TT_ASSERT( batch_size_1_a > batch_size_1_b and batch_size_1_b == 1, - "ttnn::operations::binary::Binary: batch size mismatch"); + "ttnn::operations::binary::BinaryDeviceOperation: batch size mismatch"); } if (height_a != height_b) { - TT_ASSERT(height_a > height_b and height_b == 1, "ttnn::operations::binary::Binary: height mismatch"); + TT_ASSERT( + height_a > height_b and height_b == 1, "ttnn::operations::binary::BinaryDeviceOperation: height mismatch"); } if (width_a != width_b) { - TT_ASSERT(width_a > width_b and width_b == 1, "ttnn::operations::binary::Binary: width mismatch"); + TT_ASSERT( + width_a > width_b and width_b == 1, "ttnn::operations::binary::BinaryDeviceOperation: width mismatch"); } } -Binary::shape_return_value_t Binary::compute_output_shapes( +BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_output_shapes( const operation_attributes_t&, const tensor_args_t& tensor_args) { const auto input_shape_a = tensor_args.input_tensor_a.tensor_attributes->shape; const auto input_shape_b = tensor_args.input_tensor_b.tensor_attributes->shape; @@ -308,7 +314,7 @@ Binary::shape_return_value_t Binary::compute_output_shapes( return ttnn::Shape::from_vector(output_shape, output_shape_with_tile_padding); } -Binary::tensor_return_value_t Binary::create_output_tensors( +BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { auto output_shape = compute_output_shapes(operation_attributes, tensor_args); const auto& input_tensor_a = tensor_args.input_tensor_a; @@ -366,15 +372,27 @@ Binary::tensor_return_value_t Binary::create_output_tensors( } } -tt::stl::hash::hash_t Binary::compute_program_hash( +tt::stl::hash::hash_t BinaryDeviceOperation::compute_program_hash( const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { const auto& input_tensor_a = tensor_args.input_tensor_a; const auto& input_tensor_b = tensor_args.input_tensor_b; auto program_factory = select_program_factory(attributes, tensor_args); - TT_ASSERT(std::holds_alternative(input_tensor_a.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(input_tensor_a.get_storage()),__FILE__, __LINE__)); - TT_ASSERT(std::holds_alternative(input_tensor_b.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(input_tensor_b.get_storage()),__FILE__, __LINE__)); - operation::Hash hash = operation::hash_operation( + TT_ASSERT( + std::holds_alternative(input_tensor_a.get_storage()), + fmt::format( + "Unexpected type {} in {}:{} ", + tt::stl::get_active_type_name_in_variant(input_tensor_a.get_storage()), + __FILE__, + __LINE__)); + TT_ASSERT( + std::holds_alternative(input_tensor_b.get_storage()), + fmt::format( + "Unexpected type {} in {}:{} ", + tt::stl::get_active_type_name_in_variant(input_tensor_b.get_storage()), + __FILE__, + __LINE__)); + operation::Hash hash = operation::hash_operation( attributes, program_factory.index(), input_tensor_a.dtype(), @@ -384,7 +402,7 @@ tt::stl::hash::hash_t Binary::compute_program_hash( return hash; } -operation::OpPerformanceModel Binary::create_op_performance_model( +operation::OpPerformanceModel BinaryDeviceOperation::create_op_performance_model( const operation_attributes_t& attributes, const tensor_args_t& tensor_args, tensor_return_value_t& tensor_return_value) { @@ -404,7 +422,7 @@ operation::OpPerformanceModel Binary::create_op_performance_model( // TODO: update OpPerformanceModel to work on variadic arguments operation::OpPerformanceModel result({input_tensor_a, input_tensor_b}, {output_tensor}, ideal_eltwise_cycles); #if 0 - tt::log_info(tt::LogOp, "Binary PerfModel:"); + tt::log_info(tt::LogOp, "BinaryDeviceOperation PerfModel:"); tt::log_info(tt::LogOp, "\t Data (Bytes): {}", total_bytes); tt::log_info(tt::LogOp, "\t ideal_eltwise_cycles: {}", ideal_eltwise_cycles); #endif diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp similarity index 95% rename from ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.hpp rename to ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp index c45bff9fde35..925b1a381d47 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp @@ -57,7 +57,7 @@ std::map get_defines( } // namespace utils -struct Binary { +struct BinaryDeviceOperation { struct operation_attributes_t { BinaryOpType binary_op_type; bool in_place; @@ -75,7 +75,7 @@ struct Binary { using tensor_return_value_t = Tensor; struct ElementWiseMultiCore { - struct cached_program_attributes_t { + struct shared_variables_t { KernelHandle binary_reader_kernel_id; KernelHandle unary_writer_kernel_id; KernelHandle eltwise_binary_kernel_id; @@ -87,7 +87,7 @@ struct Binary { uint32_t src1_single_tile_size; uint32_t dst_single_tile_size; }; - using cached_program_t = ttnn::device_operation::CachedProgram; + using cached_program_t = ttnn::device_operation::CachedProgram; static cached_program_t create( const operation_attributes_t& operation_attributes, @@ -102,13 +102,13 @@ struct Binary { }; struct BroadcastWidthMultiCore { - struct cached_program_attributes_t { + struct shared_variables_t { KernelHandle binary_reader_kernel_id; KernelHandle unary_writer_kernel_id; KernelHandle bcast_kernel_id; CoreCoord compute_with_storage_grid_size; }; - using cached_program_t = ttnn::device_operation::CachedProgram; + using cached_program_t = ttnn::device_operation::CachedProgram; static cached_program_t create( const operation_attributes_t& operation_attributes, @@ -123,13 +123,13 @@ struct Binary { }; struct BroadcastHeightMultiCore { - struct cached_program_attributes_t { + struct shared_variables_t { KernelHandle binary_reader_kernel_id; KernelHandle unary_writer_kernel_id; KernelHandle bcast_kernel_id; CoreCoord compute_with_storage_grid_size; }; - using cached_program_t = ttnn::device_operation::CachedProgram; + using cached_program_t = ttnn::device_operation::CachedProgram; static cached_program_t create( const operation_attributes_t& operation_attributes, @@ -144,7 +144,7 @@ struct Binary { }; struct BroadcastHeightAndWidthMultiCore { - struct cached_program_attributes_t { + struct shared_variables_t { KernelHandle binary_reader_kernel_id; KernelHandle unary_writer_kernel_id; KernelHandle bcast_kernel_id; @@ -155,7 +155,7 @@ struct Binary { uint32_t dst_single_tile_size; CBHandle cb_output; }; - using cached_program_t = ttnn::device_operation::CachedProgram; + using cached_program_t = ttnn::device_operation::CachedProgram; static cached_program_t create( const operation_attributes_t& operation_attributes, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp index e3c09d9c5453..d75a92134865 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp @@ -4,7 +4,7 @@ #include -#include "binary_op.hpp" +#include "binary_device_operation.hpp" #include "impl/buffers/buffer.hpp" #include "tensor/tensor.hpp" #include "tt_dnn/op_library/bcast/bcast_op.hpp" @@ -25,7 +25,8 @@ static const tt::tt_metal::BcastOpMath binary_op_type_to_bcast_op_math(const Bin } } -Binary::BroadcastHeightAndWidthMultiCore::cached_program_t Binary::BroadcastHeightAndWidthMultiCore::create( +BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::cached_program_t +BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, tensor_return_value_t& tensor_return_value) { @@ -229,7 +230,7 @@ Binary::BroadcastHeightAndWidthMultiCore::cached_program_t Binary::BroadcastHeig cb_output}}; } -void Binary::BroadcastHeightAndWidthMultiCore::override_runtime_arguments( +void BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::override_runtime_arguments( cached_program_t& cached_program, const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, @@ -241,15 +242,15 @@ void Binary::BroadcastHeightAndWidthMultiCore::override_runtime_arguments( const auto& input_tensor_b = tensor_args.input_tensor_b; auto& output_tensor = tensor_return_value; - auto& binary_reader_kernel_id = cached_program.program_attributes.binary_reader_kernel_id; - auto& unary_writer_kernel_id = cached_program.program_attributes.unary_writer_kernel_id; - auto& bcast_kernel_id = cached_program.program_attributes.bcast_kernel_id; - auto& compute_with_storage_grid_size = cached_program.program_attributes.compute_with_storage_grid_size; - auto& cb_src0 = cached_program.program_attributes.cb_src0; - auto& src0_single_tile_size = cached_program.program_attributes.src0_single_tile_size; - auto& src1_single_tile_size = cached_program.program_attributes.src1_single_tile_size; - auto& dst_single_tile_size = cached_program.program_attributes.dst_single_tile_size; - auto& cb_output = cached_program.program_attributes.cb_output; + auto& binary_reader_kernel_id = cached_program.shared_variables.binary_reader_kernel_id; + auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; + auto& bcast_kernel_id = cached_program.shared_variables.bcast_kernel_id; + auto& compute_with_storage_grid_size = cached_program.shared_variables.compute_with_storage_grid_size; + auto& cb_src0 = cached_program.shared_variables.cb_src0; + auto& src0_single_tile_size = cached_program.shared_variables.src0_single_tile_size; + auto& src1_single_tile_size = cached_program.shared_variables.src1_single_tile_size; + auto& dst_single_tile_size = cached_program.shared_variables.dst_single_tile_size; + auto& cb_output = cached_program.shared_variables.cb_output; auto& program = cached_program.program; uint32_t num_cores_x = compute_with_storage_grid_size.x; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp index 226091271201..fbb0f0d630e2 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "binary_op.hpp" +#include "binary_device_operation.hpp" #include "tensor/tensor.hpp" #include "tt_dnn/op_library/bcast/bcast_op.hpp" #include "tt_dnn/op_library/work_split.hpp" @@ -22,10 +22,11 @@ static const tt::tt_metal::BcastOpMath binary_op_type_to_bcast_op_math(const Bin } } -Binary::BroadcastHeightMultiCore::cached_program_t Binary :: BroadcastHeightMultiCore::create( - const operation_attributes_t& operation_attributes, - const tensor_args_t& tensor_args, - tensor_return_value_t& tensor_return_value) { +BinaryDeviceOperation::BroadcastHeightMultiCore::cached_program_t +BinaryDeviceOperation ::BroadcastHeightMultiCore::create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { using namespace tt; using namespace tt::tt_metal; @@ -200,11 +201,11 @@ Binary::BroadcastHeightMultiCore::cached_program_t Binary :: BroadcastHeightMult }; } -void Binary :: BroadcastHeightMultiCore::override_runtime_arguments( - cached_program_t& cached_program, - const operation_attributes_t& operation_attributes, - const tensor_args_t& tensor_args, - tensor_return_value_t& tensor_return_value) { +void BinaryDeviceOperation ::BroadcastHeightMultiCore::override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { using namespace tt; using namespace tt::tt_metal; @@ -213,7 +214,7 @@ void Binary :: BroadcastHeightMultiCore::override_runtime_arguments( auto& output_tensor = tensor_return_value; auto&& [binary_reader_kernel_id, unary_writer_kernel_id, bcast_kernel_id, compute_with_storage_grid_size] = - cached_program.program_attributes; + cached_program.shared_variables; auto& program = cached_program.program; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp index b595735161c5..6186f96b34df 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "binary_op.hpp" +#include "binary_device_operation.hpp" #include "tensor/tensor.hpp" #include "tt_dnn/op_library/bcast/bcast_op.hpp" #include "tt_dnn/op_library/work_split.hpp" @@ -22,7 +22,7 @@ static const tt::tt_metal::BcastOpMath binary_op_type_to_bcast_op_math(const Bin } } -Binary::BroadcastWidthMultiCore::cached_program_t Binary::BroadcastWidthMultiCore::create( +BinaryDeviceOperation::BroadcastWidthMultiCore::cached_program_t BinaryDeviceOperation::BroadcastWidthMultiCore::create( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, tensor_return_value_t& tensor_return_value) { @@ -199,7 +199,7 @@ Binary::BroadcastWidthMultiCore::cached_program_t Binary::BroadcastWidthMultiCor {binary_reader_kernel_id, unary_writer_kernel_id, bcast_kernel_id, compute_with_storage_grid_size}}; } -void Binary::BroadcastWidthMultiCore::override_runtime_arguments( +void BinaryDeviceOperation::BroadcastWidthMultiCore::override_runtime_arguments( cached_program_t& cached_program, const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, @@ -211,10 +211,10 @@ void Binary::BroadcastWidthMultiCore::override_runtime_arguments( const auto& input_tensor_b = tensor_args.input_tensor_b; auto& output_tensor = tensor_return_value; - auto& binary_reader_kernel_id = cached_program.program_attributes.binary_reader_kernel_id; - auto& unary_writer_kernel_id = cached_program.program_attributes.unary_writer_kernel_id; - auto& bcast_kernel_id = cached_program.program_attributes.bcast_kernel_id; - auto& compute_with_storage_grid_size = cached_program.program_attributes.compute_with_storage_grid_size; + auto& binary_reader_kernel_id = cached_program.shared_variables.binary_reader_kernel_id; + auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; + auto& bcast_kernel_id = cached_program.shared_variables.bcast_kernel_id; + auto& compute_with_storage_grid_size = cached_program.shared_variables.compute_with_storage_grid_size; auto& program = cached_program.program; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp index 521a83d0d8d6..41bf60f6da77 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp @@ -4,7 +4,7 @@ #include -#include "binary_op.hpp" +#include "binary_device_operation.hpp" #include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_eager/tt_dnn/op_library/work_split.hpp" #include "tt_metal/common/constants.hpp" @@ -245,7 +245,7 @@ inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args( UpdateCircularBufferTotalSize(program, cb_output, num_tiles_per_core_group_1 * dst_single_tile_size); } } -Binary::ElementWiseMultiCore::cached_program_t Binary::ElementWiseMultiCore::create( +BinaryDeviceOperation::ElementWiseMultiCore::cached_program_t BinaryDeviceOperation::ElementWiseMultiCore::create( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, tensor_return_value_t& tensor_return_value) { @@ -425,7 +425,7 @@ Binary::ElementWiseMultiCore::cached_program_t Binary::ElementWiseMultiCore::cre dst_single_tile_size}}; } -void Binary::ElementWiseMultiCore::override_runtime_arguments( +void BinaryDeviceOperation::ElementWiseMultiCore::override_runtime_arguments( cached_program_t& cached_program, const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, @@ -434,23 +434,23 @@ void Binary::ElementWiseMultiCore::override_runtime_arguments( const auto& input_tensor_b = tensor_args.input_tensor_b; auto& output_tensor = tensor_return_value; - const auto& program_attributes = cached_program.program_attributes; + const auto& shared_variables = cached_program.shared_variables; set_eltwise_binary_runtime_args( cached_program.program, input_tensor_a, input_tensor_b, output_tensor, - program_attributes.binary_reader_kernel_id, - program_attributes.unary_writer_kernel_id, - program_attributes.eltwise_binary_kernel_id, - program_attributes.cb_src0, - program_attributes.cb_src1, - program_attributes.cb_output, - program_attributes.compute_with_storage_grid_size, - program_attributes.src0_single_tile_size, - program_attributes.src1_single_tile_size, - program_attributes.dst_single_tile_size); + shared_variables.binary_reader_kernel_id, + shared_variables.unary_writer_kernel_id, + shared_variables.eltwise_binary_kernel_id, + shared_variables.cb_src0, + shared_variables.cb_src1, + shared_variables.cb_output, + shared_variables.compute_with_storage_grid_size, + shared_variables.src0_single_tile_size, + shared_variables.src1_single_tile_size, + shared_variables.dst_single_tile_size); } } // namespace ttnn::operations::binary diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp index ab0eceb0d80a..b79e40ef8783 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp @@ -2,19 +2,18 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp" #include "third_party/magic_enum/magic_enum.hpp" - +#include "tt_eager/tt_dnn/op_library/backward/backward_ops.cpp" #include "tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp" -#include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_eager/tt_dnn/op_library/composite/composite_ops.hpp" -#include "tt_eager/tt_dnn/op_library/backward/backward_ops.cpp" +#include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" +#include "tt_eager/tt_dnn/op_library/unpad/unpad_op.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" -#include "ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp" -#include "tt_eager/tt_dnn/op_library/unpad/unpad_op.hpp" - +#include "ttnn/operations/eltwise/unary/unary.hpp" namespace ttnn::operations::binary_backward { @@ -184,7 +183,7 @@ std::vector> _add_bw_overload( std::vector _xlogy_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor grad1_result = log(other, output_mem_config); + Tensor grad1_result = ttnn::log(other, output_mem_config); Tensor zero_tensor = ttnn::operations::creation::zeros_like(other, other.get_dtype(), other.get_layout(), std::nullopt, output_mem_config); grad1_result = where( ttnn::logical_and( diff --git a/ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.cpp b/ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.cpp new file mode 100644 index 000000000000..339c0688523a --- /dev/null +++ b/ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.cpp @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "example_device_operation.hpp" + +namespace ttnn::operations::examples { + +ExampleDeviceOperation::program_factory_t ExampleDeviceOperation::select_program_factory( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + bool some_condition_based_on_operation_attributes_and_or_tensor_args = true; + if (some_condition_based_on_operation_attributes_and_or_tensor_args) { + return SingleCore{}; + } + return MultiCore{}; +} + +void ExampleDeviceOperation::validate_on_program_cache_miss( + const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {} + +void ExampleDeviceOperation::validate_on_program_cache_hit( + const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {} + +ExampleDeviceOperation::shape_return_value_t ExampleDeviceOperation::compute_output_shapes( + const operation_attributes_t&, const tensor_args_t& tensor_args) { + return tensor_args.input_tensor.tensor_attributes->shape; +} + +ExampleDeviceOperation::tensor_return_value_t ExampleDeviceOperation::create_output_tensors( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + auto output_shape = compute_output_shapes(operation_attributes, tensor_args); + const auto& input_tensor = tensor_args.input_tensor; + return create_device_tensor( + output_shape, + input_tensor.tensor_attributes->dtype, + input_tensor.tensor_attributes->layout, + input_tensor.device()); +} + +} // namespace ttnn::operations::example diff --git a/ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.hpp b/ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.hpp new file mode 100644 index 000000000000..00ccc209222b --- /dev/null +++ b/ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.hpp @@ -0,0 +1,101 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "tensor/tensor.hpp" +#include "ttnn/core.hpp" +#include "ttnn/device_operation.hpp" +#include "ttnn/types.hpp" + +namespace ttnn::operations::examples { + +struct ExampleDeviceOperation { + struct operation_attributes_t { + bool attribute; + int some_other_attribute; + }; + struct tensor_args_t { + // An example of the tensor that can only be used as an input + const Tensor& input_tensor; + + // An example of the tensor that can be used for input/output or just for pre-allocated output + // Tensor& io_tensor; + + // An example of an optional tensor + // std::optional optional_output_tensor; + + // An example of a vector of tensors + // std::vector vector_of_tensors; + + // An example of a tuple of tensors + // std::tuple tuple_of_tensors; + + // An example of a vector of optional tensors + // std::vector> vector_of_optional_tensors; + + // An example of a tuple of tensors + // std::tuple>, std::optional> some_crazy_tuple_of_tensors; + }; + + // Can be a single ttnn::Shape, std::optional, std::vector, std::tuple etc. + using shape_return_value_t = ttnn::Shape; + + // Can be a single Tensor, std::optional, std::vector, std::tuple etc. + using tensor_return_value_t = Tensor; + + struct SingleCore { + struct shared_variables_t { + int some_variable_from_create_to_use_in_override_runtime_arguments; + }; + using cached_program_t = ttnn::device_operation::CachedProgram; + + static cached_program_t create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + + static void override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + }; + + struct MultiCore { + struct shared_variables_t { + int some_variable_from_create_to_use_in_override_runtime_arguments; + int some_other_variable_from_create_to_use_in_override_runtime_arguments; + }; + using cached_program_t = ttnn::device_operation::CachedProgram; + + static cached_program_t create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + + static void override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + }; + + using program_factory_t = std::variant; + + static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); + + static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); + + static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + + static tensor_return_value_t create_output_tensors( + const operation_attributes_t& operation_attributes, const tensor_args_t&); +}; + +} // namespace ttnn::operations::example diff --git a/ttnn/cpp/ttnn/operations/examples/example/device/multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/examples/example/device/multi_core_program_factory.cpp new file mode 100644 index 000000000000..8089633457f3 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/examples/example/device/multi_core_program_factory.cpp @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "example_device_operation.hpp" + +namespace ttnn::operations::examples { +ExampleDeviceOperation::MultiCore::cached_program_t ExampleDeviceOperation::MultiCore::create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { + using namespace tt; + using namespace tt::tt_metal; + + tt_metal::Program program = tt_metal::CreateProgram(); + + return { + std::move(program), + {.some_variable_from_create_to_use_in_override_runtime_arguments = 1, + .some_other_variable_from_create_to_use_in_override_runtime_arguments = 2}}; +} + +void ExampleDeviceOperation::MultiCore::override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { + auto& program = cached_program.program; + auto& some_variable_from_create_to_use_in_override_runtime_arguments = + cached_program.shared_variables.some_variable_from_create_to_use_in_override_runtime_arguments; + auto& some_other_variable_from_create_to_use_in_override_runtime_arguments = + cached_program.shared_variables.some_other_variable_from_create_to_use_in_override_runtime_arguments; +} + +} // namespace ttnn::operations::example diff --git a/ttnn/cpp/ttnn/operations/examples/example/device/single_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/examples/example/device/single_core_program_factory.cpp new file mode 100644 index 000000000000..5d32cbc41b61 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/examples/example/device/single_core_program_factory.cpp @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "example_device_operation.hpp" + +namespace ttnn::operations::examples { +ExampleDeviceOperation::SingleCore::cached_program_t ExampleDeviceOperation::SingleCore::create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { + using namespace tt; + using namespace tt::tt_metal; + + tt_metal::Program program = tt_metal::CreateProgram(); + + return {std::move(program), {.some_variable_from_create_to_use_in_override_runtime_arguments = 1}}; +} + +void ExampleDeviceOperation::SingleCore::override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { + auto& program = cached_program.program; + auto& some_variable_from_create_to_use_in_override_runtime_arguments = + cached_program.shared_variables.some_variable_from_create_to_use_in_override_runtime_arguments; +} + +} // namespace ttnn::operations::examples diff --git a/ttnn/cpp/ttnn/operations/examples/example/example.hpp b/ttnn/cpp/ttnn/operations/examples/example/example.hpp new file mode 100644 index 000000000000..94979242eac9 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/examples/example/example.hpp @@ -0,0 +1,35 @@ + +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "device/example_device_operation.hpp" + +namespace ttnn::operations::examples { + +struct ExampleOperation { + + static Tensor execute_on_main_thread( + uint8_t queue_id, + const Tensor &input_tensor) { + return ttnn::device_operation::run( + queue_id, + ExampleDeviceOperation::operation_attributes_t{.attribute = true, .some_other_attribute = 42}, + ExampleDeviceOperation::tensor_args_t{input_tensor}); + } + + static Tensor execute_on_main_thread( + const Tensor &input_tensor) { + return execute_on_main_thread(0, input_tensor); + } +}; + +} // namespace ttnn::operations::binary + +namespace ttnn { + +constexpr auto example = ttnn::register_operation("ttnn::example"); + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/examples/example/example_pybind.hpp b/ttnn/cpp/ttnn/operations/examples/example/example_pybind.hpp new file mode 100644 index 000000000000..f88c5b9d0640 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/examples/example/example_pybind.hpp @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" +#include "ttnn/operations/examples/example/example.hpp" +#include "ttnn/types.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::examples { + +void bind_example_operation(py::module& module) { + bind_registered_operation( + module, + ttnn::example, + R"doc(example(input_tensor: ttnn.Tensor) -> ttnn.Tensor)doc", + ttnn::pybind_overload_t{ + [](const decltype(ttnn::example)& self, const ttnn::Tensor& input_tensor, const uint8_t& queue_id) + -> ttnn::Tensor { return self(queue_id, input_tensor); }, + py::arg("input_tensor"), + py::kw_only(), + py::arg("queue_id") = 0}); +} + +} // namespace ttnn::operations::examples diff --git a/ttnn/cpp/ttnn/operations/examples/examples_pybind.hpp b/ttnn/cpp/ttnn/operations/examples/examples_pybind.hpp new file mode 100644 index 000000000000..cad3d11782c8 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/examples/examples_pybind.hpp @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "ttnn/operations/examples/example/example_pybind.hpp" + +namespace ttnn::operations::examples { + +void py_module(py::module& module) { bind_example_operation(module); } + +} // namespace ttnn::operations::examples