Skip to content

Commit

Permalink
#10071 : Move second set of Unary Backward ops to TTNN (#10038)
Browse files Browse the repository at this point in the history
* #10071: Merge floor_bw to TTNN

* #10071: Merge round_bw to TTNN

* #10071: Merge log_bw to TTNN

* #10071: Merge relu6_bw to TTNN

* #10071: Merge abs_bw to TTNN

* #10071: Merge silu_bw to TTNN

* #10071: Merge selu_bw to TTNN

* #10071: Fix rebase errors

* #10071: Update files

* #10071: Move test files to TTNN

* #10071: Update CPP files

---------

Co-authored-by: mouliraj-mcw <[email protected]>
  • Loading branch information
VirdhatchaniKN and mouliraj-mcw authored Jul 11, 2024
1 parent 99b2214 commit ad1fd9f
Show file tree
Hide file tree
Showing 23 changed files with 232 additions and 303 deletions.
7 changes: 7 additions & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@ Pointwise Unary
ttnn/elu_bw
ttnn/celu_bw
ttnn/rpow_bw
ttnn/floor_bw
ttnn/round_bw
ttnn/log_bw
ttnn/relu6_bw
ttnn/abs_bw
ttnn/silu_bw
ttnn/selu_bw

Pointwise Binary
================
Expand Down
14 changes: 0 additions & 14 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -814,10 +814,6 @@ Backward Operations

.. autofunction:: tt_lib.tensor.fill_bw

.. autofunction:: tt_lib.tensor.log_bw

.. autofunction:: tt_lib.tensor.abs_bw

.. autofunction:: tt_lib.tensor.complex_abs_bw

.. autofunction:: tt_lib.tensor.lt_bw
Expand Down Expand Up @@ -880,12 +876,6 @@ Backward Operations

.. autofunction:: tt_lib.tensor.reciprocal_bw

.. autofunction:: tt_lib.tensor.relu6_bw

.. autofunction:: tt_lib.tensor.silu_bw

.. autofunction:: tt_lib.tensor.selu_bw

.. autofunction:: tt_lib.tensor.square_bw

.. autofunction:: tt_lib.tensor.tanhshrink_bw
Expand Down Expand Up @@ -928,10 +918,6 @@ Backward Operations

.. autofunction:: tt_lib.tensor.repeat_bw

.. autofunction:: tt_lib.tensor.floor_bw

.. autofunction:: tt_lib.tensor.round_bw

.. autofunction:: tt_lib.tensor.unary_div_no_nan_bw

Loss Functions
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/abs_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.abs_bw:

ttnn.abs_bw
###########

.. autofunction:: ttnn.abs_bw
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/floor_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.floor_bw:

ttnn.floor_bw
#############

.. autofunction:: ttnn.floor_bw
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/log_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.log_bw:

ttnn.log_bw
###########

.. autofunction:: ttnn.log_bw
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/relu6_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.relu6_bw:

ttnn.relu6_bw
#############

.. autofunction:: ttnn.relu6_bw
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/round_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.round_bw:

ttnn.round_bw
#############

.. autofunction:: ttnn.round_bw
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/selu_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.selu_bw:

ttnn.selu_bw
############

.. autofunction:: ttnn.selu_bw
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/silu_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.silu_bw:

ttnn.silu_bw
############

.. autofunction:: ttnn.silu_bw
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc
import ttnn
from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -22,7 +22,7 @@ def test_bw_abs(input_shapes, device):

pyt_y = torch.abs(in_data)

tt_output_tensor_on_device = tt_lib.tensor.abs_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.abs_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range
import ttnn
from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -22,7 +22,7 @@ def test_bw_floor(input_shapes, device):

pyt_y = torch.floor(in_data)

tt_output_tensor_on_device = tt_lib.tensor.floor_bw(grad_tensor)
tt_output_tensor_on_device = ttnn.floor_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import (
import ttnn
from tests.ttnn.unit_tests.operations.backward.utility_funcs import (
data_gen_with_val,
compare_pcc,
data_gen_with_range,
Expand All @@ -23,7 +23,7 @@
def test_bw_log_0(input_shapes, device):
in_data, input_tensor = data_gen_with_val(input_shapes, device, True, val=0)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -1, 1, device)
tt_output_tensor_on_device = tt_lib.tensor.log_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.log_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand All @@ -47,7 +47,7 @@ def test_bw_log_0(input_shapes, device):
def test_bw_log(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)
tt_output_tensor_on_device = tt_lib.tensor.log_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.log_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range
import ttnn
from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -22,7 +22,7 @@ def test_bw_relu6(input_shapes, device):

pyt_y = torch.nn.functional.relu6(in_data)

tt_output_tensor_on_device = tt_lib.tensor.relu6_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.relu6_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range
import ttnn
from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -21,7 +21,7 @@ def test_bw_round(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -200, 201, device, required_grad=True)
pyt_y = torch.round(in_data)

tt_output_tensor_on_device = tt_lib.tensor.round_bw(grad_tensor)
tt_output_tensor_on_device = ttnn.round_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc
import ttnn
from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -22,7 +22,7 @@ def test_bw_selu(input_shapes, device):

pyt_y = torch.nn.functional.selu(in_data)

tt_output_tensor_on_device = tt_lib.tensor.selu_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.selu_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc
import ttnn
from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -22,7 +22,7 @@ def test_bw_silu(input_shapes, device):

pyt_y = torch.nn.functional.silu(in_data)

tt_output_tensor_on_device = tt_lib.tensor.silu_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.silu_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
120 changes: 0 additions & 120 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,36 +280,6 @@ std::vector<Tensor> ne_bw(const Tensor& grad, const MemoryConfig& output_mem_con
return operation::decorate_as_composite(__func__, _ne_bw)(grad, output_mem_config);
}

std::vector<Tensor> _log_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_a = ttnn::multiply(grad, ttnn::reciprocal(input, output_mem_config), std::nullopt, output_mem_config);
Tensor t_inf = full_like(input, std::numeric_limits<float>::infinity(), output_mem_config);
Tensor t_nan = full_like(input, std::nanf(""), output_mem_config);
grad_tensor.emplace_back(where(
ttnn::eqz(input, output_mem_config),
where(
ttnn::eqz(grad, output_mem_config),
t_nan,
ttnn::multiply(t_inf, ttnn::sign(grad, output_mem_config), std::nullopt, output_mem_config),
output_mem_config),
grad_a,
output_mem_config));
return grad_tensor;
}
std::vector<Tensor> log_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _log_bw)(grad, input, output_mem_config);
}

