diff --git a/docs/source/ttnn/ttnn/ttnn/binary_eq_bw.rst b/docs/source/ttnn/ttnn/ttnn/binary_eq_bw.rst deleted file mode 100644 index 771788a5b85a..000000000000 --- a/docs/source/ttnn/ttnn/ttnn/binary_eq_bw.rst +++ /dev/null @@ -1,6 +0,0 @@ -.. _ttnn.binary_eq_bw: - -ttnn.binary_eq_bw -################# - -.. autofunction:: ttnn.binary_eq_bw diff --git a/docs/source/ttnn/ttnn/ttnn/eq_bw.rst b/docs/source/ttnn/ttnn/ttnn/eq_bw.rst new file mode 100644 index 000000000000..1be3bae2ccee --- /dev/null +++ b/docs/source/ttnn/ttnn/ttnn/eq_bw.rst @@ -0,0 +1,6 @@ +.. _ttnn.eq_bw: + +ttnn.eq_bw +################# + +.. autofunction:: ttnn.eq_bw diff --git a/docs/source/ttnn/ttnn/ttnn/unary_eq_bw.rst b/docs/source/ttnn/ttnn/ttnn/unary_eq_bw.rst deleted file mode 100644 index 0e413efc9683..000000000000 --- a/docs/source/ttnn/ttnn/ttnn/unary_eq_bw.rst +++ /dev/null @@ -1,6 +0,0 @@ -.. _ttnn.unary_eq_bw: - -ttnn.unary_eq_bw -################# - -.. autofunction:: ttnn.unary_eq_bw diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_addalpha.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_addalpha.py index b95a74ef0a81..6998ff46ef05 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_addalpha.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_addalpha.py @@ -4,7 +4,6 @@ import torch import pytest -import tt_lib import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_addcdiv.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_addcdiv.py index 7dd3f47673ee..9b13447c3fa1 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_addcdiv.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_addcdiv.py @@ -4,7 +4,6 @@ import torch import pytest -import tt_lib import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_assign.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_assign.py index 6684c915b5fd..771ad27cfb33 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_assign.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_assign.py @@ -4,7 +4,6 @@ import torch import pytest -import tt_lib import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc @@ -50,9 +49,7 @@ def test_bw_binary_assign(input_shapes, device): grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) tt_output_tensor_on_device = ttnn.assign_bw(grad_tensor, input_tensor, other_tensor) - print(tt_output_tensor_on_device) - print(grad_tensor) - print(input_tensor, other_tensor) + in_data.retain_grad() pyt_y = torch.clone(in_data) diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_atan2.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_atan2.py index 62ff3d340b66..de7a1c44fe09 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_atan2.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_atan2.py @@ -4,7 +4,6 @@ import torch import pytest -import tt_lib import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import ( data_gen_with_range, diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_bias_gelu.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_bias_gelu.py index 1b69674d4ba3..e4df2c9030b7 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_bias_gelu.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_bias_gelu.py @@ -4,7 +4,6 @@ import torch import pytest -import tt_lib import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_ge.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_ge.py index abc7091fbbe9..e07d27b4be2a 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_ge.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_ge.py @@ -4,7 +4,6 @@ import torch import pytest -import tt_lib import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_gt.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_gt.py index a0052e82947f..a20f95ea4d6e 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_gt.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_gt.py @@ -4,7 +4,6 @@ import torch import pytest -import tt_lib import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_ne.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_ne.py index 5976220d3a26..17869882e71a 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_ne.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_ne.py @@ -4,7 +4,6 @@ import torch import pytest -import tt_lib import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_clamp.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_clamp.py index 1351e2990d59..1b446ea59518 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_clamp.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_clamp.py @@ -4,7 +4,6 @@ import torch import pytest -import tt_lib import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_clamp_min.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_clamp_min.py index 65fef6bb6924..c94e5d35c8d3 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_clamp_min.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_clamp_min.py @@ -4,7 +4,6 @@ import torch import pytest -import tt_lib import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_eq.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_eq.py index d3a6b785267e..6af86cc1cda3 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_eq.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_eq.py @@ -4,7 +4,6 @@ import torch import pytest -import tt_lib import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_mvlgamma.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_mvlgamma.py index d9d1ce47643a..c393cbd440bc 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_mvlgamma.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_mvlgamma.py @@ -4,7 +4,6 @@ import torch import pytest -import tt_lib import ttnn from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 9df4e53f40ff..38713beafcdd 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -23,6 +23,7 @@ set(TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/upsample/device/upsample_op_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/upsample/device/upsample_op_single_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary/device/unary_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/examples/example/device/example_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/examples/example/device/single_core_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/examples/example/device/multi_core_program_factory.cpp diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp index 097e8593da25..f1185034701b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp @@ -13,12 +13,10 @@ #include "tt_metal/host_api.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" +#include "ttnn/operations/eltwise/binary/binary.hpp" namespace ttnn::operations::unary_backward { -namespace utils { - - std::vector _mul_bw( const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { std::vector grad_tensor; @@ -36,7 +34,6 @@ std::vector _clamp_min_bw( return grad_tensor; } - std::vector _clamp_bw( const Tensor& grad, const Tensor& input, float min, float max, const MemoryConfig& output_mem_config) { std::vector grad_tensor; @@ -81,7 +78,7 @@ std::vector _add_bw( return grad_tensor; } -std::vector _unary_comp_bw(const Tensor& grad, const Tensor& input, float other, const MemoryConfig& output_mem_config) { +std::vector _unary_comp_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor zero_grad = tt::tt_metal::zeros_like(grad, output_mem_config); grad_tensor.emplace_back(zero_grad); @@ -90,10 +87,10 @@ std::vector _unary_comp_bw(const Tensor& grad, const Tensor& input, floa std::vector _eq_bw( const Tensor& grad, const Tensor& input, float other, const MemoryConfig& output_mem_config) { - return _unary_comp_bw(grad, input, other, output_mem_config); + return _unary_comp_bw(grad, output_mem_config); } -std::function(const Tensor&, const Tensor&, const MemoryConfig&)> get_function_type1(UnaryBackwardOpType OpType){ +std::function(const Tensor&, const Tensor&, const MemoryConfig&)> UnaryBackwardFunction::get_function_type1(UnaryBackwardOpType OpType){ switch (OpType) { case UnaryBackwardOpType::ASSIGN_BW: return _assign_bw; @@ -105,7 +102,7 @@ std::function(const Tensor&, const Tensor&, const Memo } } -std::function(const Tensor&, const Tensor&, float, const MemoryConfig&)> get_function_type1_w_float(UnaryBackwardOpType OpType){ +std::function(const Tensor&, const Tensor&, float, const MemoryConfig&)> UnaryBackwardFunction::get_function_type1_w_float(UnaryBackwardOpType OpType){ switch (OpType) { case UnaryBackwardOpType::MUL_BW: return _mul_bw; @@ -121,7 +118,7 @@ std::function(const Tensor&, const Tensor&, float, con } } -std::function(const Tensor&, const Tensor&, float, float, const MemoryConfig&)> get_function_type1_w_two_float(UnaryBackwardOpType OpType){ +std::function(const Tensor&, const Tensor&, float, float, const MemoryConfig&)> UnaryBackwardFunction::get_function_type1_w_two_float(UnaryBackwardOpType OpType){ switch (OpType) { case UnaryBackwardOpType::CLAMP_BW: return _clamp_bw; @@ -131,6 +128,4 @@ std::function(const Tensor&, const Tensor&, float, flo } } -} - } // namespace ttnn::operations::unary diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp index 16efcd1548f5..1285aa7da535 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp @@ -6,7 +6,7 @@ #include #include - +#include "tensor/tensor.hpp" #include "third_party/magic_enum/magic_enum.hpp" namespace ttnn::operations::unary_backward { @@ -22,5 +22,10 @@ enum class UnaryBackwardOpType { EQ_BW, }; +struct UnaryBackwardFunction{ + static std::function(const Tensor&, const Tensor&, const MemoryConfig&)> get_function_type1(UnaryBackwardOpType OpType); + static std::function(const Tensor&, const Tensor&, float, const MemoryConfig&)> get_function_type1_w_float(UnaryBackwardOpType OpType); + static std::function(const Tensor&, const Tensor&, float, float, const MemoryConfig&)> get_function_type1_w_two_float(UnaryBackwardOpType OpType); +}; -} // namespace ttnn::operations::unary +} // namespace ttnn::operations::unary_backward diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp index 3c0ac22b6d0e..b06e55143a6b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp @@ -5,7 +5,7 @@ #pragma once -#include "device/unary_backward_op.cpp" +#include "device/unary_backward_op.hpp" #include "ttnn/device_operation.hpp" #include "ttnn/operations/data_movement.hpp" @@ -29,7 +29,7 @@ struct ExecuteUnaryBackward { const Tensor &input_tensor_arg, const std::optional &memory_config = std::nullopt) { - auto op_type = utils::get_function_type1(unary_backward_op_type); + auto op_type = UnaryBackwardFunction::get_function_type1(unary_backward_op_type); auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config()); return op_type(grad_tensor_arg, input_tensor_arg, output_memory_config); } @@ -42,7 +42,7 @@ struct ExecuteUnaryBackward { float alpha, const std::optional &memory_config = std::nullopt) { - auto op_type = utils::get_function_type1_w_float(unary_backward_op_type); + auto op_type = UnaryBackwardFunction::get_function_type1_w_float(unary_backward_op_type); auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config()); return op_type(grad_tensor_arg, input_tensor_arg, alpha, output_memory_config); } @@ -56,7 +56,7 @@ struct ExecuteUnaryBackward { float b, const std::optional &memory_config = std::nullopt) { - auto op_type = utils::get_function_type1_w_two_float(unary_backward_op_type); + auto op_type = UnaryBackwardFunction::get_function_type1_w_two_float(unary_backward_op_type); auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config()); return op_type(grad_tensor_arg, input_tensor_arg, a, b, output_memory_config); } diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp index 684277306f5e..b463ec45c1bf 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp @@ -160,7 +160,7 @@ void py_module(py::module& module) { detail::bind_unary_backward( module, ttnn::mul_bw, - R"doc(Performs backward operations for multiply on :attr:`input_tensor`, :attr:`alpha` with given :attr:`grad_tensor`.)doc"); + R"doc(Performs backward operations for multiply on :attr:`input_tensor`, :attr:`alpha` or attr:`input_tensor_a`, attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc"); detail::bind_unary_backward( module, @@ -175,7 +175,7 @@ void py_module(py::module& module) { detail::bind_unary_backward( module, ttnn::assign_bw, - R"doc(Performs backward operations for assign on :attr:`input_tensor` with given :attr:`grad_tensor`.)doc"); + R"doc(Performs backward operations for assign on :attr:`input_tensor` or attr:`input_tensor_a`, attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc"); detail::bind_unary_backward( module, @@ -187,12 +187,12 @@ void py_module(py::module& module) { detail::bind_unary_backward( module, ttnn::add_bw, - R"doc(Performs backward operations for addition on :attr:`input_tensor`, :attr:`alpha` with given :attr:`grad_tensor`.)doc"); + R"doc(Performs backward operations for addition on :attr:`input_tensor`, :attr:`alpha` or attr:`input_tensor_a`, attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc"); detail::bind_unary_backward( module, ttnn::eq_bw, - R"doc(Performs backward operations for equal to comparison on :attr:`input_tensor`, :attr:`alpha` or attr:`input_tensor_b` with given :attr:`grad_tensor`. + R"doc(Performs backward operations for equal to comparison on :attr:`input_tensor`, :attr:`alpha` or attr:`input_tensor_a`, attr:`input_tensor_b` with given :attr:`grad_tensor`. Returns an tensor of zeros like input tensors.)doc"); }