diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_rsub.py b/tests/ttnn/unit_tests/operations/backward/test_backward_rsub.py index 7c68db0c2bbb..8f1834dba51c 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_rsub.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_rsub.py @@ -29,3 +29,46 @@ def test_bw_rsub(input_shapes, device): status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True], [False, False]]) +def test_bw_rsub_opt(input_shapes, device, are_required_outputs): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + other_data, other_tensor = data_gen_with_range(input_shapes, -5, 5, device, True) + + grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, 5, device) + + input_grad = None + other_grad = None + tt_output_tensor_on_device = None + + if are_required_outputs[0]: + _, input_grad = data_gen_with_range(input_shapes, -1, 1, device) + if are_required_outputs[1]: + _, other_grad = data_gen_with_range(input_shapes, -1, 1, device) + + cq_id = 0 + + if are_required_outputs[0] and are_required_outputs[1]: + pages_before = ttnn._ttnn.reports.get_buffer_pages() + ttnn.rsub_bw( + grad_tensor, input_tensor, other_tensor, input_grad=input_grad, other_grad=other_grad, queue_id=cq_id + ) + assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) + tt_output_tensor_on_device = [input_grad, other_grad] + else: + tt_output_tensor_on_device = ttnn.rsub_bw(grad_tensor, input_tensor, other_tensor, queue_id=cq_id) + + golden_function = ttnn.get_golden_function(ttnn.rsub_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) + + status = compare_pcc(tt_output_tensor_on_device, golden_tensor) + assert status diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_subalpha.py b/tests/ttnn/unit_tests/operations/backward/test_backward_subalpha.py index 3bb493858164..1c64fc30826a 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_subalpha.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_subalpha.py @@ -51,3 +51,32 @@ def test_bw_subalpha_default(input_shapes, device): status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_subalpha_opt_output(input_shapes, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + other_data, other_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) + + _, input_grad = data_gen_with_range(input_shapes, -1, 1, device) + + cq_id = 0 + pages_before = ttnn._ttnn.reports.get_buffer_pages() + ttnn.subalpha_bw(grad_tensor, input_tensor, other_tensor, input_grad=input_grad, queue_id=cq_id) + assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) + + tt_output_tensor_on_device = [input_grad] + + golden_function = ttnn.get_golden_function(ttnn.subalpha_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) + + status = compare_pcc(tt_output_tensor_on_device, golden_tensor) + assert status diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp index 58f94f5da71a..91a18c5db8b4 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp @@ -264,10 +264,49 @@ struct ExecuteAddalphaBW { std::optional input_b_grad = std::nullopt); }; +struct ExecuteBackwardSubAlpha { + static std::vector> invoke( + uint8_t queue_id, + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_a_arg, + const Tensor &input_tensor_b_arg, + float alpha, + const std::optional &memory_config = std::nullopt, + const std::vector &are_required_outputs = std::vector{true, true}, + std::optional input_grad = std::nullopt, + std::optional other_grad = std::nullopt); + + static std::vector> invoke( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_a_arg, + const Tensor &input_tensor_b_arg, + float alpha, + const std::optional &memory_config = std::nullopt); + +}; + +struct ExecuteBackwardRsub { + static std::vector> invoke( + uint8_t queue_id, + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_a_arg, + const Tensor &input_tensor_b_arg, + const std::optional &memory_config = std::nullopt, + const std::vector &are_required_outputs = std::vector{true, true}, + std::optional input_grad = std::nullopt, + std::optional other_grad = std::nullopt); + + static std::vector> invoke( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_a_arg, + const Tensor &input_tensor_b_arg, + const std::optional &memory_config = std::nullopt); + +}; + } // operations::binary constexpr auto atan2_bw = ttnn::register_operation<"ttnn::atan2_bw", operations::binary_backward::ExecuteBinaryBackwardTensor>(); -constexpr auto rsub_bw = ttnn::register_operation<"ttnn::rsub_bw", operations::binary_backward::ExecuteBinaryBackwardTensor>(); constexpr auto xlogy_bw = ttnn::register_operation<"ttnn::xlogy_bw", operations::binary_backward::ExecuteBinaryBackwardTensor>(); constexpr auto hypot_bw = ttnn::register_operation<"ttnn::hypot_bw", operations::binary_backward::ExecuteBinaryBackwardTensor>(); constexpr auto ldexp_bw = ttnn::register_operation<"ttnn::ldexp_bw", operations::binary_backward::ExecuteBinaryBackwardTensor>(); @@ -278,7 +317,13 @@ constexpr auto min_bw = ttnn::register_operation<"ttnn::min_bw", operations::bin constexpr auto max_bw = ttnn::register_operation<"ttnn::max_bw", operations::binary_backward::ExecuteBinaryBackwardTensor>(); -constexpr auto subalpha_bw = ttnn::register_operation<"ttnn::subalpha_bw", operations::binary_backward::ExecuteBinaryBackwardFloatDefault>(); +constexpr auto subalpha_bw = ttnn::register_operation< + "ttnn::subalpha_bw", + operations::binary_backward::ExecuteBackwardSubAlpha>(); + +constexpr auto rsub_bw = ttnn::register_operation< + "ttnn::rsub_bw", + operations::binary_backward::ExecuteBackwardRsub>(); constexpr auto concat_bw = ttnn::register_operation<"ttnn::concat_bw", operations::binary_backward::ExecuteBinaryBackwardIntDefault>(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp index bd40ed94052c..f74972db545b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp @@ -350,6 +350,141 @@ void bind_binary_backward_float_default(py::module& module, const binary_backwar ); } +template +void bind_binary_backward_sub_alpha(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, float parameter_value, std::string_view description, std::string_view supported_dtype) { + auto doc = fmt::format( + R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, {2}: float, *, memory_config: ttnn.MemoryConfig) -> std::vector + + {5} + + Args: + * :attr:`grad_tensor` + * :attr:`input_tensor_a` + * :attr:`input_tensor_b` + * :attr:`{2}` (float):`{3}`,Default value = {4} + + Keyword args: + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor + * :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor + * :attr:`queue_id` (Optional[uint8]): command queue id + + Supported dtypes, layouts, and ranks: + + {6} + + Note : bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT + + Example: + + >>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) + >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) + >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device) + >>> output = {1}(grad_tensor, tensor1, tensor2, float) + )doc", + operation.base_name(), + operation.python_fully_qualified_name(), + parameter_name, + parameter_doc, + parameter_value, + description, + supported_dtype); + + + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const binary_backward_operation_t& self, + const ttnn::Tensor& grad_tensor, + const ttnn::Tensor& input_tensor, + const ttnn::Tensor& other_tensor, + float alpha, + const std::optional& memory_config, + const std::vector& are_required_outputs, + const std::optional& input_grad, + const std::optional& other_grad, + const uint8_t& queue_id) -> std::vector> { + return self(queue_id, grad_tensor, input_tensor, other_tensor, alpha, memory_config, are_required_outputs, input_grad, other_grad); + }, + py::arg("grad_tensor"), + py::arg("input_tensor_a"), + py::arg("input_tensor_b"), + py::arg(parameter_name.c_str()) = parameter_value, + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("are_required_outputs") = std::vector{true, true}, + py::arg("input_grad") = std::nullopt, + py::arg("other_grad") = std::nullopt, + py::arg("queue_id") = 0} + ); +} + +template +void bind_binary_backward_rsub(py::module& module, const binary_backward_operation_t& operation, std::string_view description, std::string_view supported_dtype) { + + auto doc = fmt::format( + R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig) -> std::vector + + {2} + + Args: + * :attr:`grad_tensor` + * :attr:`input_tensor_a` + * :attr:`input_tensor_b` + + Keyword args: + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor + * :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor + * :attr:`queue_id` (Optional[uint8]): command queue id + + Supported dtypes, layouts, and ranks: + + {3} + + Note : bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT + + Example: + + >>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) + >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) + >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device) + >>> output = {1}(grad_tensor, tensor1, tensor2, float) + )doc", + operation.base_name(), + operation.python_fully_qualified_name(), + description, + supported_dtype); + + + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const binary_backward_operation_t& self, + const ttnn::Tensor& grad_tensor, + const ttnn::Tensor& input_tensor, + const ttnn::Tensor& other_tensor, + const std::optional& memory_config, + const std::vector& are_required_outputs, + const std::optional& input_grad, + const std::optional& other_grad, + const uint8_t& queue_id) -> std::vector> { + return self(queue_id, grad_tensor, input_tensor, other_tensor, memory_config, are_required_outputs, input_grad, other_grad); + }, + py::arg("grad_tensor"), + py::arg("input_tensor_a"), + py::arg("input_tensor_b"), + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("are_required_outputs") = std::vector{true, true}, + py::arg("input_grad") = std::nullopt, + py::arg("other_grad") = std::nullopt, + py::arg("queue_id") = 0} + ); +} + template void bind_binary_bw_mul(py::module& module, const binary_backward_operation_t& operation, std::string_view description, std::string_view supported_dtype) { auto doc = fmt::format( @@ -846,7 +981,7 @@ void py_module(py::module& module) { | BFLOAT16, BFLOAT8_B | ROW_MAJOR, TILE | 2, 3, 4 | +----------------------------+---------------------------------+-------------------+)doc"); - detail::bind_binary_backward_float_default( + detail::bind_binary_backward_sub_alpha( module, ttnn::subalpha_bw, "alpha", "Alpha value", 1.0f, @@ -948,7 +1083,7 @@ void py_module(py::module& module) { | BFLOAT16, BFLOAT8_B | ROW_MAJOR, TILE | 2, 3, 4 | +----------------------------+---------------------------------+-------------------+)doc"); - detail::bind_binary_backward_ops( + detail::bind_binary_backward_rsub( module, ttnn::rsub_bw, R"doc(Performs backward operations for subraction of :attr:`input_tensor_a` from :attr:`input_tensor_b` with given attr:`grad_tensor` (reversed order of subtraction operator).)doc", diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp index 0e95f9ef238a..4dcbebb69d8d 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp @@ -22,6 +22,7 @@ #include "tt_metal/common/constants.hpp" #include "ttnn/cpp/ttnn/common/constants.hpp" +#include "ttnn/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" #include "ttnn/operations/eltwise/ternary/where.hpp" @@ -105,18 +106,39 @@ std::vector> ExecuteAddalphaBW::invoke( return ExecuteAddalphaBW::invoke(ttnn::DefaultQueueId, grad, input, other, alpha, are_required_outputs, output_mem_config, input_grad, other_grad); } -std::vector _subalpha_bw( - const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const std::optional& output_mem_config) { - std::vector grad_tensor; - grad_tensor.emplace_back(grad); - Tensor grad_b = ttnn::multiply(ttnn::neg(grad, output_mem_config), alpha, std::nullopt, output_mem_config); - grad_tensor.emplace_back(grad_b); - return grad_tensor; -} +std::vector> ExecuteBackwardSubAlpha::invoke( + uint8_t queue_id, + const Tensor& grad, + const Tensor& input, + const Tensor& other, + float alpha, + const std::optional& output_mem_config, + const std::vector& are_required_outputs, + std::optional input_grad, + std::optional other_grad) { -std::vector _sub_bw( - const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config) { - return _subalpha_bw(grad, input, other, 1.0, output_mem_config); + std::vector> result; + if (are_required_outputs.at(0)) { + if(!input_grad.has_value()){ + input_grad = grad; + } + ttnn::assign(queue_id, grad, input_grad.value()); + result.emplace_back(input_grad); + } else { + result.emplace_back(std::nullopt); + } + + if (are_required_outputs.at(1)) { + if(!other_grad.has_value()){ + other_grad = ttnn::neg(queue_id, grad, output_mem_config); + } + ttnn::neg(queue_id, grad, output_mem_config, other_grad); + result.emplace_back(other_grad); + } else { + result.emplace_back(std::nullopt); + } + + return result; } std::vector ExecuteBackwardAdd::invoke( @@ -435,10 +457,38 @@ std::vector _binary_comp_bw(const Tensor& grad, const Tensor& input, con return grad_tensor; } -std::vector _rsub_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config) { - std::vector grad_tensor = _subalpha_bw(grad, input, other, 1.0f, output_mem_config.value_or(input.memory_config())); - std::swap(grad_tensor[0], grad_tensor[1]); - return grad_tensor; +std::vector> ExecuteBackwardRsub::invoke( + uint8_t queue_id, + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const std::optional& output_mem_config, + const std::vector& are_required_outputs, + std::optional input_grad, + std::optional other_grad) { + + std::vector> result; + + if (are_required_outputs.at(0)) { + if(!input_grad.has_value()){ + input_grad = ttnn::neg(queue_id, grad, output_mem_config); + } + ttnn::neg(queue_id, grad, output_mem_config, input_grad); + result.emplace_back(input_grad); + } else { + result.emplace_back(std::nullopt); + } + + if (are_required_outputs.at(1)) { + if(!other_grad.has_value()){ + other_grad = grad; + } + ttnn::assign(queue_id, grad, other_grad.value()); + result.emplace_back(other_grad); + } else { + result.emplace_back(std::nullopt); + } + return result; } std::vector ExecuteBackwardBiasGelu::invoke( diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp index 539d6114e91f..36b26199ab36 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp @@ -15,7 +15,6 @@ namespace ttnn::operations::binary_backward { enum class BinaryBackwardOpType { ATAN2_BW, ADDALPHA_BW, - SUBALPHA_BW, SUB_BW, XLOGY_BW, HYPOT_BW, @@ -26,7 +25,6 @@ enum class BinaryBackwardOpType { ADD_BW, ASSIGN_BW, CONCAT_BW, - RSUB_BW, BIAS_GELU_BW, MIN_BW, MAX_BW, @@ -37,7 +35,6 @@ enum class BinaryBackwardOpType { }; std::vector _atan2_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); -std::vector _rsub_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); std::vector _xlogy_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); std::vector _hypot_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); std::vector _ldexp_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); @@ -49,8 +46,6 @@ std::vector _squared_difference_bw( const Tensor& grad, const Tensor& in template std::vector _min_or_max_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); -std::vector _subalpha_bw( const Tensor& grad, const Tensor& input, const Tensor& other, float alpha = 1.0f, const std::optional& output_mem_config = std::nullopt); - std::vector _div_bw( const Tensor& grad, const Tensor& input, const Tensor& other, string round_mode = "None" , const std::optional& output_mem_config = std::nullopt); std::vector _concat_bw( const Tensor& grad, const Tensor& input, const Tensor& other, int dim = 0, const std::optional& output_mem_config = std::nullopt); @@ -97,13 +92,6 @@ struct OpHandler { } }; -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config ) { - return _rsub_bw(grad, input, other, output_mem_config); - } -}; - template <> struct OpHandler { static std::vector handle( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config ) { @@ -139,13 +127,6 @@ struct OpHandler { } }; -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const std::optional& output_mem_config ) { - return _subalpha_bw(grad, input, other, alpha, output_mem_config); - } -}; - template <> struct OpHandler { static std::vector handle( const Tensor& grad, const Tensor& input, const Tensor& other, string round_mode, const std::optional& output_mem_config ) {