std::vector<Tensor> _abs_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor result = ttnn::multiply(grad, ttnn::sign(input, output_mem_config), std::nullopt, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}
std::vector<Tensor> abs_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _abs_bw)(grad, input, output_mem_config);
}

// bw(expm1) = grad * expm1(input) + 1
std::vector<Tensor> _expm1_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Expand Down Expand Up @@ -990,76 +960,6 @@ std::vector<Tensor> reciprocal_bw(const Tensor& grad, const Tensor& input, const
return operation::decorate_as_composite(__func__, _reciprocal_bw)(grad, input, output_mem_config);
}

std::vector<Tensor> _relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor zero_tensor = zeros_like(input, output_mem_config);
Tensor one_tensor = ones_like(input, output_mem_config);
Tensor six_tensor = full_like(input, 6, output_mem_config);
Tensor grad_result =
where(ttnn::le(input, zero_tensor, std::nullopt, output_mem_config), zero_tensor, six_tensor, output_mem_config);
grad_result = where(
ttnn::logical_and(
ttnn::gtz(input, output_mem_config),
ttnn::lt(input, six_tensor, std::nullopt, output_mem_config),
std::nullopt,
output_mem_config),
grad,
grad_result,
output_mem_config);
grad_result =
where(ttnn::ge(input, six_tensor, std::nullopt, output_mem_config), zero_tensor, grad_result, output_mem_config);

grad_tensor.emplace_back(grad_result);
return grad_tensor;
}
std::vector<Tensor> relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _relu6_bw)(grad, input, output_mem_config);
}


// Silu
// result: grad * sigmoid_result * (1 + input * (1 - sigmoid_result))
std::vector<Tensor> _silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_sigmoid = ttnn::multiply(grad, ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config);
Tensor add_sub = ttnn::add(
ttnn::multiply(ttnn::subtract(ttnn::full_like(input, 1.0f) , ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config),
input,
std::nullopt,
output_mem_config),
1.0f,
std::nullopt,
output_mem_config);
Tensor grad_result = ttnn::multiply(grad_sigmoid, add_sub, std::nullopt, output_mem_config);

grad_tensor.emplace_back(grad_result);
return grad_tensor;
}
std::vector<Tensor> silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _silu_bw)(grad, input, output_mem_config);
}

// Selu
// result: torch.where(input > 0, grad * lambd, grad * lambd * alpha * torch.exp(input))
std::vector<Tensor> _selu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_lambd = ttnn::multiply(grad, 1.0507f, std::nullopt, output_mem_config);
Tensor grad_result = where(
ttnn::gtz(input, output_mem_config),
grad_lambd,
ttnn::multiply(ttnn::multiply(grad_lambd, 1.673260f, std::nullopt, output_mem_config),
ttnn::exp(input, false, output_mem_config),
std::nullopt,
output_mem_config),
output_mem_config);
grad_tensor.emplace_back(grad_result);
return grad_tensor;
}
std::vector<Tensor> selu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _selu_bw)(grad, input, output_mem_config);
}


// Autoformat support
Tensor change_layout_to_tile(const Tensor& temp, const MemoryConfig& output_mem_config) {
auto formatted_input_tensor = temp;
Expand Down Expand Up @@ -1748,26 +1648,6 @@ std::vector<Tensor> repeat_bw(
return operation::decorate_as_composite(__func__, _repeat_bw)(grad, input, shape, output_mem_config);
}

std::vector<Tensor> _floor_bw(const Tensor& grad, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor t_zero = zeros_like(grad, output_mem_config);
grad_tensor.emplace_back(t_zero);
return grad_tensor;
}
std::vector<Tensor> floor_bw(const Tensor& grad, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _floor_bw)(grad, output_mem_config);
}

std::vector<Tensor> _round_bw(const Tensor& grad, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor t_zero = zeros_like(grad, output_mem_config);
grad_tensor.emplace_back(t_zero);
return grad_tensor;
}
std::vector<Tensor> round_bw(const Tensor& grad, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _round_bw)(grad, output_mem_config);
}

std::vector<Tensor> _unary_div_no_nan_bw(
const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Expand Down
Loading

0 comments on commit ad1fd9f

Please sign in to comment.