diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_rsub.py b/tests/ttnn/unit_tests/operations/backward/test_backward_rsub.py index 8f1834dba51c..5b20643fa977 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_rsub.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_rsub.py @@ -39,7 +39,7 @@ def test_bw_rsub(input_shapes, device): (torch.Size([1, 3, 320, 384])), ), ) -@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True], [False, False]]) +@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]]) def test_bw_rsub_opt(input_shapes, device, are_required_outputs): in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) other_data, other_tensor = data_gen_with_range(input_shapes, -5, 5, device, True) @@ -57,18 +57,16 @@ def test_bw_rsub_opt(input_shapes, device, are_required_outputs): cq_id = 0 - if are_required_outputs[0] and are_required_outputs[1]: - pages_before = ttnn._ttnn.reports.get_buffer_pages() - ttnn.rsub_bw( - grad_tensor, input_tensor, other_tensor, input_grad=input_grad, other_grad=other_grad, queue_id=cq_id - ) - assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) - tt_output_tensor_on_device = [input_grad, other_grad] - else: - tt_output_tensor_on_device = ttnn.rsub_bw(grad_tensor, input_tensor, other_tensor, queue_id=cq_id) + pages_before = ttnn._ttnn.reports.get_buffer_pages() + ttnn.rsub_bw(grad_tensor, input_tensor, other_tensor, input_grad=input_grad, other_grad=other_grad, queue_id=cq_id) + assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) + tt_output_tensor_on_device = [input_grad, other_grad] golden_function = ttnn.get_golden_function(ttnn.rsub_bw) golden_tensor = golden_function(grad_data, in_data, other_data) - status = compare_pcc(tt_output_tensor_on_device, golden_tensor) + status = True + for i in range(len(are_required_outputs)): + if are_required_outputs[i]: + status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]]) assert status diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_sub.py b/tests/ttnn/unit_tests/operations/backward/test_backward_sub.py index a73803150b5c..41bd48a5735a 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_sub.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_sub.py @@ -50,3 +50,43 @@ def test_bw_unary_sub(input_shapes, scalar, device): status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]]) +def test_bw_sub_opt(input_shapes, device, are_required_outputs): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + other_data, other_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) + + input_grad = None + other_grad = None + tt_output_tensor_on_device = None + + if are_required_outputs[0]: + _, input_grad = data_gen_with_range(input_shapes, -1, 1, device) + if are_required_outputs[1]: + _, other_grad = data_gen_with_range(input_shapes, -1, 1, device) + + cq_id = 0 + + pages_before = ttnn._ttnn.reports.get_buffer_pages() + ttnn.sub_bw(grad_tensor, input_tensor, other_tensor, input_grad=input_grad, other_grad=other_grad, queue_id=cq_id) + assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) + tt_output_tensor_on_device = [input_grad, other_grad] + + golden_function = ttnn.get_golden_function(ttnn.sub_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) + + status = True + for i in range(len(are_required_outputs)): + if are_required_outputs[i]: + status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]]) + assert status diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp index 91a18c5db8b4..3fbc7a6429b3 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp @@ -151,11 +151,15 @@ struct ExecuteBackwardAdd { float scalar, const std::optional &memory_config = std::nullopt); - static std::vector invoke( + static std::vector> invoke( + uint8_t queue_id, const Tensor &grad_tensor_arg, const Tensor &input_tensor_a_arg, const Tensor &input_tensor_b_arg, - const std::optional &memory_config = std::nullopt); + const std::optional &memory_config = std::nullopt, + const std::vector &are_required_outputs = std::vector{true, true}, + std::optional input_grad = std::nullopt, + std::optional other_grad = std::nullopt); static std::vector invoke( const ComplexTensor &grad_tensor_arg, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp index f74972db545b..cb1661aa45df 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp @@ -583,6 +583,104 @@ void bind_binary_bw_mul(py::module& module, const binary_backward_operation_t& o py::arg("memory_config") = std::nullopt}); } + +template +void bind_binary_bw_sub(py::module& module, const binary_backward_operation_t& operation, std::string_view description, std::string_view supported_dtype) { + auto doc = fmt::format( + R"doc({0}(input_tensor_a: Union[ttnn.Tensor, ComplexTensor] , input_tensor_b: Union[ComplexTensor, ttnn.Tensor, int, float], *, memory_config: Optional[ttnn.MemoryConfig] = None, dtype: Optional[ttnn.DataType] = None, activations: Optional[List[str]] = None) -> ttnn.Tensor or ComplexTensor + + {2} + Supports broadcasting. + + Args: + * :attr:`input_tensor_a` (ComplexTensor or ttnn.Tensor) + * :attr:`input_tensor_b` (ComplexTensor or ttnn.Tensor or Number): the tensor or number to add to :attr:`input_tensor_a`. + + Keyword args: + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor + * :attr:`dtype` (Optional[ttnn.DataType]): data type for the output tensor + * :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor + * :attr:`activations` (Optional[List[str]]): list of activation functions to apply to the output tensor + * :attr:`queue_id` (Optional[uint8]): command queue id + + Supported dtypes, layouts, and ranks: + + {3} + + Note : bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT + + Example: + + >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) + >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device) + >>> output = {1}(tensor1, tensor2) + + )doc", + operation.base_name(), + operation.python_fully_qualified_name(), + description, + supported_dtype); + + bind_registered_operation( + module, + operation, + doc, + // tensor and scalar + ttnn::pybind_overload_t{ + [](const binary_backward_operation_t& self, + const Tensor& grad_tensor, + const Tensor& input_tensor_a, + const float scalar, + const std::optional& memory_config){ + return self(grad_tensor, input_tensor_a, scalar, memory_config); + }, + py::arg("grad_tensor"), + py::arg("input_tensor_a"), + py::arg("scalar"), + py::kw_only(), + py::arg("memory_config") = std::nullopt}, + + // tensor and tensor + ttnn::pybind_overload_t{ + [](const binary_backward_operation_t& self, + const ttnn::Tensor& grad_tensor, + const ttnn::Tensor& input_tensor, + const ttnn::Tensor& other_tensor, + const std::optional& memory_config, + const std::vector& are_required_outputs, + const std::optional& input_grad, + const std::optional& other_grad, + const uint8_t& queue_id) -> std::vector> { + return self(queue_id, grad_tensor, input_tensor, other_tensor, memory_config, are_required_outputs, input_grad, other_grad); + }, + py::arg("grad_tensor"), + py::arg("input_tensor"), + py::arg("other_tensor"), + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("are_required_outputs") = std::vector{true, true}, + py::arg("input_grad") = std::nullopt, + py::arg("other_grad") = std::nullopt, + py::arg("queue_id") = 0}, + + // complex tensor + ttnn::pybind_overload_t{ + [](const binary_backward_operation_t& self, + const ComplexTensor& grad_tensor, + const ComplexTensor& input_tensor_a, + const ComplexTensor& input_tensor_b, + float alpha, + const std::optional& memory_config) { + return self(grad_tensor, input_tensor_a, input_tensor_b, alpha, memory_config); + }, + py::arg("grad_tensor"), + py::arg("input_tensor_a"), + py::arg("input_tensor_b"), + py::arg("alpha"), + py::kw_only(), + py::arg("memory_config") = std::nullopt}); +} + template void bind_binary_bw_div(py::module& module, const binary_backward_operation_t& operation, std::string_view description, std::string_view supported_dtype) { auto doc = fmt::format( @@ -913,7 +1011,7 @@ void py_module(py::module& module) { | BFLOAT16, BFLOAT8_B | ROW_MAJOR, TILE | 2, 3, 4 | +----------------------------+---------------------------------+-------------------+)doc"); - detail::bind_binary_bw_operation( + detail::bind_binary_bw_sub( module, ttnn::sub_bw, R"doc(Performs backward operations for sub of :attr:`input_tensor_a` and :attr:`input_tensor_b` or :attr:`scalar` with given :attr:`grad_tensor`.)doc", diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp index 4dcbebb69d8d..8c6d786b33b3 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp @@ -177,13 +177,16 @@ std::vector ExecuteBackwardSub::invoke( return grad_tensor; } -std::vector ExecuteBackwardSub::invoke( - const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config) { - std::vector grad_tensor; - grad_tensor.emplace_back(grad); - Tensor grad_b = ttnn::multiply(ttnn::neg(grad, output_mem_config), 1.0f, std::nullopt, output_mem_config); - grad_tensor.emplace_back(grad_b); - return grad_tensor; +std::vector> ExecuteBackwardSub::invoke( + uint8_t queue_id, + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const std::optional& output_mem_config, + const std::vector& are_required_outputs, + std::optional input_grad, + std::optional other_grad) { + return ttnn::subalpha_bw(queue_id, grad, input, other, 1.0f, output_mem_config, are_required_outputs, input_grad, other_grad); } std::vector ExecuteBackwardSub::invoke( diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp index 36b26199ab36..fb8d65fe2e15 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.hpp @@ -15,7 +15,6 @@ namespace ttnn::operations::binary_backward { enum class BinaryBackwardOpType { ATAN2_BW, ADDALPHA_BW, - SUB_BW, XLOGY_BW, HYPOT_BW, LDEXP_BW, @@ -39,7 +38,6 @@ std::vector _xlogy_bw( const Tensor& grad, const Tensor& input, const Te std::vector _hypot_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); std::vector _ldexp_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); std::vector _logaddexp_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); -std::vector _sub_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); std::vector _gt_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); std::vector _logaddexp2_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); std::vector _squared_difference_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config); @@ -92,13 +90,6 @@ struct OpHandler { } }; -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config ) { - return _sub_bw(grad, input, other, output_mem_config); - } -}; - template <> struct OpHandler { static std::vector handle( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config ) {