From 42a30d72352e0b534e97084a779f58c62bc1751f Mon Sep 17 00:00:00 2001
From: Virdhatchani Narayanamoorthy
<138196495+VirdhatchaniKN@users.noreply.github.com>
Date: Sun, 17 Nov 2024 14:30:02 +0530
Subject: [PATCH] #14985: Update the examples for binary backward doc (#15083)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
### Ticket
https://github.com/tenstorrent/tt-metal/issues/14985
### Problem description
To update the examples for binary backward documentation
### What's changed
Updated the examples for the following binary backward ops
- concat_bw
- addalpha_bw
- mul_bw
### Checklist
- [x] [Post commit
CI](https://github.com/tenstorrent/tt-metal/actions/runs/11860056757)
passes
### Documentation screenshots
---
.../eltwise/backward/test_backward_concat.py | 18 ++++++++++
.../binary_backward_pybind.hpp | 36 ++++++++++---------
2 files changed, 37 insertions(+), 17 deletions(-)
diff --git a/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_concat.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_concat.py
index e589a9d3b86..4abcfd1c6eb 100644
--- a/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_concat.py
+++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_concat.py
@@ -69,6 +69,24 @@ def test_bw_concat_Default(input_shapes, input_shapes_2, device):
assert comp_pass
+def test_bw_concat_default_example(device):
+ x1_torch = torch.rand([12, 1, 30, 32], dtype=torch.bfloat16, requires_grad=True)
+ x2_torch = torch.rand([2, 1, 30, 32], dtype=torch.bfloat16, requires_grad=True)
+ grad_tensor = torch.rand([14, 1, 30, 32], dtype=torch.bfloat16)
+ golden_function = ttnn.get_golden_function(ttnn.concat_bw)
+ golden_tensor = golden_function(grad_tensor, x1_torch, x2_torch)
+
+ grad_tt = ttnn.from_torch(grad_tensor, layout=ttnn.TILE_LAYOUT, device=device)
+ x1_tt = ttnn.from_torch(x1_torch, layout=ttnn.TILE_LAYOUT, device=device)
+ x2_tt = ttnn.from_torch(x2_torch, layout=ttnn.TILE_LAYOUT, device=device)
+ y_tt = ttnn.concat_bw(grad_tt, x1_tt, x2_tt, 0)
+ tt_out_1 = ttnn.to_torch(y_tt[1])
+ tt_out_0 = ttnn.to_torch(y_tt[0])
+ comp_pass_1 = torch.allclose(tt_out_1, golden_tensor[1])
+ comp_pass_0 = torch.allclose(tt_out_0, golden_tensor[0])
+ assert comp_pass_1 and comp_pass_0
+
+
@pytest.mark.parametrize(
"input_shapes, input_shapes_2",
(
diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp
index b99897762e7..d594557f60b 100644
--- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp
+++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp
@@ -88,7 +88,7 @@ void bind_binary_backward_ops(py::module& module, const binary_backward_operatio
}
template
-void bind_binary_backward_int_default(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, int parameter_value, const std::string_view description, const std::string_view supported_dtype = "BFLOAT16") {
+void bind_binary_backward_concat(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, int parameter_value, const std::string_view description, const std::string_view supported_dtype = "BFLOAT16") {
auto doc = fmt::format(
R"doc(
{5}
@@ -123,16 +123,17 @@ void bind_binary_backward_int_default(py::module& module, const binary_backward_
- Ranks
* - {6}
- TILE
- - 2, 3, 4
+ - 4
bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT
Example:
- >>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device)
- >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device)
- >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device=device)
- >>> output = {1}(grad_tensor, tensor1, tensor2, int)
+ >>> grad_tensor = ttnn.from_torch(torch.rand([14, 1, 30, 32], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)
+ >>> tensor1 = ttnn.from_torch(torch.rand([12, 1, 30, 32], dtype=torch.bfloat16, requires_grad=True), layout=ttnn.TILE_LAYOUT, device=device)
+ >>> tensor2 = ttnn.from_torch(torch.rand([2, 1, 30, 32], dtype=torch.bfloat16, requires_grad=True), layout=ttnn.TILE_LAYOUT, device=device)
+ >>> {2} = {4}
+ >>> output = ttnn.concat_bw(grad_tensor, tensor1, tensor2, {2})
)doc",
@@ -176,7 +177,7 @@ void bind_binary_backward_int_default(py::module& module, const binary_backward_
}
template
-void bind_binary_backward_opt_float_default(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, float parameter_value, const std::string_view description, const std::string_view supported_dtype = "BFLOAT16") {
+void bind_binary_backward_addalpha(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, float parameter_value, const std::string_view description, const std::string_view supported_dtype = "BFLOAT16") {
auto doc = fmt::format(
R"doc(
{5}
@@ -218,10 +219,11 @@ void bind_binary_backward_opt_float_default(py::module& module, const binary_bac
Example:
- >>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device)
- >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device)
- >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device=device)
- >>> output = {1}(grad_tensor, tensor1, tensor2, float)
+ >>> grad_tensor = ttnn.from_torch(torch.tensor([[1, 2], [3,4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)
+ >>> tensor1 = ttnn.from_torch(torch.tensor([[1, 2], [3,4]], dtype=torch.bfloat16, requires_grad=True), layout=ttnn.TILE_LAYOUT, device=device)
+ >>> tensor2 = ttnn.from_torch(torch.tensor([[1, 2], [3,4]], dtype=torch.bfloat16, requires_grad=True), layout=ttnn.TILE_LAYOUT, device=device)
+ >>> {2} = {4}
+ >>> output = ttnn.addalpha_bw(grad_tensor, tensor1, tensor2, {2})
)doc",
@@ -555,10 +557,10 @@ void bind_binary_bw_mul(py::module& module, const binary_backward_operation_t& o
bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT
Example:
- >>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
- >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
- >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device)
- >>> output = {1}(grad_tensor, tensor1, tensor2/scalar)
+ >>> grad_tensor = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)
+ >>> tensor1 = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16, requires_grad=True), layout=ttnn.TILE_LAYOUT, device=device)
+ >>> tensor2 = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16, requires_grad=True), layout=ttnn.TILE_LAYOUT, device=device)
+ >>> output = ttnn.mul_bw(grad_tensor, tensor1, tensor2)
)doc",
operation.base_name(),
@@ -1085,7 +1087,7 @@ void py_module(py::module& module) {
R"doc(Performs backward operations for subalpha of :attr:`input_tensor_a` and :attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");
- detail::bind_binary_backward_opt_float_default(
+ detail::bind_binary_backward_addalpha(
module,
ttnn::addalpha_bw,
"alpha", "Alpha value", 1.0f,
@@ -1133,7 +1135,7 @@ void py_module(py::module& module) {
R"doc(BFLOAT16, BFLOAT8_B)doc");
- detail::bind_binary_backward_int_default(
+ detail::bind_binary_backward_concat(
module,
ttnn::concat_bw,
"dim", "Dimension to concatenate", 0,