Skip to content

Commit

Permalink
#14985: Update the examples for binary backward doc (#15083)
Browse files Browse the repository at this point in the history
### Ticket
#14985

### Problem description
To update the examples for binary backward documentation

### What's changed
Updated the examples for the following binary backward ops

- concat_bw
- addalpha_bw
- mul_bw

### Checklist
- [x] [Post commit
CI](https://github.com/tenstorrent/tt-metal/actions/runs/11860056757)
passes

### Documentation screenshots
<img width="967" alt="Screenshot 2024-11-15 at 3 26 34 PM"
src="https://github.com/user-attachments/assets/5d7aad2e-92b1-4c83-99ee-730ad95f73be">

<img width="960" alt="Screenshot 2024-11-15 at 3 27 15 PM"
src="https://github.com/user-attachments/assets/83b284bb-3387-43c6-9e9e-266f3201e8c9">

<img width="961" alt="Screenshot 2024-11-15 at 9 41 04 PM"
src="https://github.com/user-attachments/assets/a58456e0-eee8-4325-b81e-99de3a3f075b">
  • Loading branch information
VirdhatchaniKN authored Nov 17, 2024
1 parent 07ec8a9 commit 42a30d7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,24 @@ def test_bw_concat_Default(input_shapes, input_shapes_2, device):
assert comp_pass


def test_bw_concat_default_example(device):
x1_torch = torch.rand([12, 1, 30, 32], dtype=torch.bfloat16, requires_grad=True)
x2_torch = torch.rand([2, 1, 30, 32], dtype=torch.bfloat16, requires_grad=True)
grad_tensor = torch.rand([14, 1, 30, 32], dtype=torch.bfloat16)
golden_function = ttnn.get_golden_function(ttnn.concat_bw)
golden_tensor = golden_function(grad_tensor, x1_torch, x2_torch)

grad_tt = ttnn.from_torch(grad_tensor, layout=ttnn.TILE_LAYOUT, device=device)
x1_tt = ttnn.from_torch(x1_torch, layout=ttnn.TILE_LAYOUT, device=device)
x2_tt = ttnn.from_torch(x2_torch, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.concat_bw(grad_tt, x1_tt, x2_tt, 0)
tt_out_1 = ttnn.to_torch(y_tt[1])
tt_out_0 = ttnn.to_torch(y_tt[0])
comp_pass_1 = torch.allclose(tt_out_1, golden_tensor[1])
comp_pass_0 = torch.allclose(tt_out_0, golden_tensor[0])
assert comp_pass_1 and comp_pass_0


@pytest.mark.parametrize(
"input_shapes, input_shapes_2",
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void bind_binary_backward_ops(py::module& module, const binary_backward_operatio
}

template <typename binary_backward_operation_t>
void bind_binary_backward_int_default(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, int parameter_value, const std::string_view description, const std::string_view supported_dtype = "BFLOAT16") {
void bind_binary_backward_concat(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, int parameter_value, const std::string_view description, const std::string_view supported_dtype = "BFLOAT16") {
auto doc = fmt::format(
R"doc(
{5}
Expand Down Expand Up @@ -123,16 +123,17 @@ void bind_binary_backward_int_default(py::module& module, const binary_backward_
- Ranks
* - {6}
- TILE
- 2, 3, 4
- 4
bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device)
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device=device)
>>> output = {1}(grad_tensor, tensor1, tensor2, int)
>>> grad_tensor = ttnn.from_torch(torch.rand([14, 1, 30, 32], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)
>>> tensor1 = ttnn.from_torch(torch.rand([12, 1, 30, 32], dtype=torch.bfloat16, requires_grad=True), layout=ttnn.TILE_LAYOUT, device=device)
>>> tensor2 = ttnn.from_torch(torch.rand([2, 1, 30, 32], dtype=torch.bfloat16, requires_grad=True), layout=ttnn.TILE_LAYOUT, device=device)
>>> {2} = {4}
>>> output = ttnn.concat_bw(grad_tensor, tensor1, tensor2, {2})
)doc",
Expand Down Expand Up @@ -176,7 +177,7 @@ void bind_binary_backward_int_default(py::module& module, const binary_backward_
}

template <typename binary_backward_operation_t>
void bind_binary_backward_opt_float_default(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, float parameter_value, const std::string_view description, const std::string_view supported_dtype = "BFLOAT16") {
void bind_binary_backward_addalpha(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, float parameter_value, const std::string_view description, const std::string_view supported_dtype = "BFLOAT16") {
auto doc = fmt::format(
R"doc(
{5}
Expand Down Expand Up @@ -218,10 +219,11 @@ void bind_binary_backward_opt_float_default(py::module& module, const binary_bac
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device)
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device=device)
>>> output = {1}(grad_tensor, tensor1, tensor2, float)
>>> grad_tensor = ttnn.from_torch(torch.tensor([[1, 2], [3,4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)
>>> tensor1 = ttnn.from_torch(torch.tensor([[1, 2], [3,4]], dtype=torch.bfloat16, requires_grad=True), layout=ttnn.TILE_LAYOUT, device=device)
>>> tensor2 = ttnn.from_torch(torch.tensor([[1, 2], [3,4]], dtype=torch.bfloat16, requires_grad=True), layout=ttnn.TILE_LAYOUT, device=device)
>>> {2} = {4}
>>> output = ttnn.addalpha_bw(grad_tensor, tensor1, tensor2, {2})
)doc",
Expand Down Expand Up @@ -555,10 +557,10 @@ void bind_binary_bw_mul(py::module& module, const binary_backward_operation_t& o
bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device)
>>> output = {1}(grad_tensor, tensor1, tensor2/scalar)
>>> grad_tensor = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)
>>> tensor1 = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16, requires_grad=True), layout=ttnn.TILE_LAYOUT, device=device)
>>> tensor2 = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16, requires_grad=True), layout=ttnn.TILE_LAYOUT, device=device)
>>> output = ttnn.mul_bw(grad_tensor, tensor1, tensor2)
)doc",
operation.base_name(),
Expand Down Expand Up @@ -1085,7 +1087,7 @@ void py_module(py::module& module) {
R"doc(Performs backward operations for subalpha of :attr:`input_tensor_a` and :attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_binary_backward_opt_float_default(
detail::bind_binary_backward_addalpha(
module,
ttnn::addalpha_bw,
"alpha", "Alpha value", 1.0f,
Expand Down Expand Up @@ -1133,7 +1135,7 @@ void py_module(py::module& module) {
R"doc(BFLOAT16, BFLOAT8_B)doc");


detail::bind_binary_backward_int_default(
detail::bind_binary_backward_concat(
module,
ttnn::concat_bw,
"dim", "Dimension to concatenate", 0,
Expand Down

0 comments on commit 42a30d7

Please sign in to comment.