From 30c6a63015f20edddff903cfcab988082df5b513 Mon Sep 17 00:00:00 2001 From: Johanna Rock <129077594+johanna-rock-tt@users.noreply.github.com> Date: Thu, 4 Jul 2024 23:31:30 +0200 Subject: [PATCH] #8835: cleaned up ttnn operation registration on C++ side --- .../ttnn/ttnn/adding_new_ttnn_operation.rst | 49 ++++- tests/ttnn/unit_tests/test_deallocate.py | 4 +- .../unit_tests/test_validate_decorator.py | 29 --- tt_eager/tensor/types.hpp | 37 ---- ttnn/cpp/pybind11/decorators.hpp | 64 +----- ttnn/cpp/pybind11/operations/copy.hpp | 5 +- ttnn/cpp/pybind11/operations/creation.hpp | 10 +- ttnn/cpp/pybind11/operations/kv_cache.hpp | 4 +- ttnn/cpp/pybind11/operations/pool.hpp | 2 +- ttnn/cpp/ttnn/decorators.hpp | 200 ++++-------------- .../ttnn/op_library/to_dtype/to_dtype_op.hpp | 17 -- .../op_library/to_layout/to_layout_op.hpp | 17 -- .../to_memory_config/to_memory_config_op.hpp | 16 -- ttnn/cpp/ttnn/operations/ccl.hpp | 16 -- ttnn/cpp/ttnn/operations/conv2d.cpp | 32 +-- ttnn/cpp/ttnn/operations/copy.hpp | 27 --- ttnn/cpp/ttnn/operations/core.hpp | 14 -- ttnn/cpp/ttnn/operations/creation.hpp | 2 +- ttnn/cpp/ttnn/operations/data_movement.hpp | 48 ----- .../ttnn/operations/data_movement/pad/pad.hpp | 17 -- .../data_movement/permute/permute.hpp | 23 -- .../ttnn/operations/eltwise/binary/binary.hpp | 41 ---- .../eltwise/binary/binary_pybind.hpp | 2 +- .../binary_backward/binary_backward.hpp | 149 ++++--------- .../binary_backward_pybind.hpp | 4 +- .../ttnn/operations/eltwise/unary/unary.hpp | 104 +++------ .../eltwise/unary/unary_composite.hpp | 41 +++- .../operations/eltwise/unary/unary_pybind.hpp | 20 +- ttnn/cpp/ttnn/operations/embedding.hpp | 13 -- .../device/example_device_operation.hpp | 40 +++- .../operations/examples/example/example.hpp | 17 +- ttnn/cpp/ttnn/operations/kv_cache.hpp | 30 --- ttnn/cpp/ttnn/operations/matmul.cpp | 37 +--- ttnn/cpp/ttnn/operations/matmul.hpp | 2 - ttnn/cpp/ttnn/operations/normalization.hpp | 97 +-------- ttnn/cpp/ttnn/operations/pool.hpp | 22 -- .../operations/reduction/argmax/argmax.hpp | 21 +- .../reduction/generic/generic_reductions.hpp | 12 -- .../generic/generic_reductions_pybind.hpp | 2 +- .../ttnn/operations/reduction/topk/topk.hpp | 15 -- ttnn/cpp/ttnn/operations/transformer.hpp | 30 --- ttnn/cpp/ttnn/validation.hpp | 94 -------- ttnn/ttnn/decorators.py | 2 - ttnn/ttnn/operations/binary.py | 40 ++-- ttnn/ttnn/operations/core.py | 8 +- ttnn/ttnn/operations/creation.py | 16 +- ttnn/ttnn/operations/data_movement.py | 10 +- ttnn/ttnn/operations/embedding.py | 2 +- ttnn/ttnn/operations/normalization.py | 8 +- ttnn/ttnn/operations/pool.py | 2 +- ttnn/ttnn/operations/reduction.py | 16 +- ttnn/ttnn/operations/ternary.py | 2 +- ttnn/ttnn/operations/transformer.py | 10 +- ttnn/ttnn/operations/unary.py | 140 ++++++------ 54 files changed, 421 insertions(+), 1261 deletions(-) delete mode 100644 tests/ttnn/unit_tests/test_validate_decorator.py delete mode 100644 ttnn/cpp/ttnn/validation.hpp diff --git a/docs/source/ttnn/ttnn/adding_new_ttnn_operation.rst b/docs/source/ttnn/ttnn/adding_new_ttnn_operation.rst index 68c5a36590d..b0412ce2cf0 100644 --- a/docs/source/ttnn/ttnn/adding_new_ttnn_operation.rst +++ b/docs/source/ttnn/ttnn/adding_new_ttnn_operation.rst @@ -8,6 +8,9 @@ Adding New ttnn Operation Wormhole, or others). +FAQ +*** + What is a ttnn operation? ------------------------- @@ -25,11 +28,13 @@ What steps are needed to add ttnn operation in Python? 2. (Optional) Attach golden function to the operation using `ttnn.attach_golden_function`. This is useful for debugging and testing. +Example of Adding a new Device Operation +**************************************** C++ Implementation ------------------ -Step 1: Implement device operation (Optional) +Step 1: Implement device operation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In order to add a new device operation, follow the directory structure shown below: @@ -110,8 +115,10 @@ Finally, call the module defined in `examples/example/example_pybind.hpp` wherev -Step 2: Add golden function for the operation in Python -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Step 2: (Optional) Add golden function for the operation in Python +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A golden function can be added to an operation in order to compare its output with an equivalent `torch` implementation Add the following code in a python file: @@ -119,8 +126,38 @@ Add the following code in a python file: import ttnn - def example_golden_function(input_tensor, *args, **kwargs): - output_tensor = ... + # For the golden function, use the same signature as the operation + # Keep in mind that all `ttnn.Tensor`s are converted to `torch.Tensor`s + # And arguments not needed by torch can be ignored using `*args` and `**kwargs` + def golden_function(input_tensor: "torch.Tensor", *args, **kwargs): + output_tensor: "torch.Tensor" = ... return output_tensor - ttnn.attach_golden_function(ttnn.example, example_golden_function) + # ttnn Tensors are converted to torch tensors before calling the golden function automatically + # And the outputs are converted back to ttnn Tensors + # But in some cases you may need to preprocess the inputs and postprocess the outputs manually + + # In order to preprocess the inputs manually, use the following signature + # Note that the arguments are not packed into *args and **kwargs as in the golden function!!! + def preprocess_golden_function_inputs(args, kwargs): + # i.e. + ttnn_input_tensor = args[0] + return ttnn.to_torch(ttnn_input_tensor) + + # In order to postprocess the outputs manually, use the following signature + # Note that the arguments are not packed into *args and **kwargs as in the golden function!!! + def postprocess_golden_function_outputs(args, kwargs, output): + # i.e. + ttnn_input_tensor = args[0] + torch_output_tensor = outputs[0] + return ttnn.from_torch(torch_output_tensor, dtype=ttnn_input_tensor.dtype, device=ttnn_input_tensor.device) + + ttnn.attach_golden_function( + ttnn.example, + golden_function=golden_function, + preprocess_golden_function_inputs=preprocess_golden_function_inputs, # Optional + postprocess_golden_function_outputs=postprocess_golden_function_outputs # Optional + ) + +.. note:: + `ttnn.example` is the name of the operation in Python because the operation was registered as `ttnn::example` in C++. diff --git a/tests/ttnn/unit_tests/test_deallocate.py b/tests/ttnn/unit_tests/test_deallocate.py index b1647ba689f..315067ff7d2 100644 --- a/tests/ttnn/unit_tests/test_deallocate.py +++ b/tests/ttnn/unit_tests/test_deallocate.py @@ -11,7 +11,6 @@ @pytest.mark.parametrize("h", [32]) @pytest.mark.parametrize("w", [2 * 32]) -@pytest.mark.requires_fast_runtime_mode_off def test_deallocate(device, h, w): torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16) @@ -27,4 +26,5 @@ def test_deallocate(device, h, w): ttnn.deallocate(output_tensor) with pytest.raises(RuntimeError) as exception: output_tensor_reference + output_tensor_reference - assert "Tensor must be allocated!" in str(exception.value) + + assert "MemoryConfig can only be obtained if the buffer is not null" in str(exception.value) diff --git a/tests/ttnn/unit_tests/test_validate_decorator.py b/tests/ttnn/unit_tests/test_validate_decorator.py deleted file mode 100644 index dbaffd58a46..00000000000 --- a/tests/ttnn/unit_tests/test_validate_decorator.py +++ /dev/null @@ -1,29 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -import torch - -import ttnn -from models.utility_functions import skip_for_wormhole_b0 -from models.utility_functions import torch_random - - -@skip_for_wormhole_b0() -@pytest.mark.requires_fast_runtime_mode_off -@pytest.mark.parametrize("batch_size", [1]) -@pytest.mark.parametrize("h", [32]) -@pytest.mark.parametrize("w", [32]) -def test_add(device, batch_size, h, w): - torch.manual_seed(0) - - torch_input_tensor = torch_random((batch_size, h, w), -1, 1, dtype=torch.bfloat16) - - input_tensor = ttnn.from_torch(torch_input_tensor) - input_tensor = ttnn.to_device(input_tensor, device) - - with pytest.raises(RuntimeError) as exception: - output_tensor = input_tensor + input_tensor - assert "Tensor must be of layout" in str(exception.value) diff --git a/tt_eager/tensor/types.hpp b/tt_eager/tensor/types.hpp index b3fe2dc9a66..cbf1e9bb346 100644 --- a/tt_eager/tensor/types.hpp +++ b/tt_eager/tensor/types.hpp @@ -824,45 +824,8 @@ static std::ostream &operator<<(std::ostream &os, const Shape &self) { return os; } - -struct TensorSchema { - const std::size_t min_rank; - const std::size_t max_rank; - const std::set dtypes; - const std::set layouts; - const bool can_be_on_device; - const bool can_be_on_cpu; - const bool can_be_scalar; - const bool is_optional; - - static constexpr auto attribute_names() { - return std::forward_as_tuple( - "min_rank", - "max_rank", - "dtypes", - "layouts", - "can_be_on_device", - "can_be_on_cpu", - "can_be_scalar", - "is_optional"); - } - - const auto attribute_values() const { - return std::forward_as_tuple( - this->min_rank, - this->max_rank, - this->dtypes, - this->layouts, - this->can_be_on_device, - this->can_be_on_cpu, - this->can_be_scalar, - this->is_optional); - } -}; - } // namespace types using types::Shape; -using types::TensorSchema; } // namespace ttnn diff --git a/ttnn/cpp/pybind11/decorators.hpp b/ttnn/cpp/pybind11/decorators.hpp index 13251a6713b..298fb7e5247 100644 --- a/ttnn/cpp/pybind11/decorators.hpp +++ b/ttnn/cpp/pybind11/decorators.hpp @@ -38,7 +38,7 @@ struct pybind_overload_t { }; template -void add_operator_call(T& py_operation, const pybind_arguments_t& overload) { +void define_call_operator(T& py_operation, const pybind_arguments_t& overload) { std::apply( [&py_operation](auto... args) { py_operation.def( @@ -55,60 +55,12 @@ template < typename T, typename function_t, typename... py_args_t> -void add_operator_call(T& py_operation, const pybind_overload_t& overload) { +void define_call_operator(T& py_operation, const pybind_overload_t& overload) { std::apply( [&py_operation, &overload](auto... args) { py_operation.def("__call__", overload.function, args...); }, overload.args.value); } -template -std::string append_input_tensor_schemas_to_doc( - const operation_t& operation, const std::string& doc) { - std::stringstream updated_doc; - - auto write_row = [&updated_doc](const Tuple& tuple) { - auto index = 0; - - std::apply( - [&index, &updated_doc](const auto&... args) { - ( - [&index, &updated_doc](const auto& item) { - updated_doc << " "; - if (index == 0) { - updated_doc << " * - "; - } else { - updated_doc << " - "; - } - updated_doc << fmt::format("{}", item); - updated_doc << "\n"; - index++; - }(args), - ...); - }, - tuple); - }; - - if constexpr (detail::has_input_tensor_schemas()) { - if constexpr (std::tuple_size_v > 0) { - updated_doc << doc << "\n\n"; - auto tensor_index = 0; - for (const auto& schema : concrete_operation_t::input_tensor_schemas()) { - updated_doc << " .. list-table:: Input Tensor " << tensor_index << "\n\n"; - write_row(ttnn::TensorSchema::attribute_names()); - write_row(schema.attribute_values()); - tensor_index++; - updated_doc << "\n"; - } - updated_doc << "\n"; - return updated_doc.str(); - } else { - return doc; - } - } else { - return doc; - } -} - auto bind_registered_operation_helper( py::module& module, const auto& operation, const std::string& doc, auto attach_call_operator) { using registered_operation_t = std::decay_t; @@ -117,15 +69,11 @@ auto bind_registered_operation_helper( py::class_ py_operation(module, operation.class_name().c_str()); - if constexpr (requires { append_input_tensor_schemas_to_doc(operation, doc); }) { - py_operation.doc() = append_input_tensor_schemas_to_doc(operation, doc).c_str(); - } else { - py_operation.doc() = doc; - } + py_operation.doc() = doc; py_operation.def_property_readonly( "name", - [](const registered_operation_t& self) -> const std::string { return self.name(); }, + [](const registered_operation_t& self) -> const std::string { return self.base_name(); }, "Shortened name of the api"); py_operation.def_property_readonly( @@ -139,7 +87,7 @@ auto bind_registered_operation_helper( attach_call_operator(py_operation); - module.attr(operation.name().c_str()) = operation; // Bind an instance of the operation to the module + module.attr(operation.base_name().c_str()) = operation; // Bind an instance of the operation to the module return py_operation; } @@ -155,7 +103,7 @@ auto bind_registered_operation( auto attach_call_operator = [&](auto& py_operation) { ( [&py_operation](auto&& overload) { - add_operator_call(py_operation, overload); + define_call_operator(py_operation, overload); }(overloads), ...); }; diff --git a/ttnn/cpp/pybind11/operations/copy.hpp b/ttnn/cpp/pybind11/operations/copy.hpp index 4937a821f05..fc0f2876307 100644 --- a/ttnn/cpp/pybind11/operations/copy.hpp +++ b/ttnn/cpp/pybind11/operations/copy.hpp @@ -21,7 +21,7 @@ namespace detail { void bind_global_typecast(py::module& module) { auto doc = fmt::format( -R"doc({0}(input_tensor: ttnn.Tensor, dtype: ttnn.DataType, *, memory_config: Optional[ttnn.MemoryConfig] = None, output_tensor : Optional[ttnn.Tensor] = None, queue_id : Optional[int]) -> ttnn.Tensor + R"doc({0}(input_tensor: ttnn.Tensor, dtype: ttnn.DataType, *, memory_config: Optional[ttnn.MemoryConfig] = None, output_tensor : Optional[ttnn.Tensor] = None, queue_id : Optional[int]) -> ttnn.Tensor Applies {0} to :attr:`input_tensor`. @@ -40,8 +40,7 @@ Example:: >>> tensor = ttnn.typecast(torch.randn((10, 3, 32, 32), dtype=ttnn.bfloat16), ttnn.uint16) )doc", - ttnn::typecast.name()); - + ttnn::typecast.base_name()); using TypecastType = decltype(ttnn::typecast); bind_registered_operation( diff --git a/ttnn/cpp/pybind11/operations/creation.hpp b/ttnn/cpp/pybind11/operations/creation.hpp index 977a71b1471..f27e4240318 100644 --- a/ttnn/cpp/pybind11/operations/creation.hpp +++ b/ttnn/cpp/pybind11/operations/creation.hpp @@ -22,7 +22,7 @@ template void bind_full_operation(py::module& module, const creation_operation_t& operation) { auto doc = fmt::format( R"doc({0}(shape: ttnn.Shape, fill_value: Union[int, float], dtype: Optional[ttnn.DataType] = None, layout: Optional[ttnn.Layout] = None, device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None)doc", - operation.name()); + operation.base_name()); bind_registered_operation( module, @@ -66,7 +66,7 @@ template void bind_full_operation_with_hard_coded_value(py::module& module, const creation_operation_t& operation) { auto doc = fmt::format( R"doc({0}(shape: ttnn.Shape, dtype: Optional[ttnn.DataType] = None, layout: Optional[ttnn.Layout] = None, device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None)doc", - operation.name()); + operation.base_name()); bind_registered_operation( module, @@ -92,7 +92,7 @@ template void bind_full_like_operation(py::module& module, const creation_operation_t& operation) { auto doc = fmt::format( R"doc({0}(tensor: ttnn.Tensor, fill_value: Union[int, float], dtype: Optional[ttnn.DataType] = None, layout: Optional[ttnn.Layout] = None, device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None)doc", - operation.name()); + operation.base_name()); bind_registered_operation( module, @@ -136,7 +136,7 @@ template void bind_full_like_operation_with_hard_coded_value(py::module& module, const creation_operation_t& operation) { auto doc = fmt::format( R"doc({0}(tensor: ttnn.Tensor, dtype: Optional[ttnn.DataType] = None, layout: Optional[ttnn.Layout] = None, device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None)doc", - operation.name()); + operation.base_name()); bind_registered_operation( module, @@ -162,7 +162,7 @@ template void bind_arange_operation(py::module& module, const creation_operation_t& operation) { auto doc = fmt::format( R"doc({0}(start: int = 0, stop: int, step: int = 1, dtype: ttnn.DataType = ttnn.bfloat16, device: ttnn.Device = None, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG)doc", - operation.name()); + operation.base_name()); bind_registered_operation( module, diff --git a/ttnn/cpp/pybind11/operations/kv_cache.hpp b/ttnn/cpp/pybind11/operations/kv_cache.hpp index f30ffa60ea1..38799184b5c 100644 --- a/ttnn/cpp/pybind11/operations/kv_cache.hpp +++ b/ttnn/cpp/pybind11/operations/kv_cache.hpp @@ -32,7 +32,7 @@ void bind_fill_cache_for_user_(py::module& module, const kv_cache_operation_t& o * :attr:`batch_index` (int): The index into the cache tensor. )doc", - operation.name(), + operation.base_name(), operation.python_fully_qualified_name()); bind_registered_operation( @@ -64,7 +64,7 @@ void bind_update_cache_for_token_(py::module& module, const kv_cache_operation_t * :attr:`batch_offset` (int): The batch_offset into the cache tensor. )doc", - operation.name(), + operation.base_name(), operation.python_fully_qualified_name()); bind_registered_operation( diff --git a/ttnn/cpp/pybind11/operations/pool.hpp b/ttnn/cpp/pybind11/operations/pool.hpp index c775273cf14..b4056ca4163 100644 --- a/ttnn/cpp/pybind11/operations/pool.hpp +++ b/ttnn/cpp/pybind11/operations/pool.hpp @@ -43,7 +43,7 @@ void bind_global_avg_pool2d(py::module& module) { >>> tensor = ttnn.from_torch(torch.randn((10, 3, 32, 32), dtype=ttnn.bfloat16), device=device) >>> output = {1}(tensor) )doc", - ttnn::global_avg_pool2d.name(), + ttnn::global_avg_pool2d.base_name(), ttnn::global_avg_pool2d.python_fully_qualified_name()); bind_registered_operation( diff --git a/ttnn/cpp/ttnn/decorators.hpp b/ttnn/cpp/ttnn/decorators.hpp index cac0af87279..50fbc22d4d5 100644 --- a/ttnn/cpp/ttnn/decorators.hpp +++ b/ttnn/cpp/ttnn/decorators.hpp @@ -4,31 +4,19 @@ #pragma once -#include "tt_dnn/op_library/run_operation.hpp" #include "tt_eager/tensor/tensor.hpp" +#include "tt_eager/tt_dnn/op_library/operation.hpp" +#include "tt_eager/tt_dnn/op_library/run_operation.hpp" #include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp" -#include "ttnn/validation.hpp" namespace ttnn { namespace decorators { -namespace detail { - -template -using input_tensors_to_validate_return_t = decltype(T::input_tensors_to_validate(std::declval()...)); +using Tensors = tt::tt_metal::operation::Tensors; +using OptionalTensors = tt::tt_metal::operation::OptionalTensors; +using OptionalConstTensors = tt::tt_metal::operation::OptionalConstTensors; -template -constexpr bool has_input_tensors_to_validate() { - return std::experimental::is_detected_v; -} - -template -using input_tensor_schemas_t = decltype(T::input_tensor_schemas); - -template -constexpr bool has_input_tensor_schemas() { - return std::experimental::is_detected_v; -} +namespace detail { template using execute_on_worker_thread_return_t = decltype(T::execute_on_worker_thread(std::declval()...)); @@ -102,55 +90,6 @@ inline Tensors create_async_output_tensors(const Tensors& inputs, const Optional } } -template -constexpr bool is_any_of = (... || std::is_same_v, TypeToCheck>); - -template -constexpr auto conditional_tuple(T&& arg) { - if constexpr (is_any_of) { - return std::forward_as_tuple(std::forward(arg)); - } else { - return std::tuple<>(); - } -} - -template -constexpr auto extract_args(Args&&... args) { - return std::tuple_cat(conditional_tuple(std::forward(args))...); -} - -template -constexpr auto validate(const char* cpp_fully_qualified_name, args_t&&... args) { - if constexpr (has_input_tensor_schemas()) { - if (ttnn::CONFIG.enable_fast_runtime_mode) { - return; - } - - constexpr auto input_tensors_to_validate = [](args_t&&... args) { - if constexpr (has_input_tensors_to_validate()) { - return concrete_operation_t::input_tensors_to_validate(std::forward(args)...); - } else { - return extract_args>(std::forward(args)...); - } - }; - - auto tensors_to_validate = input_tensors_to_validate(std::forward(args)...); - static_assert( - std::tuple_size_v == - std::tuple_size_v, - "Number of tensors to validate must match the number of input tensors schemas"); - if constexpr (std::tuple_size_v > 0) { - [cpp_fully_qualified_name, &tensors_to_validate](std::index_sequence) { - (ttnn::validate_input_tensor( - cpp_fully_qualified_name, - std::get(tensors_to_validate), - concrete_operation_t::input_tensor_schemas().at(Ns)), - ...); - }(std::make_index_sequence>{}); - } - } -} - template auto map_launch_op_args_to_execute_on_worker_thread_args( const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors, const OptionalTensors& optional_output_tensors, args_t&&... args) { @@ -214,6 +153,35 @@ void log(const std::string& prefix, args_t&&... args) { std::apply([&fmt](const auto&... args) { tt::log_debug(tt::LogOp, fmt.c_str(), args...); }, args_tuple); } +// Get "add" from "ttnn::add" +static const std::string base_name(const char* cpp_fully_qualified_name) { + auto cpp_fully_qualified_name_as_string = std::string(cpp_fully_qualified_name); + auto last_token = cpp_fully_qualified_name_as_string.substr(cpp_fully_qualified_name_as_string.rfind("::") + 2); + return last_token; +} + +// Convert "ttnn::add" to "add_t" +static const std::string class_name(const char* cpp_fully_qualified_name) { + return base_name(cpp_fully_qualified_name) + "_t"; +} + +// Convert "ttnn::add" to "ttnn.add" +static const std::string python_fully_qualified_name(const char* cpp_fully_qualified_name) { + auto replace = [](const std::string& input, const std::string& from, const std::string& to) { + if (from.empty()) { + return input; + } + auto output = input; + size_t start = 0; + while ((start = output.find(from, start)) != std::string::npos) { + output.replace(start, from.length(), to); + start += to.length(); // In case 'to' contains 'from', like replacing 'x' with 'yx' + }; + return output; + }; + return replace(std::string{cpp_fully_qualified_name}, "::", "."); +} + } // namespace detail template @@ -259,8 +227,6 @@ struct operation_t { input_tensors, optional_input_tensors, optional_output_tensors, std::forward(args)...); return std::apply( [cpp_fully_qualified_name](auto&&... args) -> Tensors { - detail::validate( - cpp_fully_qualified_name, std::forward(args)...); return detail::map_execute_on_worker_thread_return_to_launch_op_return< concrete_operation_t>( concrete_operation_t::execute_on_worker_thread(std::forward(args)...)); @@ -290,7 +256,6 @@ struct operation_t { } } else { - detail::validate(cpp_fully_qualified_name, std::forward(args)...); auto output = concrete_operation_t::execute_on_main_thread(std::forward(args)...); tt::log_debug(tt::LogOp, "Finished C++ ttnn operation: {}", this->cpp_fully_qualified_name); return output; @@ -298,30 +263,14 @@ struct operation_t { } // Get "add" from "ttnn::add" - const std::string name() const { - auto cpp_fully_qualified_name = std::string(this->cpp_fully_qualified_name); - auto last_token = cpp_fully_qualified_name.substr(cpp_fully_qualified_name.rfind("::") + 2); - return last_token; - } + const std::string base_name() const { return detail::base_name(this->cpp_fully_qualified_name); } - // Convert "ttnn::add" to "ttnn_add_t" - const std::string class_name() const { return this->name() + "_t"; } + // Convert "ttnn::add" to "add_t" + const std::string class_name() const { return detail::class_name(this->cpp_fully_qualified_name); } // Convert "ttnn::add" to "ttnn.add" const std::string python_fully_qualified_name() const { - auto replace = [](const std::string& input, const std::string& from, const std::string& to) { - if (from.empty()) { - return input; - } - auto output = input; - size_t start = 0; - while ((start = output.find(from, start)) != std::string::npos) { - output.replace(start, from.length(), to); - start += to.length(); // In case 'to' contains 'from', like replacing 'x' with 'yx' - }; - return output; - }; - return replace(std::string{this->cpp_fully_qualified_name}, "::", "."); + return detail::python_fully_qualified_name(this->cpp_fully_qualified_name); } }; @@ -344,84 +293,29 @@ struct lambda_operation_t { } // Get "add" from "ttnn::add" - const std::string name() const { - auto cpp_fully_qualified_name = std::string(this->cpp_fully_qualified_name); - auto last_token = cpp_fully_qualified_name.substr(cpp_fully_qualified_name.rfind("::") + 2); - return last_token; - } + const std::string base_name() const { return detail::base_name(this->cpp_fully_qualified_name); } - // Convert "ttnn::add" to "ttnn_add_t" - const std::string class_name() const { return this->name() + "_t"; } + // Convert "ttnn::add" to "add_t" + const std::string class_name() const { return detail::class_name(this->cpp_fully_qualified_name); } // Convert "ttnn::add" to "ttnn.add" const std::string python_fully_qualified_name() const { - auto replace = [](const std::string& input, const std::string& from, const std::string& to) { - if (from.empty()) { - return input; - } - auto output = input; - size_t start = 0; - while ((start = output.find(from, start)) != std::string::npos) { - output.replace(start, from.length(), to); - start += to.length(); // In case 'to' contains 'from', like replacing 'x' with 'yx' - }; - return output; - }; - return replace(std::string{this->cpp_fully_qualified_name}, "::", "."); + return detail::python_fully_qualified_name(this->cpp_fully_qualified_name); } }; template -constexpr auto register_operation(const char* name) { - return operation_t<__COUNTER__, concrete_operation_t>{name}; +constexpr auto register_operation(const char* cpp_fully_qualified_name) { + return operation_t<__COUNTER__, concrete_operation_t>{cpp_fully_qualified_name}; } template -constexpr auto register_operation(const char* name, const lambda_t& lambda) { - return lambda_operation_t<__COUNTER__, lambda_t>{name, lambda}; -} - -// This function is used to transform the arguments of a function before calling it -// where the lambda is applied to the type that matches T. -// Example: https://godbolt.org/z/3P9YedMdj -template -constexpr auto transform_args_lambda(Func func, Lambda lambda, Args&&... args) -> decltype(auto) { - auto transformer = [lambda](auto&& arg) -> decltype(auto) { - if constexpr (std::is_same_v>) { - return lambda(std::forward(arg)); - } else { - return std::forward(arg); - } - }; - - return func(transformer(std::forward(args))...); -} - -template -auto transform_first_matching_arg(Lambda lambda) { - static_assert(!std::is_same::value, "No matching type found"); -} - -template -auto transform_first_matching_arg(Lambda lambda, First&& first, Rest&&... rest) { - if constexpr (std::is_same_v>) { - return lambda(std::forward(first)); - } else { - return transform_first_matching_arg(lambda, std::forward(rest)...); - } +constexpr auto register_operation(const char* cpp_fully_qualified_name, const lambda_t& lambda) { + return lambda_operation_t<__COUNTER__, lambda_t>{cpp_fully_qualified_name, lambda}; } #define TO_LAMBDA(function) ([](auto&&... args) { return function(std::forward(args)...); }) -#define TO_LAMBDA_WITH_RESHAPE(function) \ - ([](auto&&... args) { \ - const auto original_shape = ttnn::decorators::transform_first_matching_arg( \ - [&](auto&& tensor) { return tensor.get_shape(); }, std::forward(args)...); \ - return ttnn::reshape( \ - ttnn::decorators::transform_args_lambda( \ - function, [&](auto&& tensor) { return ttnn::unsqueeze_to_4D(tensor); }, args...), \ - original_shape); \ - }) } // namespace decorators diff --git a/ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp b/ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp index 82268ce6ff2..fc7ae2c0149 100644 --- a/ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp +++ b/ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp @@ -186,23 +186,6 @@ inline Tensor convert_to_dtype(const Tensor& input_tensor, const Layout& input_l } // namespace detail struct ToDtype { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 1, - 8, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::float32, ttnn::uint16, ttnn::uint32, ttnn::int32}, - {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, - true, - true, - false, - false}}; - } - - template - static auto input_tensors_to_validate(const Tensor& tensor_arg, Args&&... args) { - return std::forward_as_tuple(tensor_arg); - } - // TODO: Move to cpp once we merge with tt_eager static Tensor execute_on_worker_thread(const ttnn::Tensor& input_tensor, const ttnn::DataType& dtype) { auto input_layout = input_tensor.get_layout(); diff --git a/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.hpp b/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.hpp index 4cee9ac2f7c..9a13a86597e 100644 --- a/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.hpp +++ b/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.hpp @@ -25,23 +25,6 @@ namespace operations { namespace core { struct ToLayout { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 1, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::float32, ttnn::uint16, ttnn::uint32, ttnn::int32}, - {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, - true, - true, - false, - false}}; - } - - template - static auto input_tensors_to_validate(const Tensor& tensor_arg, Args&&... args) { - return std::forward_as_tuple(tensor_arg); - } - static Tensor execute_on_worker_thread( const ttnn::Tensor& tensor_arg, const ttnn::Layout layout, diff --git a/ttnn/cpp/ttnn/op_library/to_memory_config/to_memory_config_op.hpp b/ttnn/cpp/ttnn/op_library/to_memory_config/to_memory_config_op.hpp index 572d12a6f54..6d7a001750c 100644 --- a/ttnn/cpp/ttnn/op_library/to_memory_config/to_memory_config_op.hpp +++ b/ttnn/cpp/ttnn/op_library/to_memory_config/to_memory_config_op.hpp @@ -19,22 +19,6 @@ namespace operations { namespace core { struct ToMemoryConfig { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 1, - 8, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::float32, ttnn::uint16, ttnn::uint32, ttnn::int32}, - {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, - true, - true, - false, - false}}; - } - - template - static auto input_tensors_to_validate(const Tensor& tensor_arg, Args&&... args) { - return std::forward_as_tuple(tensor_arg); - } // TODO: Move to cpp once we merge with tt_eager static Tensor execute_on_worker_thread( diff --git a/ttnn/cpp/ttnn/operations/ccl.hpp b/ttnn/cpp/ttnn/operations/ccl.hpp index 145662ea33d..bdf8dc2c615 100644 --- a/ttnn/cpp/ttnn/operations/ccl.hpp +++ b/ttnn/cpp/ttnn/operations/ccl.hpp @@ -12,22 +12,6 @@ namespace operations { namespace ccl { struct ExecuteAllGather { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, - true, - false, - false, - false}}; - } - - template - static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } static ttnn::Tensor execute_on_main_thread( const ttnn::Tensor& input_tensor, diff --git a/ttnn/cpp/ttnn/operations/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv2d.cpp index 259448d21b6..eea5e6e151b 100644 --- a/ttnn/cpp/ttnn/operations/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv2d.cpp @@ -15,32 +15,6 @@ namespace operations { namespace conv2d { -const std::array input_schemas{ - ttnn::TensorSchema{ - 4, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::float32}, - {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, - true, - true, - false}, - ttnn::TensorSchema{ - 4, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::float32}, - {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, - true, - true, - false}, - ttnn::TensorSchema{ - 4, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::float32}, - {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, - true, - true, - false}}; - uint32_t find_closest_largest_divisor(uint32_t num, uint32_t start_divisor) { uint32_t divisor = start_divisor; while (num % divisor != 0) divisor = divisor - 1; @@ -553,11 +527,7 @@ std::tuple bias_tensor, std::optional conv_config_) { - ttnn::validate_input_tensor("ttnn.conv2d", input_tensor, input_schemas[0]); - ttnn::validate_input_tensor("ttnn.conv2d", weight_tensor, input_schemas[1]); - if (bias_tensor.has_value()) { - ttnn::validate_input_tensor("ttnn.conv2d", bias_tensor.value(), input_schemas[2]); - } + Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); uint32_t output_height = ((input_height - kernel_size[0] + 2 * padding[0]) / stride[0]) + 1; uint32_t output_width = ((input_width - kernel_size[1] + 2 * padding[1]) / stride[1]) + 1; diff --git a/ttnn/cpp/ttnn/operations/copy.hpp b/ttnn/cpp/ttnn/operations/copy.hpp index 85cbb42cf74..5afb3d6cf40 100644 --- a/ttnn/cpp/ttnn/operations/copy.hpp +++ b/ttnn/cpp/ttnn/operations/copy.hpp @@ -12,29 +12,7 @@ namespace ttnn { namespace operations { namespace copy { -namespace detail { -inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 2, // min rank - 4, // max rank - {ttnn::bfloat16}, - {ttnn::TILE_LAYOUT}, - true, // can_be_on_device - false, // can_be_on_cpu - false, // can_be_scalar - false // is_optional} - }}; -} -} // namespace detail - struct Typecast { - static const std::array input_tensor_schemas() { return detail::input_tensor_schemas(); } - - template - static auto input_tensors_to_validate(uint8_t queue_id, const Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } - static Tensor execute_on_worker_thread( const uint8_t& queue_id, const Tensor& input, @@ -60,11 +38,6 @@ struct Typecast { return operation::run(eltwise_op, {input}, {}, {optional_output_tensor}, queue_id).at(0); } - template - static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } - static Tensor execute_on_worker_thread( const Tensor& input, const DataType& output_dtype, diff --git a/ttnn/cpp/ttnn/operations/core.hpp b/ttnn/cpp/ttnn/operations/core.hpp index 94941fd05a6..ddef3ce400a 100644 --- a/ttnn/cpp/ttnn/operations/core.hpp +++ b/ttnn/cpp/ttnn/operations/core.hpp @@ -21,27 +21,13 @@ #include "ttnn/op_library/to_dtype/to_dtype_op.hpp" #include "ttnn/op_library/to_memory_config/to_memory_config_op.hpp" #include "ttnn/types.hpp" -#include "ttnn/validation.hpp" namespace ttnn { namespace operations { namespace core { -static inline const std::array reshape_input_schemas{ - ttnn::TensorSchema{ - 1, - 8, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::uint8, ttnn::uint16, ttnn::uint32, ttnn::int32, ttnn::float32}, - {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, - true, - true, - false, - false}, -}; - inline ttnn::Tensor reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { - ttnn::validate_input_tensor("ttnn.reshape", tensor, reshape_input_schemas[0]); auto tensor_shape = tensor.get_shape(); if (tensor_shape == shape) { diff --git a/ttnn/cpp/ttnn/operations/creation.hpp b/ttnn/cpp/ttnn/operations/creation.hpp index f082bb2ae99..c514666fab7 100644 --- a/ttnn/cpp/ttnn/operations/creation.hpp +++ b/ttnn/cpp/ttnn/operations/creation.hpp @@ -9,9 +9,9 @@ #include "tt_eager/tensor/types.hpp" #include "tt_eager/tt_numpy/functions.hpp" #include "tt_metal/impl/dispatch/command_queue.hpp" +#include "ttnn/core.hpp" #include "ttnn/decorators.hpp" #include "ttnn/types.hpp" -#include "ttnn/validation.hpp" namespace ttnn { namespace operations { diff --git a/ttnn/cpp/ttnn/operations/data_movement.hpp b/ttnn/cpp/ttnn/operations/data_movement.hpp index f59d0bee47b..9a212962951 100644 --- a/ttnn/cpp/ttnn/operations/data_movement.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement.hpp @@ -19,22 +19,6 @@ namespace operations { namespace data_movement { struct UpSample { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b}, - {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, - true, - false, - false, - false}}; - } - - template - static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { - return std::make_tuple(input_tensor); - } static ttnn::Tensor execute_on_worker_thread( const ttnn::Tensor& input_tensor, @@ -93,22 +77,6 @@ struct UpSample { }; struct Repeat { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 4, // min rank - 4, // max rank - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::int32, ttnn::uint32}, - {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, - true, // can_be_on_device - false, // can_be_on_cpu - false, // can_be_scalar - false}}; // is_optional - } - - template - static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { - return std::make_tuple(input_tensor); - } static ttnn::Tensor execute_on_worker_thread( const ttnn::Tensor& input_tensor, @@ -121,22 +89,6 @@ struct Repeat { }; struct RepeatInterleave { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 4, // min rank - 4, // max rank - {ttnn::bfloat16}, - {ttnn::TILE_LAYOUT}, - true, // can_be_on_device - true, // can_be_on_cpu - false, // can_be_scalar - false}}; // is_optional - } - - template - static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { - return std::make_tuple(input_tensor); - } // # This operation does not support the following cases: // # - Shape([2[32], 2[32]]) -> repeats = 2, dim = 0 diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp index 33f2df687f2..6d3e9da9400 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp @@ -19,23 +19,6 @@ constexpr uint8_t DefaultQueueId = 0; struct Pad { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 2, // min rank - 4, // max rank - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::uint16, ttnn::int32, ttnn::uint32}, - {ttnn::TILE_LAYOUT}, - true, // can_be_on_device - false, // can_be_on_cpu - false, // can_be_scalar - false // is_optional} - }}; - } - - template - static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { - return std::make_tuple(input_tensor); - } // Wrapper for TTDNN diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp index a42f2734102..5b20962b196 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp @@ -6,7 +6,6 @@ #include "ttnn/decorators.hpp" #include "ttnn/operations/core.hpp" -#include "ttnn/validation.hpp" #include "tt_eager/tt_dnn/op_library/permute/permute_op.hpp" #include "tt_eager/tt_dnn/op_library/run_operation.hpp" #include "tt_dnn/op_library/transpose/transpose_op.hpp" @@ -135,23 +134,6 @@ Tensor permute_launch(const Tensor &a, std::vector dims, const Mem } struct Permute { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 2, // min rank - 4, // max rank - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::uint16, ttnn::int32, ttnn::uint32, ttnn::float32}, - {ttnn::TILE_LAYOUT}, - true, // can_be_on_device - true, // can_be_on_cpu - false, // can_be_scalar - false // is_optional} - }}; - } - - template - static auto input_tensors_to_validate(uint8_t queue_id, const Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } static inline ttnn::Tensor execute_on_worker_thread( uint8_t queue_id, @@ -213,11 +195,6 @@ struct Permute { return output_tensor; } - template - static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } - static inline auto execute_on_worker_thread( const ttnn::Tensor &input_tensor, const std::vector& dims, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp index ab7f8d7e1c0..d0db352caad 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp @@ -29,32 +29,6 @@ constexpr bool is_associative(BinaryOpType op) { template struct Binary { - static inline const std::array input_tensor_schemas() { - return { - ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::uint16}, - {ttnn::TILE_LAYOUT}, - true, - false, - false, - false}, - ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::uint16}, - {ttnn::TILE_LAYOUT}, - true, - false, - true, - false}}; - } - - template - static auto input_tensors_to_validate(uint8_t queue_id, const Tensor &input_tensor_a, const Tensor &input_tensor_b, Args &&...args) { - return std::forward_as_tuple(input_tensor_a, input_tensor_b); - } static Tensor execute_on_worker_thread( uint8_t queue_id, @@ -115,11 +89,6 @@ struct Binary { BinaryDeviceOperation::tensor_args_t{input_tensor_a, input_tensor_b, optional_output_tensor}); } - template - static auto input_tensors_to_validate(const Tensor &input_tensor_a, const Tensor &input_tensor_b, Args &&...args) { - return std::forward_as_tuple(input_tensor_a, input_tensor_b); - } - static Tensor execute_on_worker_thread( const Tensor &input_tensor_a_arg, const Tensor &input_tensor_b_arg, @@ -131,11 +100,6 @@ struct Binary { return execute_on_worker_thread(DefaultQueueId, input_tensor_a_arg, input_tensor_b_arg, output_dtype, memory_config, optional_output_tensor, activations); } - template - static auto input_tensors_to_validate(const Tensor &input_tensor_a, const float input_tensor_b, Args &&...args) { - return std::forward_as_tuple(input_tensor_a, input_tensor_b); - } - // TODO: this case should use BinaryWithScalarProgramConfig and there should be a custom kernel to run this // Currently, this is exactly how tt::tt_metal::add_unary works static Tensor execute_on_worker_thread( @@ -155,11 +119,6 @@ struct Binary { activations); } - template - static auto input_tensors_to_validate(uint8_t queue_id, const Tensor &input_tensor_a, const float input_tensor_b, Args &&...args) { - return std::forward_as_tuple(input_tensor_a, input_tensor_b); - } - static Tensor execute_on_worker_thread( uint8_t queue_id, const ttnn::Tensor &input_tensor_a, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp index 50ac00ac663..1f3febcc7d0 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp @@ -43,7 +43,7 @@ void bind_binary_operation(py::module& module, const binary_operation_t& operati >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device) >>> output = {1}(tensor1, tensor2) )doc", - operation.name(), + operation.base_name(), operation.python_fully_qualified_name(), description); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp index 111defe40a8..b62a0b5ed88 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp @@ -15,38 +15,6 @@ namespace operations::binary_backward { template struct ExecuteBinaryBackward { - - static inline const std::array input_tensor_schemas() { - return { - ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::uint16}, - {ttnn::TILE_LAYOUT}, - true, - false, - false, - false}, - ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::uint16}, - {ttnn::TILE_LAYOUT}, - true, - false, - false, - false}, - ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::uint16}, - {ttnn::TILE_LAYOUT}, - true, - false, - false, - false}}; - } - static inline std::vector create_async_output_tensors( const std::vector &input_tensors, const std::vector>& optional_inputs) { const auto& input_tensor = input_tensors.at(0); @@ -54,11 +22,7 @@ struct ExecuteBinaryBackward { Tensor(operation::get_workers_for_op_output({input_tensor}))}; } - //Type 1: 2 inputs, 1 grad tensor - template - static auto input_tensors_to_validate(const Tensor &grad_tensor, const Tensor &input_tensor_a, const Tensor &input_tensor_b, Args &&...args) { - return std::forward_as_tuple(grad_tensor, input_tensor_a, input_tensor_b); - } + // Type 1: 2 inputs, 1 grad tensor static std::vector execute_on_worker_thread( const Tensor &grad_tensor_arg, @@ -70,11 +34,6 @@ struct ExecuteBinaryBackward { return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, memory_config); } - template - static auto input_tensors_to_validate(const Tensor &grad_tensor, const Tensor &input_tensor_a, const Tensor &input_tensor_b, const Tensor &input_tensor_c, Args &&...args) { - return std::forward_as_tuple(grad_tensor, input_tensor_a, input_tensor_b, input_tensor_c); - } - static std::vector execute_on_worker_thread( const Tensor &grad_tensor_arg, const Tensor &input_tensor_a_arg, @@ -86,62 +45,56 @@ struct ExecuteBinaryBackward { return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, input_tensor_c_arg, memory_config); } - //Type 1: Type 1 with 1 float - template - - static std::vector execute_on_worker_thread( - const Tensor &grad_tensor_arg, - const Tensor &input_tensor_a_arg, - float alpha, - const Tensor &input_tensor_b_arg, - const std::optional &memory_config = std::nullopt) { - - auto op_type = utils::get_function_type1_w_float(binary_backward_op_type); - auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config()); - return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, alpha, output_memory_config); + // Type 1: Type 1 with 1 float + + static std::vector execute_on_worker_thread( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_a_arg, + float alpha, + const Tensor &input_tensor_b_arg, + const std::optional &memory_config = std::nullopt) { + auto op_type = utils::get_function_type1_w_float(binary_backward_op_type); + auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config()); + return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, alpha, output_memory_config); } - //Type 1: Type 1 with 1 string - template - - static std::vector execute_on_worker_thread( - const Tensor &grad_tensor_arg, - const Tensor &input_tensor_a_arg, - string value, - const Tensor &input_tensor_b_arg, - const std::optional &memory_config = std::nullopt) { - - auto op_type = utils::get_function_type1_w_string(binary_backward_op_type); - auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config()); - return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, value, output_memory_config); + // Type 1: Type 1 with 1 string + static std::vector execute_on_worker_thread( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_a_arg, + string value, + const Tensor &input_tensor_b_arg, + const std::optional &memory_config = std::nullopt) { + auto op_type = utils::get_function_type1_w_string(binary_backward_op_type); + auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config()); + return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, value, output_memory_config); } - //Type 3 : Q_ID, type1 args, optional output tensor for inputs based on are_required_outputs value - template - static auto input_tensors_to_validate(uint8_t queue_id, const Tensor &grad_tensor, const Tensor &input_tensor_a, const Tensor &input_tensor_b, Args &&...args) { - return std::forward_as_tuple(grad_tensor, input_tensor_a, input_tensor_b); - } - - static std::vector> execute_on_main_thread( - uint8_t queue_id, - const Tensor &grad_tensor_arg, - const Tensor &input_tensor_a_arg, - const Tensor &input_tensor_b_arg, - const std::optional &memory_config = std::nullopt, - const std::vector& are_required_outputs = std::vector{true, true}, - std::optional input_a_grad = std::nullopt, - std::optional input_b_grad = std::nullopt) { - - auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config()); - auto op_type = utils::get_function_type3(binary_backward_op_type); - return op_type(queue_id, grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, output_memory_config, are_required_outputs, input_a_grad, input_b_grad); + // Type 3 : Q_ID, type1 args, optional output tensor for inputs based on are_required_outputs value + + static std::vector> execute_on_main_thread( + uint8_t queue_id, + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_a_arg, + const Tensor &input_tensor_b_arg, + const std::optional &memory_config = std::nullopt, + const std::vector &are_required_outputs = std::vector{true, true}, + std::optional input_a_grad = std::nullopt, + std::optional input_b_grad = std::nullopt) { + auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config()); + auto op_type = utils::get_function_type3(binary_backward_op_type); + return op_type( + queue_id, + grad_tensor_arg, + input_tensor_a_arg, + input_tensor_b_arg, + output_memory_config, + are_required_outputs, + input_a_grad, + input_b_grad); } - //Type 3 : type1 args, optional output tensor for inputs based on are_required_outputs value - template - static auto input_tensors_to_validate(const Tensor &grad_tensor, const Tensor &input_tensor_a, const Tensor &input_tensor_b, std::vector are_required_outputs, Args &&...args) { - return std::forward_as_tuple(grad_tensor, input_tensor_a, input_tensor_b); - } + // Type 3 : type1 args, optional output tensor for inputs based on are_required_outputs value static std::vector> execute_on_main_thread( const Tensor &grad_tensor_arg, @@ -157,11 +110,7 @@ struct ExecuteBinaryBackward { return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, output_memory_config, are_required_outputs, input_a_grad, input_b_grad); } - //Type 2 : Q_ID, type1 args, optional output tensor for inputs based on are_required_outputs value - template - static auto input_tensors_to_validate(uint8_t queue_id, const Tensor &grad_tensor, const Tensor &input_tensor_a, const Tensor &input_tensor_b, float alpha, Args &&...args) { - return std::forward_as_tuple(grad_tensor, input_tensor_a, input_tensor_b); - } + // Type 2 : Q_ID, type1 args, optional output tensor for inputs based on are_required_outputs value static std::vector> execute_on_main_thread( uint8_t queue_id, @@ -179,11 +128,7 @@ struct ExecuteBinaryBackward { return op_type(queue_id, grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, alpha, output_memory_config, are_required_outputs, input_a_grad, input_b_grad); } - //Type 2 : type1 args, optional output tensor for inputs based on are_required_outputs value - template - static auto input_tensors_to_validate(const Tensor &grad_tensor, const Tensor &input_tensor_a, const Tensor &input_tensor_b, float alpha, std::vector are_required_outputs, Args &&...args) { - return std::forward_as_tuple(grad_tensor, input_tensor_a, input_tensor_b); - } + // Type 2 : type1 args, optional output tensor for inputs based on are_required_outputs value static std::vector> execute_on_main_thread( const Tensor &grad_tensor_arg, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp index 72a66d8dae7..a205b00a631 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp @@ -22,7 +22,7 @@ namespace detail { template void bind_binary_backward(py::module& module, const binary_backward_operation_t& operation, const std::string& description) { auto doc = fmt::format( -R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig) -> std::vector + R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig) -> std::vector {2} @@ -41,7 +41,7 @@ Keyword args: >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device) >>> output = {1}(grad_tensor, tensor1, tensor2) )doc", - operation.name(), + operation.base_name(), operation.python_fully_qualified_name(), description); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp index 03dec87b7d0..e7e9bcceed5 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp @@ -9,7 +9,6 @@ #include "tt_eager/tt_dnn/op_library/run_operation.hpp" #include "ttnn/decorators.hpp" #include "ttnn/operations/core.hpp" -#include "ttnn/validation.hpp" namespace ttnn { @@ -19,23 +18,6 @@ namespace unary { namespace detail { -inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b}, - {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, - true, - false, - false, - false}}; -} - -template -inline auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); -} - inline Tensor execute_on_worker_thread( uint8_t queue_id, const Tensor& input_tensor, @@ -61,12 +43,6 @@ inline Tensor execute_on_worker_thread( template struct ExecuteUnary { - static const std::array input_tensor_schemas() { return detail::input_tensor_schemas(); } - - template - static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return detail::input_tensors_to_validate(input_tensor, std::forward(args)...); - } static Tensor execute_on_worker_thread( uint8_t queue_id, const Tensor& input_tensor, const std::optional& memory_config = std::nullopt, const std::optional& optional_output_tensor = std::nullopt) { @@ -81,13 +57,6 @@ struct ExecuteUnary { template struct ExecuteUnaryWithFastAndApproximateMode { - static const std::array input_tensor_schemas() { return detail::input_tensor_schemas(); } - - template - static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return detail::input_tensors_to_validate(input_tensor, std::forward(args)...); - } - static Tensor execute_on_worker_thread( uint8_t queue_id, const Tensor& input_tensor, @@ -95,7 +64,11 @@ struct ExecuteUnaryWithFastAndApproximateMode { const std::optional& memory_config = std::nullopt, const std::optional& optional_output_tensor = std::nullopt) { return detail::execute_on_worker_thread( - queue_id, input_tensor, {UnaryWithParam{unary_op_type, static_cast(parameter)}}, memory_config, optional_output_tensor); + queue_id, + input_tensor, + {UnaryWithParam{unary_op_type, static_cast(parameter)}}, + memory_config, + optional_output_tensor); } static Tensor execute_on_worker_thread( const Tensor& input_tensor, @@ -103,19 +76,16 @@ struct ExecuteUnaryWithFastAndApproximateMode { const std::optional& memory_config = std::nullopt, const std::optional& optional_output_tensor = std::nullopt) { return detail::execute_on_worker_thread( - DefaultQueueId, input_tensor, {UnaryWithParam{unary_op_type, static_cast(parameter)}}, memory_config, optional_output_tensor); + DefaultQueueId, + input_tensor, + {UnaryWithParam{unary_op_type, static_cast(parameter)}}, + memory_config, + optional_output_tensor); } }; template struct ExecuteUnaryWithFloatParameter { - static const std::array input_tensor_schemas() { return detail::input_tensor_schemas(); } - - template - static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return detail::input_tensors_to_validate(input_tensor, std::forward(args)...); - } - static Tensor execute_on_worker_thread( uint8_t queue_id, const Tensor& input_tensor, @@ -123,7 +93,11 @@ struct ExecuteUnaryWithFloatParameter { const std::optional& memory_config = std::nullopt, const std::optional& optional_output_tensor = std::nullopt) { return detail::execute_on_worker_thread( - queue_id, input_tensor, {UnaryWithParam{unary_op_type, static_cast(parameter)}}, memory_config, optional_output_tensor); + queue_id, + input_tensor, + {UnaryWithParam{unary_op_type, static_cast(parameter)}}, + memory_config, + optional_output_tensor); } static Tensor execute_on_worker_thread( @@ -132,18 +106,15 @@ struct ExecuteUnaryWithFloatParameter { const std::optional& memory_config = std::nullopt, const std::optional& optional_output_tensor = std::nullopt) { return detail::execute_on_worker_thread( - DefaultQueueId, input_tensor, {UnaryWithParam{unary_op_type, static_cast(parameter)}}, memory_config, optional_output_tensor); + DefaultQueueId, + input_tensor, + {UnaryWithParam{unary_op_type, static_cast(parameter)}}, + memory_config, + optional_output_tensor); } }; struct Softplus { - static const std::array input_tensor_schemas() { return detail::input_tensor_schemas(); } - - template - static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return detail::input_tensors_to_validate(input_tensor, std::forward(args)...); - } - static Tensor execute_on_worker_thread( const Tensor& input, const float beta, @@ -152,40 +123,32 @@ struct Softplus { const std::optional& optional_output_tensor = std::nullopt) { TT_ASSERT(input.device()->arch() != tt::ARCH::GRAYSKULL, "Softplus is not currently supported on Grayskull"); return detail::execute_on_worker_thread( - DefaultQueueId, input, {UnaryWithParam{UnaryOpType::SOFTPLUS, {beta, threshold}}}, memory_config, optional_output_tensor); + DefaultQueueId, + input, + {UnaryWithParam{UnaryOpType::SOFTPLUS, {beta, threshold}}}, + memory_config, + optional_output_tensor); } }; struct Sigmoid_accurate { - static const std::array input_tensor_schemas() { return detail::input_tensor_schemas(); } - - template - static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return detail::input_tensors_to_validate(input_tensor, std::forward(args)...); - } - static Tensor execute_on_worker_thread( const Tensor& input, const std::optional& memory_config = std::nullopt, const std::optional& optional_output_tensor = std::nullopt) { return detail::execute_on_worker_thread( - DefaultQueueId, input, {UnaryWithParam(UnaryOpType::NEG), - UnaryWithParam(UnaryOpType::EXP, 1.0f), - UnaryWithParam(UnaryOpType::ADD_UNARY_SFPU, 1.0f), - UnaryWithParam(UnaryOpType::RECIP)}, - memory_config, - optional_output_tensor); + DefaultQueueId, + input, + {UnaryWithParam(UnaryOpType::NEG), + UnaryWithParam(UnaryOpType::EXP, 1.0f), + UnaryWithParam(UnaryOpType::ADD_UNARY_SFPU, 1.0f), + UnaryWithParam(UnaryOpType::RECIP)}, + memory_config, + optional_output_tensor); } }; struct Unary_chain { - static const std::array input_tensor_schemas() { return detail::input_tensor_schemas(); } - - template - static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return detail::input_tensors_to_validate(input_tensor, std::forward(args)...); - } - static Tensor execute_on_worker_thread( const Tensor& input_tensor, const std::vector& ops_chain, @@ -196,7 +159,6 @@ struct Unary_chain { } }; - } // namespace unary } // namespace operations diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp index 1064bb66b18..5a6a8606343 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp @@ -6,7 +6,6 @@ #include "ttnn/decorators.hpp" #include "ttnn/operations/core.hpp" -#include "ttnn/validation.hpp" namespace ttnn { @@ -118,6 +117,46 @@ Tensor triu( // auto prelu = ttnn::leaky_relu; // Alias for leaky_relu. TODO(#8544): implement PReLU properly // Other unaries + +// This function is used to transform the arguments of a function before calling it +// where the lambda is applied to the type that matches T. +// Example: https://godbolt.org/z/3P9YedMdj +template +constexpr auto transform_args_lambda(Func func, Lambda lambda, Args&&... args) -> decltype(auto) { + auto transformer = [lambda](auto&& arg) -> decltype(auto) { + if constexpr (std::is_same_v>) { + return lambda(std::forward(arg)); + } else { + return std::forward(arg); + } + }; + + return func(transformer(std::forward(args))...); +} + +template +auto transform_first_matching_arg(Lambda lambda) { + static_assert(!std::is_same::value, "No matching type found"); +} + +template +auto transform_first_matching_arg(Lambda lambda, First&& first, Rest&&... rest) { + if constexpr (std::is_same_v>) { + return lambda(std::forward(first)); + } else { + return transform_first_matching_arg(lambda, std::forward(rest)...); + } +} +#define TO_LAMBDA_WITH_RESHAPE(function) \ + ([](auto&&... args) { \ + const auto original_shape = transform_first_matching_arg( \ + [&](auto&& tensor) { return tensor.get_shape(); }, std::forward(args)...); \ + return ttnn::reshape( \ + transform_args_lambda( \ + function, [&](auto&& tensor) { return ttnn::unsqueeze_to_4D(tensor); }, args...), \ + original_shape); \ + }) + constexpr auto deg2rad = ttnn::register_operation("ttnn::deg2rad", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::deg2rad)); constexpr auto rad2deg = ttnn::register_operation("ttnn::rad2deg", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::rad2deg)); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp index 13277d28785..9e6b6791d5c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp @@ -43,7 +43,7 @@ void bind_unary_operation(py::module& module, const unary_operation_t& operation >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = {1}(tensor) )doc", - operation.name(), + operation.base_name(), operation.python_fully_qualified_name()); bind_registered_operation( @@ -90,7 +90,7 @@ void bind_unary_operation_with_fast_and_approximate_mode(py::module& module, con >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = {1}(tensor, fast_and_approximate_mode=true) )doc", - operation.name(), + operation.base_name(), operation.python_fully_qualified_name()); bind_registered_operation( @@ -142,7 +142,7 @@ void bind_unary_operation_with_float_parameter( >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = {1}(tensor, {2}=true) )doc", - operation.name(), + operation.base_name(), operation.python_fully_qualified_name(), parameter_name, parameter_doc); @@ -190,7 +190,7 @@ void bind_softplus(py::module& module) { >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = {1}(tensor, parameter=true) )doc", - ttnn::softplus.name(), + ttnn::softplus.base_name(), ttnn::softplus.python_fully_qualified_name()); bind_registered_operation( @@ -226,7 +226,7 @@ void bind_sigmoid_accurate(py::module& module) { >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = {1}(tensor, parameter=true) )doc", - ttnn::sigmoid_accurate.name(), + ttnn::sigmoid_accurate.base_name(), ttnn::sigmoid_accurate.python_fully_qualified_name()); bind_registered_operation( @@ -262,7 +262,7 @@ void bind_unary_chain(py::module& module) { >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = {1}(tensor, ops_chain) )doc", - ttnn::unary_chain.name(), + ttnn::unary_chain.base_name(), ttnn::unary_chain.python_fully_qualified_name()); bind_registered_operation( @@ -301,7 +301,7 @@ void bind_unary_composite_operation(py::module& module, const unary_operation_t& >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = {1}(tensor) )doc", - operation.name(), + operation.base_name(), operation.python_fully_qualified_name()); bind_registered_operation( @@ -348,7 +348,7 @@ void bind_unary_operation_with_scale_and_shift(py::module& module, const unary_o >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = {1}(tensor) )doc", - operation.name(), + operation.base_name(), operation.python_fully_qualified_name()); bind_registered_operation( @@ -399,7 +399,7 @@ void bind_unary_operation_with_low_and_high(py::module& module, const unary_oper >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = {1}(tensor) )doc", - operation.name(), + operation.base_name(), operation.python_fully_qualified_name()); bind_registered_operation( @@ -449,7 +449,7 @@ void bind_unary_operation_with_diag(py::module& module, const unary_operation_t& >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = {1}(tensor) )doc", - operation.name(), + operation.base_name(), operation.python_fully_qualified_name()); bind_registered_operation( diff --git a/ttnn/cpp/ttnn/operations/embedding.hpp b/ttnn/cpp/ttnn/operations/embedding.hpp index cd04cdb9432..8146eabab06 100644 --- a/ttnn/cpp/ttnn/operations/embedding.hpp +++ b/ttnn/cpp/ttnn/operations/embedding.hpp @@ -8,7 +8,6 @@ #include "tt_eager/tt_dnn/op_library/run_operation.hpp" #include "ttnn/decorators.hpp" #include "ttnn/operations/core.hpp" -#include "ttnn/validation.hpp" namespace ttnn { @@ -19,18 +18,6 @@ namespace embedding { using EmbeddingsType = tt::tt_metal::EmbeddingsType; struct Embedding { - static const std::array input_tensor_schemas() { - return { - ttnn::TensorSchema{ - 2, 2, {ttnn::uint32, ttnn::bfloat16}, {ttnn::ROW_MAJOR_LAYOUT}, true, false, false, false}, - ttnn::TensorSchema{2, 4, {ttnn::bfloat16}, {ttnn::ROW_MAJOR_LAYOUT}, true, false, false, false}}; - } - - template - static auto input_tensors_to_validate(const Tensor& input_tensor, const Tensor& weight, Args&&... args) { - return std::forward_as_tuple(input_tensor, weight); - } - static Tensor execute_on_worker_thread( const Tensor& input_tensor_arg, const Tensor& weight_arg, diff --git a/ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.hpp b/ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.hpp index 00ccc209222..f586ffe0e4e 100644 --- a/ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.hpp @@ -15,14 +15,21 @@ namespace ttnn::operations::examples { struct ExampleDeviceOperation { + // Define the operation attributes. This is it to store all variables needed by operations that aren't tensors struct operation_attributes_t { bool attribute; int some_other_attribute; }; + + // Define the tensor arguments. This is it to store all tensors passed in and/or out of the operation + // Tensor arguments don't need to be just input tensors, they can be output tensors, input/output tensors, optional + // tensors, etc. struct tensor_args_t { - // An example of the tensor that can only be used as an input + // This example will use a tensor that can only be used as an input const Tensor& input_tensor; + // However, the following examples show what else can be done with tensor_args_t + // An example of the tensor that can be used for input/output or just for pre-allocated output // Tensor& io_tensor; @@ -42,12 +49,18 @@ struct ExampleDeviceOperation { // std::tuple>, std::optional> some_crazy_tuple_of_tensors; }; + // Define the return types for the shape(s) of the operation // Can be a single ttnn::Shape, std::optional, std::vector, std::tuple etc. using shape_return_value_t = ttnn::Shape; + // Define the return types for the tensor(s) of the operation // Can be a single Tensor, std::optional, std::vector, std::tuple etc. using tensor_return_value_t = Tensor; + // Note shape_return_value_t and tensor_return_value_t should follow the same pattern + // i.e. if shape_return_value_t is a std::vector> then tensor_return_value_t should be + // std::vector> + struct SingleCore { struct shared_variables_t { int some_variable_from_create_to_use_in_override_runtime_arguments; @@ -87,15 +100,38 @@ struct ExampleDeviceOperation { using program_factory_t = std::variant; + // Mandatory methods + + // Select the program factory based on the operation attributes and tensor args static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); + // Validate the operation when it creates a program. Usually will have more checks static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); + + // Validate the operation when it reuses a program. Usually will have less checks static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); + // Compute the output shapes based on the operation attributes and tensor args static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + // Create the output tensors based on the operation attributes and tensor args static tensor_return_value_t create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t&); + + // Optional methods + + // In case the operation need a custom hash function, the following method can be implemented + /* static tt::stl::hash::hash_t compute_program_hash( + const operation_attributes_t&, const tensor_args_t&); + */ + + // In case the operation needs a custom create_op_performance_model, this method can be implemented + /* + static operation::OpPerformanceModel create_op_performance_model( + const operation_attributes_t& attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + */ }; -} // namespace ttnn::operations::example +} // namespace ttnn::operations::examples diff --git a/ttnn/cpp/ttnn/operations/examples/example/example.hpp b/ttnn/cpp/ttnn/operations/examples/example/example.hpp index 94979242eac..85e761d0dbf 100644 --- a/ttnn/cpp/ttnn/operations/examples/example/example.hpp +++ b/ttnn/cpp/ttnn/operations/examples/example/example.hpp @@ -9,27 +9,26 @@ namespace ttnn::operations::examples { +// This is the main operation that will be called by the user struct ExampleOperation { - - static Tensor execute_on_main_thread( - uint8_t queue_id, - const Tensor &input_tensor) { + // This is the main function that will be called by the user + static Tensor execute_on_main_thread(uint8_t queue_id, const Tensor &input_tensor) { return ttnn::device_operation::run( queue_id, ExampleDeviceOperation::operation_attributes_t{.attribute = true, .some_other_attribute = 42}, ExampleDeviceOperation::tensor_args_t{input_tensor}); } - static Tensor execute_on_main_thread( - const Tensor &input_tensor) { - return execute_on_main_thread(0, input_tensor); - } + // This is the main function that will be called by the user + static Tensor execute_on_main_thread(const Tensor &input_tensor) { return execute_on_main_thread(0, input_tensor); } }; -} // namespace ttnn::operations::binary +} // namespace ttnn::operations::examples namespace ttnn { +// Register the operation. The name, in this case, "ttnn::example" should match the namespace of the operation +// And the name will be directly mapped to python, where it will become "ttnn.example" constexpr auto example = ttnn::register_operation("ttnn::example"); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/kv_cache.hpp b/ttnn/cpp/ttnn/operations/kv_cache.hpp index 5b8da903fb6..8be266e2524 100644 --- a/ttnn/cpp/ttnn/operations/kv_cache.hpp +++ b/ttnn/cpp/ttnn/operations/kv_cache.hpp @@ -10,36 +10,6 @@ namespace ttnn { namespace operations { namespace kv_cache { -struct UpdateKVCache { - static inline const std::array input_tensor_schemas() { - return { - ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, - true, - false, - false, - false}, - ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, - true, - false, - false, - false}, - }; - } - - template - static auto input_tensors_to_validate(const ttnn::Tensor& cache, const ttnn::Tensor& token, Args&&... args) { - return std::forward_as_tuple(cache, token); - } -}; - struct ExecuteFillCache { static ttnn::Tensor execute_on_worker_thread( const ttnn::Tensor& cache, const ttnn::Tensor& input, const uint32_t batch_index) { diff --git a/ttnn/cpp/ttnn/operations/matmul.cpp b/ttnn/cpp/ttnn/operations/matmul.cpp index af7e040f171..2bf90fd83e7 100644 --- a/ttnn/cpp/ttnn/operations/matmul.cpp +++ b/ttnn/cpp/ttnn/operations/matmul.cpp @@ -4,9 +4,8 @@ #include "matmul.hpp" -#include "ttnn/cpp/ttnn/operations/core.hpp" -#include "ttnn/cpp/ttnn/validation.hpp" #include "tt_dnn/op_library/transpose/transpose_op.hpp" +#include "ttnn/cpp/ttnn/operations/core.hpp" namespace ttnn { @@ -35,36 +34,6 @@ bool is_input_batched(const ttnn::Shape& shape) { } // namespace detail -const std::array input_tensor_schemas() { - return { - ttnn::TensorSchema{ - 2, - 4, - {ttnn::float32, ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::TILE_LAYOUT}, - true, - false, - true, - false}, - ttnn::TensorSchema{ - 2, - 4, - {ttnn::float32, ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::TILE_LAYOUT}, - true, - false, - true, - false}, - ttnn::TensorSchema{ - 2, - 4, - {ttnn::float32, ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::TILE_LAYOUT}, - true, - false, - true, - true}}; -} std::optional get_fused_activation(const std::optional& activation) { if (!activation.has_value()) { @@ -86,10 +55,6 @@ ttnn::Tensor matmul( const std::optional compute_kernel_config, const std::optional core_grid, const bool propagate_is_b_batched) { - ttnn::validate_input_tensor("ttnn.matmul", input_tensor_a, input_tensor_schemas()[0]); - ttnn::validate_input_tensor("ttnn.matmul", input_tensor_b, input_tensor_schemas()[1]); - ttnn::validate_input_tensor("ttnn.matmul", bias, input_tensor_schemas()[2]); - const auto& input_tensor_a_adjusted = transpose_a ? tt::tt_metal::transpose(input_tensor_a, -1, -2, input_tensor_a.memory_config()) : input_tensor_a; const auto& input_tensor_b_adjusted = transpose_b ? tt::tt_metal::transpose(input_tensor_b, -1, -2, input_tensor_b.memory_config()) : input_tensor_b; diff --git a/ttnn/cpp/ttnn/operations/matmul.hpp b/ttnn/cpp/ttnn/operations/matmul.hpp index 139a020095a..34a38744191 100644 --- a/ttnn/cpp/ttnn/operations/matmul.hpp +++ b/ttnn/cpp/ttnn/operations/matmul.hpp @@ -28,8 +28,6 @@ inline bool is_input_batched(const ttnn::Shape& shape); } // namespace detail -extern const std::array input_tensor_schemas(); - std::optional get_fused_activation(const std::optional& activation); ttnn::Tensor matmul( diff --git a/ttnn/cpp/ttnn/operations/normalization.hpp b/ttnn/cpp/ttnn/operations/normalization.hpp index 4d7524ce4e8..4d57510a5e0 100644 --- a/ttnn/cpp/ttnn/operations/normalization.hpp +++ b/ttnn/cpp/ttnn/operations/normalization.hpp @@ -15,15 +15,6 @@ namespace normalization { template struct Softmax { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 2, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false}}; - } - - template - static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } static ttnn::Tensor execute_on_worker_thread( const ttnn::Tensor& input_tensor, @@ -52,56 +43,6 @@ struct Softmax { }; struct LayerNorm { - static inline const std::array input_tensor_schemas() { - return { - ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::TILE_LAYOUT}, - true, - false, - false, - false}, - ttnn::TensorSchema{ - 1, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, - true, - false, - false, - true}, - ttnn::TensorSchema{ - 1, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, - true, - false, - false, - true}, - ttnn::TensorSchema{ - 1, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, - true, - false, - false, - true}}; - } - - template - static auto input_tensors_to_validate( - const Tensor& input_tensor, - float epsilon = 1e-12, - const std::optional& weight = std::nullopt, - const std::optional& bias = std::nullopt, - const std::optional& residual_input_tensor = std::nullopt, - Args&&... args) { - return std::forward_as_tuple(input_tensor, weight, bias, residual_input_tensor); - } static inline ttnn::Tensor execute_on_worker_thread( const ttnn::Tensor& input_tensor, @@ -125,32 +66,6 @@ struct LayerNorm { }; struct RMSNorm { - static inline const std::array input_tensor_schemas() { - return { - ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::TILE_LAYOUT}, - true, - false, - false, - false}, - ttnn::TensorSchema{ - 1, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, - true, - false, - false, - false}}; - } - - template - static auto input_tensors_to_validate(const Tensor& input_tensor, const Tensor& weight, Args&&... args) { - return std::forward_as_tuple(input_tensor, weight); - } static inline ttnn::Tensor execute_on_worker_thread( const ttnn::Tensor& input_tensor, @@ -163,17 +78,7 @@ struct RMSNorm { }; struct GroupNorm { - template - static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } - - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 2, 4, {ttnn::bfloat16}, {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, true, false, false, false}}; - } - - static inline ttnn::Tensor execute_on_worker_thread( + static inline ttnn::Tensor execute_on_worker_thread( const ttnn::Tensor& input_tensor, const int num_groups, const float epsilon, diff --git a/ttnn/cpp/ttnn/operations/pool.hpp b/ttnn/cpp/ttnn/operations/pool.hpp index 23bc067ca22..4292b618022 100644 --- a/ttnn/cpp/ttnn/operations/pool.hpp +++ b/ttnn/cpp/ttnn/operations/pool.hpp @@ -12,29 +12,7 @@ namespace ttnn { namespace operations { namespace pool { -namespace detail { -inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 4, // min rank - 4, // max rank - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::uint16, ttnn::uint32}, - {ttnn::TILE_LAYOUT}, - true, // can_be_on_device - false, // can_be_on_cpu - false, // can_be_scalar - false // is_optional} - }}; -} -} // namespace detail - struct GlobalAveragePool2D { - static const std::array input_tensor_schemas() { return detail::input_tensor_schemas(); } - - template - static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } - static Tensor execute_on_worker_thread( const Tensor& input, const std::optional& memory_config_arg = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/reduction/argmax/argmax.hpp b/ttnn/cpp/ttnn/operations/reduction/argmax/argmax.hpp index 00bdb77bba4..3e2b864029e 100644 --- a/ttnn/cpp/ttnn/operations/reduction/argmax/argmax.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/argmax/argmax.hpp @@ -4,27 +4,15 @@ #pragma once +#include "device/argmax_op.hpp" +#include "tt_eager/tt_dnn/op_library/run_operation.hpp" #include "ttnn/decorators.hpp" #include "ttnn/operations/core.hpp" -#include "ttnn/validation.hpp" - -#include "tt_eager/tt_dnn/op_library/run_operation.hpp" - -#include "device/argmax_op.hpp" namespace ttnn { namespace operations::reduction { struct ExecuteArgMax { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{4, 4, {ttnn::bfloat16}, {ttnn::ROW_MAJOR_LAYOUT}, true, false, false, false}}; - } - - template - static auto input_tensors_to_validate(uint8_t queue_id, const Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } - static ttnn::Tensor execute_on_worker_thread( uint8_t queue_id, const Tensor& input_tensor, @@ -37,11 +25,6 @@ struct ExecuteArgMax { .at(0); } - template - static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } - static ttnn::Tensor execute_on_worker_thread( const Tensor& input_tensor, const std::optional dim = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp index b66857d6162..326c1b2464a 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp @@ -7,11 +7,8 @@ #include "tt_dnn/op_library/composite/composite_ops.hpp" #include "tt_eager/tt_dnn/op_library/reduce/reduce_op.hpp" #include "tt_eager/tt_dnn/op_library/run_operation.hpp" - #include "ttnn/decorators.hpp" #include "ttnn/operations/core.hpp" -#include "ttnn/validation.hpp" - #include "ttnn/operations/eltwise/binary/binary.hpp" namespace ttnn { @@ -28,15 +25,6 @@ enum class ReduceType { template struct Reduce { - static const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 2, 4, {ttnn::bfloat8_b, ttnn::bfloat16}, {ttnn::TILE_LAYOUT}, true, false, false, false}}; - } - - template - static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } static Tensor execute_on_worker_thread( const Tensor& input_tensor_arg, diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions_pybind.hpp b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions_pybind.hpp index ded5824c08c..8733b481ec9 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions_pybind.hpp @@ -16,7 +16,7 @@ void bind_reduction_operation(py::module& module, const reduction_operation_t& o namespace py = pybind11; auto doc = fmt::format( R"doc({0}(input_tensor: ttnn.Tensor, dim: Optional[Union[int, Tuple[int]]] = None, keepdim: bool = True, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor)doc", - operation.name()); + operation.base_name()); bind_registered_operation( module, diff --git a/ttnn/cpp/ttnn/operations/reduction/topk/topk.hpp b/ttnn/cpp/ttnn/operations/reduction/topk/topk.hpp index 18bd102ec64..fe56b8b7555 100644 --- a/ttnn/cpp/ttnn/operations/reduction/topk/topk.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/topk/topk.hpp @@ -6,7 +6,6 @@ #include "ttnn/decorators.hpp" #include "ttnn/operations/core.hpp" -#include "ttnn/validation.hpp" #include "tt_eager/tt_dnn/op_library/run_operation.hpp" @@ -25,15 +24,6 @@ namespace ttnn { namespace operations::reduction { struct ExecuteTopK { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{4, 4, {ttnn::bfloat8_b, ttnn::bfloat16}, {ttnn::TILE_LAYOUT}, true, false, false, false}}; - } - - template - static auto input_tensors_to_validate(uint8_t queue_id, const Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } - static inline std::vector execute_on_worker_thread( uint8_t queue_id, const Tensor &input_tensor, @@ -50,11 +40,6 @@ struct ExecuteTopK { queue_id); } - template - static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } - static inline auto execute_on_worker_thread( const Tensor &input_tensor, const uint16_t k, diff --git a/ttnn/cpp/ttnn/operations/transformer.hpp b/ttnn/cpp/ttnn/operations/transformer.hpp index 2f704897ac5..f9a912ee474 100644 --- a/ttnn/cpp/ttnn/operations/transformer.hpp +++ b/ttnn/cpp/ttnn/operations/transformer.hpp @@ -228,15 +228,6 @@ struct ConcatenateHeads : public tt::tt_metal::NlpConcatHeads { }; struct ExecuteConcatenateHeads { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 4, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false}}; - } - - template - static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } static inline ttnn::Tensor execute_on_worker_thread( const Tensor& input_tensor, const std::optional& memory_config) { @@ -246,13 +237,6 @@ struct ExecuteConcatenateHeads { }; struct ExecuteRotaryEmbedding { - static inline const std::array input_tensor_schemas() { - return { - ttnn::TensorSchema{4, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false}, - ttnn::TensorSchema{4, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false}, - ttnn::TensorSchema{ - 4, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false}}; - } static inline ttnn::Tensor execute_on_worker_thread( const Tensor& input_tensor, @@ -279,20 +263,6 @@ struct ExecuteRotaryEmbedding { template struct ExecuteAttentionSoftmax { - static inline const std::array input_tensor_schemas() { - return { - ttnn::TensorSchema{4, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false}, - ttnn::TensorSchema{4, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, true}}; - } - - template - static auto input_tensors_to_validate( - const ttnn::Tensor& input_tensor, - const std::optional& head_size = std::nullopt, - const std::optional& attention_mask = std::nullopt, - Args&&... args) { - return std::forward_as_tuple(input_tensor, attention_mask); - } static ttnn::Tensor execute_on_worker_thread( const ttnn::Tensor& input_tensor, diff --git a/ttnn/cpp/ttnn/validation.hpp b/ttnn/cpp/ttnn/validation.hpp deleted file mode 100644 index 676cea07f1c..00000000000 --- a/ttnn/cpp/ttnn/validation.hpp +++ /dev/null @@ -1,94 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once -#include -#include - -#include "ttnn/core.hpp" -#include "ttnn/types.hpp" - -namespace ttnn { - -namespace validation { - -using TensorToValidate = std::variant, int, float>; -using TensorsToValidate = std::vector; - -inline void validate_input_tensor( - const std::string& operation_name, const TensorToValidate& tensor_to_validate, const TensorSchema& schema) { - if (schema.can_be_scalar) { - if (std::holds_alternative(tensor_to_validate) or std::holds_alternative(tensor_to_validate)) { - return; - } - } else { - if (not std::holds_alternative>(tensor_to_validate)) { - TT_THROW("{}: Tensor cannot be a scalar!", operation_name); - } - } - - const auto& optional_tensor = std::get>(tensor_to_validate); - - if (schema.is_optional && not optional_tensor.has_value()) { - return; - } - - const auto& tensor = optional_tensor.value(); - - if (tensor.get_shape().rank() < schema.min_rank or tensor.get_shape().rank() > schema.max_rank) { - TT_THROW( - "{}: Tensor rank is not valid: rank is {} but must be {} <= rank <- {}", - operation_name, - tensor.get_shape().rank(), - schema.min_rank, - schema.max_rank); - } - - if (schema.dtypes.find(tensor.get_dtype()) == schema.dtypes.end()) { - TT_THROW("{}: Tensor must be of type {}, but got {}", operation_name, schema.dtypes, tensor.get_dtype()); - } - - if (schema.layouts.find(tensor.get_layout()) == schema.layouts.end()) { - TT_THROW("{}: Tensor must be of layout {}, but got {}", operation_name, schema.layouts, tensor.get_layout()); - } - - if (schema.can_be_on_device and schema.can_be_on_cpu) { - // pass - } else if (schema.can_be_on_device) { - if (not ttnn::is_tensor_on_device_or_multidevice(tensor)) { - TT_THROW("{}: Tensor must be on device!", operation_name); - } - } else if (schema.can_be_on_cpu) { - if (ttnn::has_storage_type_of(tensor, ttnn::DEVICE_STORAGE_TYPE)) { - TT_THROW("{}: Tensor must be on host!", operation_name); - } - } else { - TT_THROW("{}: Tensor must be on host or device!", operation_name); - } - - if (not tensor.is_allocated()) { - TT_THROW("{}: Tensor must be allocated!", operation_name); - } -} - -template -inline void validate_input_tensors( - const std::string& operation_name, const TensorsToValidate& tensors, const TensorSchemas& schemas) { - if (tensors.size() != schemas.size()) { - TT_THROW( - "{}: Number of tensors ({}) does not match the number of schemas ({})", - operation_name, - tensors.size(), - schemas.size()); - } - for (auto index = 0; index < tensors.size(); index++) { - const auto& tensor = tensors.at(index); - validate_input_tensor(operation_name, tensor, schemas.at(index)); - } -} - -} // namespace validation -using validation::validate_input_tensor; -using validation::validate_input_tensors; -} // namespace ttnn diff --git a/ttnn/ttnn/decorators.py b/ttnn/ttnn/decorators.py index 5903b3a349e..dcbbb3fa095 100644 --- a/ttnn/ttnn/decorators.py +++ b/ttnn/ttnn/decorators.py @@ -671,8 +671,6 @@ def attach_golden_function( postprocess_golden_function_outputs or default_postprocess_golden_function_outputs ) - return operation - def export_operation(python_fully_qualified_name, operation, is_method): """ diff --git a/ttnn/ttnn/operations/binary.py b/ttnn/ttnn/operations/binary.py index 29d5ee295b4..2b60cf51244 100644 --- a/ttnn/ttnn/operations/binary.py +++ b/ttnn/ttnn/operations/binary.py @@ -86,8 +86,8 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, activations=None, ** return apply_activations(output_tensor, activations) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.add, golden_function=_golden_function) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.add_, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.add, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.add_, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, activations=None, **kwargs): @@ -95,8 +95,8 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, activations=None, ** return apply_activations(output_tensor, activations) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.subtract, golden_function=_golden_function) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.subtract_, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.subtract, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.subtract_, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, activations=None, **kwargs): @@ -104,8 +104,8 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, activations=None, ** return apply_activations(output_tensor, activations) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.multiply, golden_function=_golden_function) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.multiply_, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.multiply, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.multiply_, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): @@ -114,7 +114,7 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): return torch.eq(input_tensor_a, input_tensor_b) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.eq, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.eq, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): @@ -123,7 +123,7 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): return torch.ne(input_tensor_a, input_tensor_b) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.ne, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.ne, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): @@ -132,7 +132,7 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): return torch.gt(input_tensor_a, input_tensor_b) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.gt, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.gt, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): @@ -141,7 +141,7 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): return torch.ge(input_tensor_a, input_tensor_b) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.ge, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.ge, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): @@ -150,7 +150,7 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): return torch.lt(input_tensor_a, input_tensor_b) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.lt, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.lt, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): @@ -159,7 +159,7 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): return torch.le(input_tensor_a, input_tensor_b) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.le, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.le, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): @@ -168,7 +168,7 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): return torch.logical_and(input_tensor_a, input_tensor_b) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.logical_and, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.logical_and, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): @@ -177,7 +177,7 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): return torch.logical_or(input_tensor_a, input_tensor_b) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.logical_or, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.logical_or, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): @@ -186,7 +186,7 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): return torch.ldexp(input_tensor_a, input_tensor_b) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.ldexp, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.ldexp, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): @@ -195,7 +195,7 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): return torch.logaddexp(input_tensor_a, input_tensor_b) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.logaddexp, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.logaddexp, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): @@ -204,7 +204,7 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): return torch.logaddexp2(input_tensor_a, input_tensor_b) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.logaddexp2, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.logaddexp2, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): @@ -213,7 +213,7 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): return torch.divide(input_tensor_a, input_tensor_b) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.divide, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.divide, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): @@ -222,7 +222,7 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): return torch.nn.functional.gelu(torch.add(x, y)) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.bias_gelu, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.bias_gelu, golden_function=_golden_function) def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): @@ -231,7 +231,7 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): return torch.squared_difference(input_tensor_a, input_tensor_b) -ttnn.attach_golden_function(ttnn._ttnn.operations.binary.squared_difference, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.squared_difference, golden_function=_golden_function) def torch_squared_difference(x, y, *args, **kwargs): diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 3691ae14691..deb4a540add 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -430,14 +430,14 @@ def _golden_function(tensor, *args, **kwargs): return tensor -ttnn.attach_golden_function(ttnn._ttnn.operations.core.to_memory_config, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.to_memory_config, golden_function=_golden_function) def _golden_function(tensor, *args, **kwargs): return tensor -ttnn.attach_golden_function(ttnn._ttnn.operations.core.to_layout, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.to_layout, golden_function=_golden_function) def _golden_function(tensor, *args, **kwargs): @@ -445,8 +445,8 @@ def _golden_function(tensor, *args, **kwargs): # TODO: Merge to_dtype and typecast -ttnn.attach_golden_function(ttnn._ttnn.operations.core.to_dtype, golden_function=_golden_function) -ttnn.attach_golden_function(ttnn._ttnn.operations.copy.typecast, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.to_dtype, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.typecast, golden_function=_golden_function) def _golden_function(tensor, *args, **kwargs): diff --git a/ttnn/ttnn/operations/creation.py b/ttnn/ttnn/operations/creation.py index 35c45dc5427..e58dbf9e07a 100644 --- a/ttnn/ttnn/operations/creation.py +++ b/ttnn/ttnn/operations/creation.py @@ -14,7 +14,7 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): return torch.zeros_like(input_tensor) -ttnn.attach_golden_function(ttnn._ttnn.operations.creation.zeros_like, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.zeros_like, golden_function=_golden_function) def _golden_function(input_tensor: ttnn.Tensor, **_): @@ -23,7 +23,7 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): return torch.ones_like(input_tensor) -ttnn.attach_golden_function(ttnn._ttnn.operations.creation.ones_like, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.ones_like, golden_function=_golden_function) def _golden_function(input_tensor: ttnn.Tensor, *, fill_value: float, **_): @@ -32,7 +32,7 @@ def _golden_function(input_tensor: ttnn.Tensor, *, fill_value: float, **_): return torch.full_like(input_tensor, fill_value) -ttnn.attach_golden_function(ttnn._ttnn.operations.creation.full_like, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.full_like, golden_function=_golden_function) def _golden_function(input_tensor: ttnn.Tensor, *, fill_value: float, **_): @@ -41,7 +41,7 @@ def _golden_function(input_tensor: ttnn.Tensor, *, fill_value: float, **_): return torch.empty_like(input_tensor, fill_value) -ttnn.attach_golden_function(ttnn._ttnn.operations.creation.empty_like, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.empty_like, golden_function=_golden_function) def _golden_function(input_shape: ttnn.Shape, **_): @@ -50,7 +50,7 @@ def _golden_function(input_shape: ttnn.Shape, **_): return torch.zeros(input_shape) -ttnn.attach_golden_function(ttnn._ttnn.operations.creation.zeros, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.zeros, golden_function=_golden_function) def _golden_function(input_shape: ttnn.Shape, **_): @@ -59,7 +59,7 @@ def _golden_function(input_shape: ttnn.Shape, **_): return torch.ones(input_shape) -ttnn.attach_golden_function(ttnn._ttnn.operations.creation.ones, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.ones, golden_function=_golden_function) def _golden_function_full(input_shape: ttnn.Shape, fill_value: float, **_): @@ -68,7 +68,7 @@ def _golden_function_full(input_shape: ttnn.Shape, fill_value: float, **_): return torch.full(input_shape, fill_value=fill_value) -ttnn.attach_golden_function(ttnn._ttnn.operations.creation.full, golden_function=_golden_function_full) +ttnn.attach_golden_function(ttnn.full, golden_function=_golden_function_full) def _golden_function(input_shape: ttnn.Shape, **_): @@ -77,7 +77,7 @@ def _golden_function(input_shape: ttnn.Shape, **_): return torch.empty(input_shape) -ttnn.attach_golden_function(ttnn._ttnn.operations.creation.empty, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.empty, golden_function=_golden_function) def _golden_function(start: int, end: int, step: int, **_): diff --git a/ttnn/ttnn/operations/data_movement.py b/ttnn/ttnn/operations/data_movement.py index f65288aa8b2..4a26914edb6 100644 --- a/ttnn/ttnn/operations/data_movement.py +++ b/ttnn/ttnn/operations/data_movement.py @@ -60,7 +60,7 @@ def _postprocess_golden_function_outputs(output_tensor, args, kwargs): ttnn.attach_golden_function( - ttnn._ttnn.operations.data_movement.pad, + ttnn.pad, golden_function=_golden_function, preprocess_golden_function_inputs=_preprocess_golden_function_inputs, postprocess_golden_function_outputs=_postprocess_golden_function_outputs, @@ -92,7 +92,7 @@ def _golden_function(tensors, dim=0, **_): ttnn.attach_golden_function( - ttnn._ttnn.operations.data_movement.concat, + ttnn.concat, golden_function=_golden_function, ) @@ -103,14 +103,14 @@ def _golden_function(tensor, repeats, dim=0, **_): return torch.repeat_interleave(tensor, repeats, dim=dim) -ttnn.attach_golden_function(ttnn._ttnn.operations.data_movement.repeat_interleave, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.repeat_interleave, golden_function=_golden_function) def _golden_function(tensor, shape, **_): return tensor.repeat(shape[0], shape[1], shape[2], shape[3]) -ttnn.attach_golden_function(ttnn._ttnn.operations.data_movement.repeat, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.repeat, golden_function=_golden_function) def _golden_function(input_tensor: ttnn.Tensor, scale_factor: Tuple[float, float], **_): @@ -123,7 +123,7 @@ def _golden_function(input_tensor: ttnn.Tensor, scale_factor: Tuple[float, float ttnn.attach_golden_function( - ttnn._ttnn.operations.data_movement.upsample, + ttnn.upsample, golden_function=_golden_function, ) diff --git a/ttnn/ttnn/operations/embedding.py b/ttnn/ttnn/operations/embedding.py index b5c0725042f..26c774f6432 100644 --- a/ttnn/ttnn/operations/embedding.py +++ b/ttnn/ttnn/operations/embedding.py @@ -12,7 +12,7 @@ def _golden_function(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, **_): return output_tensor -ttnn.attach_golden_function(ttnn._ttnn.operations.embedding.embedding, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.embedding, golden_function=_golden_function) __all__ = [] diff --git a/ttnn/ttnn/operations/normalization.py b/ttnn/ttnn/operations/normalization.py index bee4fd259b6..b9fa6fbc047 100644 --- a/ttnn/ttnn/operations/normalization.py +++ b/ttnn/ttnn/operations/normalization.py @@ -19,7 +19,7 @@ def _golden_function(input_tensor: ttnn.Tensor, dim: int, **_): ttnn.attach_golden_function( - ttnn._ttnn.operations.normalization.softmax, + ttnn.softmax, golden_function=_golden_function, ) @@ -45,7 +45,7 @@ def _golden_function( return torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), weight, bias, eps=epsilon) -ttnn.attach_golden_function(ttnn._ttnn.operations.normalization.layer_norm, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.layer_norm, golden_function=_golden_function) def _golden_function(input_tensor: ttnn.Tensor, weight=None, *, epsilon=1e-12, **_): @@ -60,7 +60,7 @@ def _golden_function(input_tensor: ttnn.Tensor, weight=None, *, epsilon=1e-12, * return weight * input_tensor -ttnn.attach_golden_function(ttnn._ttnn.operations.normalization.rms_norm, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.rms_norm, golden_function=_golden_function) # group norm helper function @@ -232,7 +232,7 @@ def _postprocess_golden_function_outputs(output, args, kwargs): ttnn.attach_golden_function( - ttnn._ttnn.operations.normalization.group_norm, + ttnn.group_norm, golden_function=_golden_function, postprocess_golden_function_outputs=_postprocess_golden_function_outputs, ) diff --git a/ttnn/ttnn/operations/pool.py b/ttnn/ttnn/operations/pool.py index 1daef250636..3a7225d3b57 100644 --- a/ttnn/ttnn/operations/pool.py +++ b/ttnn/ttnn/operations/pool.py @@ -126,6 +126,6 @@ def _golden_function(input_tensor: ttnn.Tensor): return torch.nn.functional.global_avg_pool2d(input_tensor, output_size) -ttnn.attach_golden_function(ttnn._ttnn.operations.pool.global_avg_pool2d, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.global_avg_pool2d, golden_function=_golden_function) __all__ = [] diff --git a/ttnn/ttnn/operations/reduction.py b/ttnn/ttnn/operations/reduction.py index 9004c84c3ca..f907a0d76b1 100644 --- a/ttnn/ttnn/operations/reduction.py +++ b/ttnn/ttnn/operations/reduction.py @@ -31,17 +31,17 @@ def golden_function(input_tensor: ttnn.Tensor, k: int, dim: Optional[int] = None # Generic reductions -ttnn.attach_golden_function(ttnn._ttnn.operations.reduction.mean, golden_function=_create_golden_function("mean")) -ttnn.attach_golden_function(ttnn._ttnn.operations.reduction.sum, golden_function=_create_golden_function("sum")) -ttnn.attach_golden_function(ttnn._ttnn.operations.reduction.max, golden_function=_create_golden_function("max")) -ttnn.attach_golden_function(ttnn._ttnn.operations.reduction.min, golden_function=_create_golden_function("min")) -ttnn.attach_golden_function(ttnn._ttnn.operations.reduction.var, golden_function=_create_golden_function("var")) -ttnn.attach_golden_function(ttnn._ttnn.operations.reduction.std, golden_function=_create_golden_function("std")) +ttnn.attach_golden_function(ttnn.mean, golden_function=_create_golden_function("mean")) +ttnn.attach_golden_function(ttnn.sum, golden_function=_create_golden_function("sum")) +ttnn.attach_golden_function(ttnn.max, golden_function=_create_golden_function("max")) +ttnn.attach_golden_function(ttnn.min, golden_function=_create_golden_function("min")) +ttnn.attach_golden_function(ttnn.var, golden_function=_create_golden_function("var")) +ttnn.attach_golden_function(ttnn.std, golden_function=_create_golden_function("std")) # Special reductions -ttnn.attach_golden_function(ttnn._ttnn.operations.reduction.argmax, golden_function=_create_golden_function("argmax")) +ttnn.attach_golden_function(ttnn.argmax, golden_function=_create_golden_function("argmax")) -ttnn.attach_golden_function(ttnn._ttnn.operations.reduction.topk, golden_function=_create_golden_function_topk()) +ttnn.attach_golden_function(ttnn.topk, golden_function=_create_golden_function_topk()) __all__ = [] diff --git a/ttnn/ttnn/operations/ternary.py b/ttnn/ttnn/operations/ternary.py index eb9ff8ec3ea..704beb32d51 100644 --- a/ttnn/ttnn/operations/ternary.py +++ b/ttnn/ttnn/operations/ternary.py @@ -195,7 +195,7 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): ttnn.attach_golden_function( - ttnn._ttnn.operations.ternary.where, + ttnn.where, golden_function=_golden_function, ) diff --git a/ttnn/ttnn/operations/transformer.py b/ttnn/ttnn/operations/transformer.py index a8872f72ce6..eb9d0b08937 100644 --- a/ttnn/ttnn/operations/transformer.py +++ b/ttnn/ttnn/operations/transformer.py @@ -48,7 +48,7 @@ def _golden_function( ttnn.attach_golden_function( - ttnn._ttnn.operations.transformer.split_query_key_value_and_split_heads, + ttnn.transformer.split_query_key_value_and_split_heads, golden_function=_golden_function, ) @@ -70,13 +70,13 @@ def _golden_function(input_tensor: ttnn.Tensor, *, head_size: int, attention_mas ttnn.attach_golden_function( - ttnn._ttnn.operations.transformer.attention_softmax, + ttnn.transformer.attention_softmax, golden_function=_golden_function, ) ttnn.attach_golden_function( - ttnn._ttnn.operations.transformer.attention_softmax_, + ttnn.transformer.attention_softmax_, golden_function=_golden_function, ) @@ -93,7 +93,7 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): return output_tensor -ttnn.attach_golden_function(ttnn._ttnn.operations.transformer.concatenate_heads, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.transformer.concatenate_heads, golden_function=_golden_function) def _golden_function(x, cos_cached, sin_cached, token_idx, **_): @@ -114,7 +114,7 @@ def apply_rotary_pos_emb(x, cos_cached, sin_cached, token_idx=0): return pt_out -ttnn.attach_golden_function(ttnn._ttnn.operations.transformer.rotary_embedding, golden_function=_golden_function) +ttnn.attach_golden_function(ttnn.transformer.rotary_embedding, golden_function=_golden_function) __all__ = [] diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index c4b9d5a2255..920978b68f4 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -122,80 +122,80 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): TTNN_ELTWISE_UNARY_CPP_FUNCTIONS = [ - ttnn._ttnn.operations.unary.abs, - ttnn._ttnn.operations.unary.acos, - ttnn._ttnn.operations.unary.asin, - ttnn._ttnn.operations.unary.atan, - ttnn._ttnn.operations.unary.cos, - ttnn._ttnn.operations.unary.erfinv, - ttnn._ttnn.operations.unary.exp2, - ttnn._ttnn.operations.unary.expm1, - ttnn._ttnn.operations.unary.eqz, - ttnn._ttnn.operations.unary.gez, - ttnn._ttnn.operations.unary.gtz, - ttnn._ttnn.operations.unary.i0, - ttnn._ttnn.operations.unary.isfinite, - ttnn._ttnn.operations.unary.isinf, - ttnn._ttnn.operations.unary.isnan, - ttnn._ttnn.operations.unary.isneginf, - ttnn._ttnn.operations.unary.isposinf, - ttnn._ttnn.operations.unary.lez, - ttnn._ttnn.operations.unary.log, - ttnn._ttnn.operations.unary.log10, - ttnn._ttnn.operations.unary.log2, - ttnn._ttnn.operations.unary.logical_not, - ttnn._ttnn.operations.unary.ltz, - ttnn._ttnn.operations.unary.neg, - ttnn._ttnn.operations.unary.nez, - ttnn._ttnn.operations.unary.reciprocal, - ttnn._ttnn.operations.unary.relu, - ttnn._ttnn.operations.unary.relu6, - ttnn._ttnn.operations.unary.sigmoid, - ttnn._ttnn.operations.unary.sign, - ttnn._ttnn.operations.unary.signbit, - ttnn._ttnn.operations.unary.silu, - ttnn._ttnn.operations.unary.sin, - ttnn._ttnn.operations.unary.sqrt, - ttnn._ttnn.operations.unary.square, - ttnn._ttnn.operations.unary.tan, - ttnn._ttnn.operations.unary.tanh, + ttnn.abs, + ttnn.acos, + ttnn.asin, + ttnn.atan, + ttnn.cos, + ttnn.erfinv, + ttnn.exp2, + ttnn.expm1, + ttnn.eqz, + ttnn.gez, + ttnn.gtz, + ttnn.i0, + ttnn.isfinite, + ttnn.isinf, + ttnn.isnan, + ttnn.isneginf, + ttnn.isposinf, + ttnn.lez, + ttnn.log, + ttnn.log10, + ttnn.log2, + ttnn.logical_not, + ttnn.ltz, + ttnn.neg, + ttnn.nez, + ttnn.reciprocal, + ttnn.relu, + ttnn.relu6, + ttnn.sigmoid, + ttnn.sign, + ttnn.signbit, + ttnn.silu, + ttnn.sin, + ttnn.sqrt, + ttnn.square, + ttnn.tan, + ttnn.tanh, # Unaries with fast_and_approximate_mode - ttnn._ttnn.operations.unary.exp, - ttnn._ttnn.operations.unary.erf, - ttnn._ttnn.operations.unary.erfc, - ttnn._ttnn.operations.unary.gelu, - ttnn._ttnn.operations.unary.rsqrt, + ttnn.exp, + ttnn.erf, + ttnn.erfc, + ttnn.gelu, + ttnn.rsqrt, # Unaries with float parameter - ttnn._ttnn.operations.unary.elu, - ttnn._ttnn.operations.unary.heaviside, - ttnn._ttnn.operations.unary.leaky_relu, - # ttnn._ttnn.operations.unary.prelu, # Alias for leaky_relu. TODO(#8544): implement PReLU properly + ttnn.elu, + ttnn.heaviside, + ttnn.leaky_relu, + # ttnn.prelu, # Alias for leaky_relu. TODO(#8544): implement PReLU properly # Unaries using op_chain - ttnn._ttnn.operations.unary.log_sigmoid, - ttnn._ttnn.operations.unary.softplus, - ttnn._ttnn.operations.unary.sigmoid_accurate, + ttnn.log_sigmoid, + ttnn.softplus, + ttnn.sigmoid_accurate, # Other unaries (composite operations - tt_eager dependency) - ttnn._ttnn.operations.unary.acosh, - ttnn._ttnn.operations.unary.asinh, - ttnn._ttnn.operations.unary.atanh, - ttnn._ttnn.operations.unary.cbrt, - ttnn._ttnn.operations.unary.cosh, - ttnn._ttnn.operations.unary.deg2rad, - ttnn._ttnn.operations.unary.digamma, - ttnn._ttnn.operations.unary.hardswish, - ttnn._ttnn.operations.unary.hardsigmoid, - ttnn._ttnn.operations.unary.hardtanh, - ttnn._ttnn.operations.unary.lgamma, - ttnn._ttnn.operations.unary.log1p, - ttnn._ttnn.operations.unary.mish, - ttnn._ttnn.operations.unary.multigammaln, - ttnn._ttnn.operations.unary.rad2deg, - ttnn._ttnn.operations.unary.sinh, - ttnn._ttnn.operations.unary.softsign, - ttnn._ttnn.operations.unary.swish, - ttnn._ttnn.operations.unary.tanhshrink, - ttnn._ttnn.operations.unary.tril, - ttnn._ttnn.operations.unary.triu, + ttnn.acosh, + ttnn.asinh, + ttnn.atanh, + ttnn.cbrt, + ttnn.cosh, + ttnn.deg2rad, + ttnn.digamma, + ttnn.hardswish, + ttnn.hardsigmoid, + ttnn.hardtanh, + ttnn.lgamma, + ttnn.log1p, + ttnn.mish, + ttnn.multigammaln, + ttnn.rad2deg, + ttnn.sinh, + ttnn.softsign, + ttnn.swish, + ttnn.tanhshrink, + ttnn.tril, + ttnn.triu, ] for unary_function in TTNN_ELTWISE_UNARY_CPP_FUNCTIONS: register_ttnn_cpp_unary_function(unary_function)