From ad1fd9f4ce82f5d6a9b60df8bba1893be760735a Mon Sep 17 00:00:00 2001 From: Virdhatchani Narayanamoorthy <138196495+VirdhatchaniKN@users.noreply.github.com> Date: Thu, 11 Jul 2024 23:34:43 +0530 Subject: [PATCH] #10071 : Move second set of Unary Backward ops to TTNN (#10038) * #10071: Merge floor_bw to TTNN * #10071: Merge round_bw to TTNN * #10071: Merge log_bw to TTNN * #10071: Merge relu6_bw to TTNN * #10071: Merge abs_bw to TTNN * #10071: Merge silu_bw to TTNN * #10071: Merge selu_bw to TTNN * #10071: Fix rebase errors * #10071: Update files * #10071: Move test files to TTNN * #10071: Update CPP files --------- Co-authored-by: mouliraj-mcw --- docs/source/ttnn/ttnn/api.rst | 7 + docs/source/ttnn/ttnn/dependencies/tt_lib.rst | 14 -- docs/source/ttnn/ttnn/ttnn/abs_bw.rst | 6 + docs/source/ttnn/ttnn/ttnn/floor_bw.rst | 6 + docs/source/ttnn/ttnn/ttnn/log_bw.rst | 6 + docs/source/ttnn/ttnn/ttnn/relu6_bw.rst | 6 + docs/source/ttnn/ttnn/ttnn/round_bw.rst | 6 + docs/source/ttnn/ttnn/ttnn/selu_bw.rst | 6 + docs/source/ttnn/ttnn/ttnn/silu_bw.rst | 6 + .../operations/backward}/test_backward_abs.py | 6 +- .../backward}/test_backward_floor.py | 6 +- .../operations/backward}/test_backward_log.py | 8 +- .../backward}/test_backward_relu6.py | 6 +- .../backward}/test_backward_round.py | 6 +- .../backward}/test_backward_selu.py | 6 +- .../backward}/test_backward_silu.py | 6 +- .../op_library/backward/backward_ops.cpp | 120 ------------------ .../op_library/backward/backward_ops.hpp | 30 ----- .../tt_lib_bindings_tensor_backward_ops.cpp | 114 ----------------- .../device/unary_backward_op.cpp | 113 ++++++++++++++++- .../device/unary_backward_op.hpp | 7 + .../eltwise/unary_backward/unary_backward.hpp | 8 +- .../unary_backward/unary_backward_pybind.hpp | 36 ++++++ 23 files changed, 232 insertions(+), 303 deletions(-) create mode 100644 docs/source/ttnn/ttnn/ttnn/abs_bw.rst create mode 100644 docs/source/ttnn/ttnn/ttnn/floor_bw.rst create mode 100644 docs/source/ttnn/ttnn/ttnn/log_bw.rst create mode 100644 docs/source/ttnn/ttnn/ttnn/relu6_bw.rst create mode 100644 docs/source/ttnn/ttnn/ttnn/round_bw.rst create mode 100644 docs/source/ttnn/ttnn/ttnn/selu_bw.rst create mode 100644 docs/source/ttnn/ttnn/ttnn/silu_bw.rst rename tests/{tt_eager/python_api_testing/unit_testing/backward_ops => ttnn/unit_tests/operations/backward}/test_backward_abs.py (77%) rename tests/{tt_eager/python_api_testing/unit_testing/backward_ops => ttnn/unit_tests/operations/backward}/test_backward_floor.py (78%) rename tests/{tt_eager/python_api_testing/unit_testing/backward_ops => ttnn/unit_tests/operations/backward}/test_backward_log.py (84%) rename tests/{tt_eager/python_api_testing/unit_testing/backward_ops => ttnn/unit_tests/operations/backward}/test_backward_relu6.py (78%) rename tests/{tt_eager/python_api_testing/unit_testing/backward_ops => ttnn/unit_tests/operations/backward}/test_backward_round.py (78%) rename tests/{tt_eager/python_api_testing/unit_testing/backward_ops => ttnn/unit_tests/operations/backward}/test_backward_selu.py (77%) rename tests/{tt_eager/python_api_testing/unit_testing/backward_ops => ttnn/unit_tests/operations/backward}/test_backward_silu.py (77%) diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index eb913f593a3..a04456f3c6d 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -193,6 +193,13 @@ Pointwise Unary ttnn/elu_bw ttnn/celu_bw ttnn/rpow_bw + ttnn/floor_bw + ttnn/round_bw + ttnn/log_bw + ttnn/relu6_bw + ttnn/abs_bw + ttnn/silu_bw + ttnn/selu_bw Pointwise Binary ================ diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index b4b942d6133..8f71224774a 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -814,10 +814,6 @@ Backward Operations .. autofunction:: tt_lib.tensor.fill_bw -.. autofunction:: tt_lib.tensor.log_bw - -.. autofunction:: tt_lib.tensor.abs_bw - .. autofunction:: tt_lib.tensor.complex_abs_bw .. autofunction:: tt_lib.tensor.lt_bw @@ -880,12 +876,6 @@ Backward Operations .. autofunction:: tt_lib.tensor.reciprocal_bw -.. autofunction:: tt_lib.tensor.relu6_bw - -.. autofunction:: tt_lib.tensor.silu_bw - -.. autofunction:: tt_lib.tensor.selu_bw - .. autofunction:: tt_lib.tensor.square_bw .. autofunction:: tt_lib.tensor.tanhshrink_bw @@ -928,10 +918,6 @@ Backward Operations .. autofunction:: tt_lib.tensor.repeat_bw -.. autofunction:: tt_lib.tensor.floor_bw - -.. autofunction:: tt_lib.tensor.round_bw - .. autofunction:: tt_lib.tensor.unary_div_no_nan_bw Loss Functions diff --git a/docs/source/ttnn/ttnn/ttnn/abs_bw.rst b/docs/source/ttnn/ttnn/ttnn/abs_bw.rst new file mode 100644 index 00000000000..dc397a7c134 --- /dev/null +++ b/docs/source/ttnn/ttnn/ttnn/abs_bw.rst @@ -0,0 +1,6 @@ +.. _ttnn.abs_bw: + +ttnn.abs_bw +########### + + .. autofunction:: ttnn.abs_bw diff --git a/docs/source/ttnn/ttnn/ttnn/floor_bw.rst b/docs/source/ttnn/ttnn/ttnn/floor_bw.rst new file mode 100644 index 00000000000..6baffddfda2 --- /dev/null +++ b/docs/source/ttnn/ttnn/ttnn/floor_bw.rst @@ -0,0 +1,6 @@ +.. _ttnn.floor_bw: + +ttnn.floor_bw +############# + + .. autofunction:: ttnn.floor_bw diff --git a/docs/source/ttnn/ttnn/ttnn/log_bw.rst b/docs/source/ttnn/ttnn/ttnn/log_bw.rst new file mode 100644 index 00000000000..5f6d4ffe8f2 --- /dev/null +++ b/docs/source/ttnn/ttnn/ttnn/log_bw.rst @@ -0,0 +1,6 @@ +.. _ttnn.log_bw: + +ttnn.log_bw +########### + + .. autofunction:: ttnn.log_bw diff --git a/docs/source/ttnn/ttnn/ttnn/relu6_bw.rst b/docs/source/ttnn/ttnn/ttnn/relu6_bw.rst new file mode 100644 index 00000000000..fc9ea28206d --- /dev/null +++ b/docs/source/ttnn/ttnn/ttnn/relu6_bw.rst @@ -0,0 +1,6 @@ +.. _ttnn.relu6_bw: + +ttnn.relu6_bw +############# + + .. autofunction:: ttnn.relu6_bw diff --git a/docs/source/ttnn/ttnn/ttnn/round_bw.rst b/docs/source/ttnn/ttnn/ttnn/round_bw.rst new file mode 100644 index 00000000000..d38af90674c --- /dev/null +++ b/docs/source/ttnn/ttnn/ttnn/round_bw.rst @@ -0,0 +1,6 @@ +.. _ttnn.round_bw: + +ttnn.round_bw +############# + + .. autofunction:: ttnn.round_bw diff --git a/docs/source/ttnn/ttnn/ttnn/selu_bw.rst b/docs/source/ttnn/ttnn/ttnn/selu_bw.rst new file mode 100644 index 00000000000..e3679d997bd --- /dev/null +++ b/docs/source/ttnn/ttnn/ttnn/selu_bw.rst @@ -0,0 +1,6 @@ +.. _ttnn.selu_bw: + +ttnn.selu_bw +############ + + .. autofunction:: ttnn.selu_bw diff --git a/docs/source/ttnn/ttnn/ttnn/silu_bw.rst b/docs/source/ttnn/ttnn/ttnn/silu_bw.rst new file mode 100644 index 00000000000..2f18db474fc --- /dev/null +++ b/docs/source/ttnn/ttnn/ttnn/silu_bw.rst @@ -0,0 +1,6 @@ +.. _ttnn.silu_bw: + +ttnn.silu_bw +############ + + .. autofunction:: ttnn.silu_bw diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_abs.py b/tests/ttnn/unit_tests/operations/backward/test_backward_abs.py similarity index 77% rename from tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_abs.py rename to tests/ttnn/unit_tests/operations/backward/test_backward_abs.py index 83e419f7c65..4da39c4a436 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_abs.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_abs.py @@ -4,8 +4,8 @@ import torch import pytest -import tt_lib -from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc +import ttnn +from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( @@ -22,7 +22,7 @@ def test_bw_abs(input_shapes, device): pyt_y = torch.abs(in_data) - tt_output_tensor_on_device = tt_lib.tensor.abs_bw(grad_tensor, input_tensor) + tt_output_tensor_on_device = ttnn.abs_bw(grad_tensor, input_tensor) in_data.retain_grad() diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_floor.py b/tests/ttnn/unit_tests/operations/backward/test_backward_floor.py similarity index 78% rename from tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_floor.py rename to tests/ttnn/unit_tests/operations/backward/test_backward_floor.py index b9fd8c12135..bedf8ea2fef 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_floor.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_floor.py @@ -4,8 +4,8 @@ import torch import pytest -import tt_lib -from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range +import ttnn +from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( @@ -22,7 +22,7 @@ def test_bw_floor(input_shapes, device): pyt_y = torch.floor(in_data) - tt_output_tensor_on_device = tt_lib.tensor.floor_bw(grad_tensor) + tt_output_tensor_on_device = ttnn.floor_bw(grad_tensor, input_tensor) in_data.retain_grad() diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_log.py b/tests/ttnn/unit_tests/operations/backward/test_backward_log.py similarity index 84% rename from tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_log.py rename to tests/ttnn/unit_tests/operations/backward/test_backward_log.py index c404ff025c4..1fcec1f3958 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_log.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_log.py @@ -4,8 +4,8 @@ import torch import pytest -import tt_lib -from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import ( +import ttnn +from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( data_gen_with_val, compare_pcc, data_gen_with_range, @@ -23,7 +23,7 @@ def test_bw_log_0(input_shapes, device): in_data, input_tensor = data_gen_with_val(input_shapes, device, True, val=0) grad_data, grad_tensor = data_gen_with_range(input_shapes, -1, 1, device) - tt_output_tensor_on_device = tt_lib.tensor.log_bw(grad_tensor, input_tensor) + tt_output_tensor_on_device = ttnn.log_bw(grad_tensor, input_tensor) in_data.retain_grad() @@ -47,7 +47,7 @@ def test_bw_log_0(input_shapes, device): def test_bw_log(input_shapes, device): in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) - tt_output_tensor_on_device = tt_lib.tensor.log_bw(grad_tensor, input_tensor) + tt_output_tensor_on_device = ttnn.log_bw(grad_tensor, input_tensor) in_data.retain_grad() diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_relu6.py b/tests/ttnn/unit_tests/operations/backward/test_backward_relu6.py similarity index 78% rename from tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_relu6.py rename to tests/ttnn/unit_tests/operations/backward/test_backward_relu6.py index 2ee155ac401..21e72e32a28 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_relu6.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_relu6.py @@ -4,8 +4,8 @@ import torch import pytest -import tt_lib -from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range +import ttnn +from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( @@ -22,7 +22,7 @@ def test_bw_relu6(input_shapes, device): pyt_y = torch.nn.functional.relu6(in_data) - tt_output_tensor_on_device = tt_lib.tensor.relu6_bw(grad_tensor, input_tensor) + tt_output_tensor_on_device = ttnn.relu6_bw(grad_tensor, input_tensor) in_data.retain_grad() diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_round.py b/tests/ttnn/unit_tests/operations/backward/test_backward_round.py similarity index 78% rename from tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_round.py rename to tests/ttnn/unit_tests/operations/backward/test_backward_round.py index cc9b7109a9c..0ea2865a08c 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_round.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_round.py @@ -4,8 +4,8 @@ import torch import pytest -import tt_lib -from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range +import ttnn +from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( @@ -21,7 +21,7 @@ def test_bw_round(input_shapes, device): in_data, input_tensor = data_gen_with_range(input_shapes, -200, 201, device, required_grad=True) pyt_y = torch.round(in_data) - tt_output_tensor_on_device = tt_lib.tensor.round_bw(grad_tensor) + tt_output_tensor_on_device = ttnn.round_bw(grad_tensor, input_tensor) in_data.retain_grad() diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_selu.py b/tests/ttnn/unit_tests/operations/backward/test_backward_selu.py similarity index 77% rename from tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_selu.py rename to tests/ttnn/unit_tests/operations/backward/test_backward_selu.py index b834db1fbe1..7adc3c09ee1 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_selu.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_selu.py @@ -4,8 +4,8 @@ import torch import pytest -import tt_lib -from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc +import ttnn +from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( @@ -22,7 +22,7 @@ def test_bw_selu(input_shapes, device): pyt_y = torch.nn.functional.selu(in_data) - tt_output_tensor_on_device = tt_lib.tensor.selu_bw(grad_tensor, input_tensor) + tt_output_tensor_on_device = ttnn.selu_bw(grad_tensor, input_tensor) in_data.retain_grad() diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_silu.py b/tests/ttnn/unit_tests/operations/backward/test_backward_silu.py similarity index 77% rename from tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_silu.py rename to tests/ttnn/unit_tests/operations/backward/test_backward_silu.py index 455c87ce191..577a929682b 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_silu.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_silu.py @@ -4,8 +4,8 @@ import torch import pytest -import tt_lib -from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc +import ttnn +from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( @@ -22,7 +22,7 @@ def test_bw_silu(input_shapes, device): pyt_y = torch.nn.functional.silu(in_data) - tt_output_tensor_on_device = tt_lib.tensor.silu_bw(grad_tensor, input_tensor) + tt_output_tensor_on_device = ttnn.silu_bw(grad_tensor, input_tensor) in_data.retain_grad() diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp index a60a9883150..7f878cc5bf8 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp @@ -280,36 +280,6 @@ std::vector ne_bw(const Tensor& grad, const MemoryConfig& output_mem_con return operation::decorate_as_composite(__func__, _ne_bw)(grad, output_mem_config); } -std::vector _log_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { - std::vector grad_tensor; - Tensor grad_a = ttnn::multiply(grad, ttnn::reciprocal(input, output_mem_config), std::nullopt, output_mem_config); - Tensor t_inf = full_like(input, std::numeric_limits::infinity(), output_mem_config); - Tensor t_nan = full_like(input, std::nanf(""), output_mem_config); - grad_tensor.emplace_back(where( - ttnn::eqz(input, output_mem_config), - where( - ttnn::eqz(grad, output_mem_config), - t_nan, - ttnn::multiply(t_inf, ttnn::sign(grad, output_mem_config), std::nullopt, output_mem_config), - output_mem_config), - grad_a, - output_mem_config)); - return grad_tensor; -} -std::vector log_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _log_bw)(grad, input, output_mem_config); -} - -std::vector _abs_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { - std::vector grad_tensor; - Tensor result = ttnn::multiply(grad, ttnn::sign(input, output_mem_config), std::nullopt, output_mem_config); - grad_tensor.emplace_back(result); - return grad_tensor; -} -std::vector abs_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _abs_bw)(grad, input, output_mem_config); -} - // bw(expm1) = grad * expm1(input) + 1 std::vector _expm1_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; @@ -990,76 +960,6 @@ std::vector reciprocal_bw(const Tensor& grad, const Tensor& input, const return operation::decorate_as_composite(__func__, _reciprocal_bw)(grad, input, output_mem_config); } -std::vector _relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { - std::vector grad_tensor; - Tensor zero_tensor = zeros_like(input, output_mem_config); - Tensor one_tensor = ones_like(input, output_mem_config); - Tensor six_tensor = full_like(input, 6, output_mem_config); - Tensor grad_result = - where(ttnn::le(input, zero_tensor, std::nullopt, output_mem_config), zero_tensor, six_tensor, output_mem_config); - grad_result = where( - ttnn::logical_and( - ttnn::gtz(input, output_mem_config), - ttnn::lt(input, six_tensor, std::nullopt, output_mem_config), - std::nullopt, - output_mem_config), - grad, - grad_result, - output_mem_config); - grad_result = - where(ttnn::ge(input, six_tensor, std::nullopt, output_mem_config), zero_tensor, grad_result, output_mem_config); - - grad_tensor.emplace_back(grad_result); - return grad_tensor; -} -std::vector relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _relu6_bw)(grad, input, output_mem_config); -} - - -// Silu -// result: grad * sigmoid_result * (1 + input * (1 - sigmoid_result)) -std::vector _silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { - std::vector grad_tensor; - Tensor grad_sigmoid = ttnn::multiply(grad, ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config); - Tensor add_sub = ttnn::add( - ttnn::multiply(ttnn::subtract(ttnn::full_like(input, 1.0f) , ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config), - input, - std::nullopt, - output_mem_config), - 1.0f, - std::nullopt, - output_mem_config); - Tensor grad_result = ttnn::multiply(grad_sigmoid, add_sub, std::nullopt, output_mem_config); - - grad_tensor.emplace_back(grad_result); - return grad_tensor; -} -std::vector silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _silu_bw)(grad, input, output_mem_config); -} - -// Selu -// result: torch.where(input > 0, grad * lambd, grad * lambd * alpha * torch.exp(input)) -std::vector _selu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { - std::vector grad_tensor; - Tensor grad_lambd = ttnn::multiply(grad, 1.0507f, std::nullopt, output_mem_config); - Tensor grad_result = where( - ttnn::gtz(input, output_mem_config), - grad_lambd, - ttnn::multiply(ttnn::multiply(grad_lambd, 1.673260f, std::nullopt, output_mem_config), - ttnn::exp(input, false, output_mem_config), - std::nullopt, - output_mem_config), - output_mem_config); - grad_tensor.emplace_back(grad_result); - return grad_tensor; -} -std::vector selu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _selu_bw)(grad, input, output_mem_config); -} - - // Autoformat support Tensor change_layout_to_tile(const Tensor& temp, const MemoryConfig& output_mem_config) { auto formatted_input_tensor = temp; @@ -1748,26 +1648,6 @@ std::vector repeat_bw( return operation::decorate_as_composite(__func__, _repeat_bw)(grad, input, shape, output_mem_config); } -std::vector _floor_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { - std::vector grad_tensor; - Tensor t_zero = zeros_like(grad, output_mem_config); - grad_tensor.emplace_back(t_zero); - return grad_tensor; -} -std::vector floor_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _floor_bw)(grad, output_mem_config); -} - -std::vector _round_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { - std::vector grad_tensor; - Tensor t_zero = zeros_like(grad, output_mem_config); - grad_tensor.emplace_back(t_zero); - return grad_tensor; -} -std::vector round_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _round_bw)(grad, output_mem_config); -} - std::vector _unary_div_no_nan_bw( const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { std::vector grad_tensor; diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp index 18a5e4b99d8..905f3ad2adf 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp @@ -101,21 +101,12 @@ std::vector> tanh_bw( std::vector fill_bw( const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); -std::vector log_bw( - const Tensor& grad, - const Tensor& input, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); std::vector binary_le_bw( const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); -std::vector abs_bw( - const Tensor& grad, - const Tensor& input, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - std::vector complex_abs_bw( const Tensor& grad, const Tensor& input, @@ -284,21 +275,6 @@ std::vector reciprocal_bw( const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); -std::vector relu6_bw( - const Tensor& grad, - const Tensor& input, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector silu_bw( - const Tensor& grad, - const Tensor& input, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector selu_bw( - const Tensor& grad, - const Tensor& input, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - std::vector square_bw( const Tensor& grad, const Tensor& input, @@ -422,12 +398,6 @@ std::vector complex_sub_bw( std::vector repeat_bw( const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config); -std::vector floor_bw( - const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector round_bw( - const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - std::vector unary_div_no_nan_bw( const Tensor& grad, const Tensor& input, diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp index cffa0bde7d6..64ab793d044 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp @@ -328,38 +328,6 @@ namespace tt::tt_metal::detail{ "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); - m_tensor.def("log_bw", &tt::tt_metal::log_bw, - py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( - Performs backward operations for logarithm of ``input`` tensors with given ``grad``. - - Input tensors must have BFLOAT16 data type. - - Output tensors will have BFLOAT16 data type. - - .. csv-table:: - :header: "Argument", "Description", "Data type", "Valid range", "Required" - - "grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "input", "Tensor add is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" - )doc"); - - m_tensor.def("abs_bw", &tt::tt_metal::abs_bw, - py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( - Performs backward operations for abs of ``input`` tensors with given ``grad``. - - Input tensors must have BFLOAT16 data type. - - Output tensor will have BFLOAT16 data type. - - .. csv-table:: - :header: "Argument", "Description", "Data type", "Valid range", "Required" - - "grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "input", "Tensor add is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" - )doc"); - m_tensor.def("complex_abs_bw", py::overload_cast(&complex_abs_bw), py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( Performs backward operations for abs of complex ``input`` tensor with given ``grad``. @@ -823,56 +791,6 @@ namespace tt::tt_metal::detail{ "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); - m_tensor.def("relu6_bw", &tt::tt_metal::relu6_bw, - py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( - Returns an tensor of backward operation of relu6 for ``input`` tensor and ``grad`` tensor. - - Input tensors must have BFLOAT16 data type. - - Output tensors will have BFLOAT16 data type. - - .. csv-table:: - :header: "Argument", "Description", "Data type", "Valid range", "Required" - - "grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "input", "Tensor relu6 is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" - )doc"); - - - m_tensor.def("silu_bw", &tt::tt_metal::silu_bw, - py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( - Performs backward operations for silu sin of ``input`` tensors with given ``grad``. - - Input tensors must have BFLOAT16 data type. - - Output tensors will have BFLOAT16 data type. - - .. csv-table:: - :header: "Argument", "Description", "Data type", "Valid range", "Required" - - "grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "input", "Tensor silu_bw is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" - )doc"); - - - m_tensor.def("selu_bw", &tt::tt_metal::selu_bw, - py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( - Performs backward operations for selu sin of ``input`` tensors with given ``grad``. - - Input tensors must have BFLOAT16 data type. - - Output tensors will have BFLOAT16 data type. - - .. csv-table:: - :header: "Argument", "Description", "Data type", "Valid range", "Required" - - "grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "input", "Tensor selu_bw is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" - )doc"); - m_tensor.def("square_bw", &tt::tt_metal::square_bw, py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( Performs backward square operations on ``input`` tensors with given ``grad``. @@ -1153,38 +1071,6 @@ namespace tt::tt_metal::detail{ "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); - - m_tensor.def("floor_bw", &tt::tt_metal::floor_bw, - py::arg("grad").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( - Returns an tensor of zeros like ``grad`` tensor - - Input tensor must have BFLOAT16 data type. - - Output tensor will have BFLOAT16 data type. - - .. csv-table:: - :header: "Argument", "Description", "Data type", "Valid range", "Required" - - "grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" - )doc"); - - - m_tensor.def("round_bw", &tt::tt_metal::round_bw, - py::arg("grad").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( - Returns an tensor of zeros like ``grad`` tensor - - Input tensor must have BFLOAT16 data type. - - Output tensor will have BFLOAT16 data type. - - .. csv-table:: - :header: "Argument", "Description", "Data type", "Valid range", "Required" - - "grad", "Gradient Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" - )doc"); - m_tensor.def("unary_div_no_nan_bw", &tt::tt_metal::unary_div_no_nan_bw, py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("scalar") = 1.0f, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( Performs backward operations for division with given ``grad`` and ``scalar`` with no nan. diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp index 775252685fa..bbcffb294c3 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp @@ -144,7 +144,6 @@ std::vector _log_sigmoid_bw(const Tensor& grad, const Tensor& input, con return grad_tensor; } - std::vector _fill_zero_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor result = tt::tt_metal::zeros_like(grad, output_mem_config); @@ -264,7 +263,6 @@ std::vector _logit_bw(const Tensor& grad, const Tensor& input, const Mem return grad_tensor; } - std::vector _hardshrink_bw( const Tensor& grad, const Tensor& input_tensor, float lambd, const MemoryConfig& output_mem_config) { std::vector grad_tensor; @@ -358,6 +356,103 @@ std::vector _rpow_bw( } +std::vector _floor_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { + std::vector grad_tensor; + Tensor t_zero = ttnn::operations::creation::zeros_like(grad); + grad_tensor.emplace_back(t_zero); + return grad_tensor; +} + +std::vector _round_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { + std::vector grad_tensor; + Tensor t_zero = ttnn::operations::creation::zeros_like(grad); + grad_tensor.emplace_back(t_zero); + return grad_tensor; +} + +std::vector _log_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { + std::vector grad_tensor; + Tensor grad_a = ttnn::multiply(grad, ttnn::reciprocal(input, output_mem_config), std::nullopt, output_mem_config); + Tensor t_inf = ttnn::operations::creation::full_like(input, std::numeric_limits::infinity()); + Tensor t_nan = ttnn::operations::creation::full_like(input, std::nanf("")); + grad_tensor.emplace_back(where( + ttnn::eqz(input, output_mem_config), + where( + ttnn::eqz(grad, output_mem_config), + t_nan, + ttnn::multiply(t_inf, ttnn::sign(grad, output_mem_config), std::nullopt, output_mem_config), + output_mem_config), + grad_a, + output_mem_config)); + return grad_tensor; +} + +std::vector _relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { + std::vector grad_tensor; + Tensor zero_tensor = ttnn::operations::creation::zeros_like(input); + Tensor one_tensor = ttnn::operations::creation::ones_like(input); + Tensor six_tensor = ttnn::operations::creation::full_like(input, 6); + Tensor grad_result = + where(ttnn::le(input, zero_tensor, std::nullopt, output_mem_config), zero_tensor, six_tensor, output_mem_config); + grad_result = where( + ttnn::logical_and( + ttnn::gtz(input, output_mem_config), + ttnn::lt(input, six_tensor, std::nullopt, output_mem_config), + std::nullopt, + output_mem_config), + grad, + grad_result, + output_mem_config); + grad_result = + where(ttnn::ge(input, six_tensor, std::nullopt, output_mem_config), zero_tensor, grad_result, output_mem_config); + + grad_tensor.emplace_back(grad_result); + return grad_tensor; +} + +std::vector _abs_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { + std::vector grad_tensor; + Tensor result = ttnn::multiply(grad, ttnn::sign(input, output_mem_config), std::nullopt, output_mem_config); + grad_tensor.emplace_back(result); + return grad_tensor; +} + +// Silu +// result: grad * sigmoid_result * (1 + input * (1 - sigmoid_result)) +std::vector _silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { + std::vector grad_tensor; + Tensor grad_sigmoid = ttnn::multiply(grad, ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config); + Tensor add_sub = ttnn::add( + ttnn::multiply(ttnn::subtract(ttnn::operations::creation::full_like(input, 1.0f) , ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config), + input, + std::nullopt, + output_mem_config), + 1.0f, + std::nullopt, + output_mem_config); + Tensor grad_result = ttnn::multiply(grad_sigmoid, add_sub, std::nullopt, output_mem_config); + + grad_tensor.emplace_back(grad_result); + return grad_tensor; +} + +// Selu +// result: torch.where(input > 0, grad * lambd, grad * lambd * alpha * torch.exp(input)) +std::vector _selu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { + std::vector grad_tensor; + Tensor grad_lambd = ttnn::multiply(grad, 1.0507f, std::nullopt, output_mem_config); + Tensor grad_result = where( + ttnn::gtz(input, output_mem_config), + grad_lambd, + ttnn::multiply(ttnn::multiply(grad_lambd, 1.673260f, std::nullopt, output_mem_config), + ttnn::exp(input, false, output_mem_config), + std::nullopt, + output_mem_config), + output_mem_config); + grad_tensor.emplace_back(grad_result); + return grad_tensor; +} + std::function(const Tensor&, const Tensor&, const MemoryConfig&)> UnaryBackwardFunction::get_function_type1(UnaryBackwardOpType OpType){ switch (OpType) { case UnaryBackwardOpType::ASSIGN_BW: @@ -388,6 +483,20 @@ std::function(const Tensor&, const Tensor&, const Memo return _relu_bw; case UnaryBackwardOpType::LOGIT_BW: return _logit_bw; + case UnaryBackwardOpType::FLOOR_BW: + return _floor_bw; + case UnaryBackwardOpType::ROUND_BW: + return _round_bw; + case UnaryBackwardOpType::LOG_BW: + return _log_bw; + case UnaryBackwardOpType::RELU6_BW: + return _relu6_bw; + case UnaryBackwardOpType::ABS_BW: + return _abs_bw; + case UnaryBackwardOpType::SILU_BW: + return _silu_bw; + case UnaryBackwardOpType::SELU_BW: + return _selu_bw; default: TT_ASSERT(false && "Undefined op type"); return 0; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp index 410fb4bd098..4851ff39af3 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp @@ -40,6 +40,13 @@ enum class UnaryBackwardOpType { ELU_BW, CELU_BW, RPOW_BW, + FLOOR_BW, + ROUND_BW, + LOG_BW, + RELU6_BW, + ABS_BW, + SILU_BW, + SELU_BW, }; struct UnaryBackwardFunction{ diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp index c59943916f8..2a5bf896106 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp @@ -92,6 +92,12 @@ constexpr auto leaky_relu_bw = ttnn::register_operation>("ttnn::elu_bw"); constexpr auto celu_bw = ttnn::register_operation>("ttnn::celu_bw"); constexpr auto rpow_bw = ttnn::register_operation>("ttnn::rpow_bw"); - +constexpr auto floor_bw = ttnn::register_operation>("ttnn::floor_bw"); +constexpr auto round_bw = ttnn::register_operation>("ttnn::round_bw"); +constexpr auto log_bw = ttnn::register_operation>("ttnn::log_bw"); +constexpr auto relu6_bw = ttnn::register_operation>("ttnn::relu6_bw"); +constexpr auto abs_bw = ttnn::register_operation>("ttnn::abs_bw"); +constexpr auto silu_bw = ttnn::register_operation>("ttnn::silu_bw"); +constexpr auto selu_bw = ttnn::register_operation>("ttnn::selu_bw"); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp index f1a2e12c1a7..efd7eb2b5b4 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp @@ -262,6 +262,11 @@ void py_module(py::module& module) { module, ttnn::logit_bw, R"doc(Performs backward operations for logit on :attr:`input_tensor` or attr:`input_tensor_a` with given :attr:`grad_tensor`.)doc"); + + detail::bind_unary_backward( + module, + ttnn::floor_bw, + R"doc(Performs backward operations for floor on :attr:`input_tensor` with given :attr:`grad_tensor`)doc"); detail::bind_unary_backward( module, @@ -297,6 +302,37 @@ void py_module(py::module& module) { module, ttnn::rpow_bw, R"doc(Performs backward operations for rpow on :attr:`input_tensor`, :attr:`exponent` with given :attr:`grad_tensor`.)doc"); + + detail::bind_unary_backward( + module, + ttnn::round_bw, + R"doc(Performs backward operations for round on :attr:`input_tensor` with given :attr:`grad_tensor`.)doc"); + + detail::bind_unary_backward( + module, + ttnn::log_bw, + R"doc(Performs backward operations for logarithm on :attr:`input_tensor` with given :attr:`grad_tensor`)doc"); + + detail::bind_unary_backward( + module, + ttnn::relu6_bw, + R"doc(Performs backward operations for relu6 on :attr:`input_tensor` with given :attr:`grad_tensor`)doc"); + + detail::bind_unary_backward( + module, + ttnn::abs_bw, + R"doc(Performs backward operations for abs on :attr:`input_tensor` with given :attr:`grad_tensor`)doc"); + + detail::bind_unary_backward( + module, + ttnn::silu_bw, + R"doc(Performs backward operations for silu on :attr:`input_tensor` with given :attr:`grad_tensor`)doc"); + + detail::bind_unary_backward( + module, + ttnn::selu_bw, + R"doc(Performs backward operations for selu on :attr:`input_tensor` with given :attr:`grad_tensor`)doc"); + } } // namespace binary_backward