Skip to content

Commit

Permalink
#6136: Add backward support for unary LE and GE
Browse files Browse the repository at this point in the history
  • Loading branch information
bharane-ab committed Mar 11, 2024
1 parent d9b88b0 commit 49a6887
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/source/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,10 @@ Backward Operations

.. autofunction:: tt_lib.tensor.log2_bw

.. autofunction:: tt_lib.tensor.ge_bw

.. autofunction:: tt_lib.tensor.le_bw

Loss Functions
==============

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt, compare_results


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_unary_ge(input_shapes, device):
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)
tt_output_tensor_on_device = tt_lib.tensor.ge_bw(grad_tensor)

pyt_y = torch.zeros_like(grad_data)

golden_tensor = [pyt_y]

comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt, compare_results


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_unary_le(input_shapes, device):
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)
tt_output_tensor_on_device = tt_lib.tensor.le_bw(grad_tensor)

pyt_y = torch.zeros_like(grad_data)

golden_tensor = [pyt_y]

comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
23 changes: 23 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,29 @@ std::vector<Tensor> log2_bw(const Tensor& grad, const Tensor& input, const Memor
{
return operation::decorate_as_composite(__func__, _log2_bw)(grad, input, output_mem_config);
}
std::vector<Tensor> _ge_bw(const Tensor& grad, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor t_zero = zeros_like(grad, output_mem_config);
grad_tensor.emplace_back(t_zero);
return grad_tensor;
}
std::vector<Tensor> ge_bw(const Tensor& grad, const MemoryConfig& output_mem_config)
{
return operation::decorate_as_composite(__func__, _ge_bw)(grad, output_mem_config);
}


std::vector<Tensor> _le_bw(const Tensor& grad, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor t_zero = zeros_like(grad, output_mem_config);
grad_tensor.emplace_back(t_zero);
return grad_tensor;
}
std::vector<Tensor> le_bw(const Tensor& grad, const MemoryConfig& output_mem_config)
{
return operation::decorate_as_composite(__func__, _le_bw)(grad, output_mem_config);
}

}//namespace tt_metal

}//namespace tt
3 changes: 3 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ std::vector<Tensor> logiteps_bw(const Tensor& grad, const Tensor& input, float e

std::vector<Tensor> log2_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> ge_bw(const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> le_bw(const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
} //namespace tt_metal

} //namespace tt
30 changes: 30 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1781,5 +1781,35 @@ namespace tt::tt_metal::detail{
"input", "Input Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("ge_bw", &tt::tt_metal::ge_bw,
py::arg("grad").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Returns an tensor of zeros like ``grad`` tensor
Input tensor must have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("le_bw", &tt::tt_metal::le_bw,
py::arg("grad").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Returns an tensor of zeros like ``grad`` tensor
Input tensor must have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");
}
}

0 comments on commit 49a6887

Please sign in to comment.