Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#10071 : Move second set of Unary Backward ops to TTNN #10038

Merged
merged 11 commits into from
Jul 11, 2024
7 changes: 7 additions & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@ Pointwise Unary
ttnn/elu_bw
ttnn/celu_bw
ttnn/rpow_bw
ttnn/floor_bw
ttnn/round_bw
ttnn/log_bw
ttnn/relu6_bw
ttnn/abs_bw
ttnn/silu_bw
ttnn/selu_bw

Pointwise Binary
================
Expand Down
14 changes: 0 additions & 14 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -814,10 +814,6 @@ Backward Operations

.. autofunction:: tt_lib.tensor.fill_bw

.. autofunction:: tt_lib.tensor.log_bw

.. autofunction:: tt_lib.tensor.abs_bw

.. autofunction:: tt_lib.tensor.complex_abs_bw

.. autofunction:: tt_lib.tensor.lt_bw
Expand Down Expand Up @@ -880,12 +876,6 @@ Backward Operations

.. autofunction:: tt_lib.tensor.reciprocal_bw

.. autofunction:: tt_lib.tensor.relu6_bw

.. autofunction:: tt_lib.tensor.silu_bw

.. autofunction:: tt_lib.tensor.selu_bw

.. autofunction:: tt_lib.tensor.square_bw

.. autofunction:: tt_lib.tensor.tanhshrink_bw
Expand Down Expand Up @@ -928,10 +918,6 @@ Backward Operations

.. autofunction:: tt_lib.tensor.repeat_bw

.. autofunction:: tt_lib.tensor.floor_bw

.. autofunction:: tt_lib.tensor.round_bw

.. autofunction:: tt_lib.tensor.unary_div_no_nan_bw

Loss Functions
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/abs_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.abs_bw:

ttnn.abs_bw
###########

.. autofunction:: ttnn.abs_bw
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/floor_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.floor_bw:

ttnn.floor_bw
#############

.. autofunction:: ttnn.floor_bw
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/log_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.log_bw:

ttnn.log_bw
###########

.. autofunction:: ttnn.log_bw
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/relu6_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.relu6_bw:

ttnn.relu6_bw
#############

.. autofunction:: ttnn.relu6_bw
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/round_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.round_bw:

ttnn.round_bw
#############

.. autofunction:: ttnn.round_bw
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/selu_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.selu_bw:

ttnn.selu_bw
############

.. autofunction:: ttnn.selu_bw
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/silu_bw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.silu_bw:

ttnn.silu_bw
############

.. autofunction:: ttnn.silu_bw
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc
import ttnn
from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -22,7 +22,7 @@ def test_bw_abs(input_shapes, device):

pyt_y = torch.abs(in_data)

tt_output_tensor_on_device = tt_lib.tensor.abs_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.abs_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range
import ttnn
from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -22,7 +22,7 @@ def test_bw_floor(input_shapes, device):

pyt_y = torch.floor(in_data)

tt_output_tensor_on_device = tt_lib.tensor.floor_bw(grad_tensor)
tt_output_tensor_on_device = ttnn.floor_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import (
import ttnn
from tests.ttnn.unit_tests.operations.backward.utility_funcs import (
data_gen_with_val,
compare_pcc,
data_gen_with_range,
Expand All @@ -23,7 +23,7 @@
def test_bw_log_0(input_shapes, device):
in_data, input_tensor = data_gen_with_val(input_shapes, device, True, val=0)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -1, 1, device)
tt_output_tensor_on_device = tt_lib.tensor.log_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.log_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand All @@ -47,7 +47,7 @@ def test_bw_log_0(input_shapes, device):
def test_bw_log(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)
tt_output_tensor_on_device = tt_lib.tensor.log_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.log_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range
import ttnn
from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -22,7 +22,7 @@ def test_bw_relu6(input_shapes, device):

pyt_y = torch.nn.functional.relu6(in_data)

tt_output_tensor_on_device = tt_lib.tensor.relu6_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.relu6_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range
import ttnn
from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -21,7 +21,7 @@ def test_bw_round(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -200, 201, device, required_grad=True)
pyt_y = torch.round(in_data)

tt_output_tensor_on_device = tt_lib.tensor.round_bw(grad_tensor)
tt_output_tensor_on_device = ttnn.round_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc
import ttnn
from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -22,7 +22,7 @@ def test_bw_selu(input_shapes, device):

pyt_y = torch.nn.functional.selu(in_data)

tt_output_tensor_on_device = tt_lib.tensor.selu_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.selu_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc
import ttnn
from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -22,7 +22,7 @@ def test_bw_silu(input_shapes, device):

pyt_y = torch.nn.functional.silu(in_data)

tt_output_tensor_on_device = tt_lib.tensor.silu_bw(grad_tensor, input_tensor)
tt_output_tensor_on_device = ttnn.silu_bw(grad_tensor, input_tensor)

in_data.retain_grad()

Expand Down
120 changes: 0 additions & 120 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,36 +280,6 @@ std::vector<Tensor> ne_bw(const Tensor& grad, const MemoryConfig& output_mem_con
return operation::decorate_as_composite(__func__, _ne_bw)(grad, output_mem_config);
}

std::vector<Tensor> _log_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_a = ttnn::multiply(grad, ttnn::reciprocal(input, output_mem_config), std::nullopt, output_mem_config);
Tensor t_inf = full_like(input, std::numeric_limits<float>::infinity(), output_mem_config);
Tensor t_nan = full_like(input, std::nanf(""), output_mem_config);
grad_tensor.emplace_back(where(
ttnn::eqz(input, output_mem_config),
where(
ttnn::eqz(grad, output_mem_config),
t_nan,
ttnn::multiply(t_inf, ttnn::sign(grad, output_mem_config), std::nullopt, output_mem_config),
output_mem_config),
grad_a,
output_mem_config));
return grad_tensor;
}
std::vector<Tensor> log_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _log_bw)(grad, input, output_mem_config);
}

std::vector<Tensor> _abs_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor result = ttnn::multiply(grad, ttnn::sign(input, output_mem_config), std::nullopt, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}
std::vector<Tensor> abs_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _abs_bw)(grad, input, output_mem_config);
}

// bw(expm1) = grad * expm1(input) + 1
std::vector<Tensor> _expm1_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Expand Down Expand Up @@ -990,76 +960,6 @@ std::vector<Tensor> reciprocal_bw(const Tensor& grad, const Tensor& input, const
return operation::decorate_as_composite(__func__, _reciprocal_bw)(grad, input, output_mem_config);
}

std::vector<Tensor> _relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor zero_tensor = zeros_like(input, output_mem_config);
Tensor one_tensor = ones_like(input, output_mem_config);
Tensor six_tensor = full_like(input, 6, output_mem_config);
Tensor grad_result =
where(ttnn::le(input, zero_tensor, std::nullopt, output_mem_config), zero_tensor, six_tensor, output_mem_config);
grad_result = where(
ttnn::logical_and(
ttnn::gtz(input, output_mem_config),
ttnn::lt(input, six_tensor, std::nullopt, output_mem_config),
std::nullopt,
output_mem_config),
grad,
grad_result,
output_mem_config);
grad_result =
where(ttnn::ge(input, six_tensor, std::nullopt, output_mem_config), zero_tensor, grad_result, output_mem_config);

grad_tensor.emplace_back(grad_result);
return grad_tensor;
}
std::vector<Tensor> relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _relu6_bw)(grad, input, output_mem_config);
}


// Silu
// result: grad * sigmoid_result * (1 + input * (1 - sigmoid_result))
std::vector<Tensor> _silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_sigmoid = ttnn::multiply(grad, ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config);
Tensor add_sub = ttnn::add(
ttnn::multiply(ttnn::subtract(ttnn::full_like(input, 1.0f) , ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config),
input,
std::nullopt,
output_mem_config),
1.0f,
std::nullopt,
output_mem_config);
Tensor grad_result = ttnn::multiply(grad_sigmoid, add_sub, std::nullopt, output_mem_config);

grad_tensor.emplace_back(grad_result);
return grad_tensor;
}
std::vector<Tensor> silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _silu_bw)(grad, input, output_mem_config);
}

// Selu
// result: torch.where(input > 0, grad * lambd, grad * lambd * alpha * torch.exp(input))
std::vector<Tensor> _selu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_lambd = ttnn::multiply(grad, 1.0507f, std::nullopt, output_mem_config);
Tensor grad_result = where(
ttnn::gtz(input, output_mem_config),
grad_lambd,
ttnn::multiply(ttnn::multiply(grad_lambd, 1.673260f, std::nullopt, output_mem_config),
ttnn::exp(input, false, output_mem_config),
std::nullopt,
output_mem_config),
output_mem_config);
grad_tensor.emplace_back(grad_result);
return grad_tensor;
}
std::vector<Tensor> selu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _selu_bw)(grad, input, output_mem_config);
}


// Autoformat support
Tensor change_layout_to_tile(const Tensor& temp, const MemoryConfig& output_mem_config) {
auto formatted_input_tensor = temp;
Expand Down Expand Up @@ -1748,26 +1648,6 @@ std::vector<Tensor> repeat_bw(
return operation::decorate_as_composite(__func__, _repeat_bw)(grad, input, shape, output_mem_config);
}

std::vector<Tensor> _floor_bw(const Tensor& grad, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor t_zero = zeros_like(grad, output_mem_config);
grad_tensor.emplace_back(t_zero);
return grad_tensor;
}
std::vector<Tensor> floor_bw(const Tensor& grad, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _floor_bw)(grad, output_mem_config);
}

std::vector<Tensor> _round_bw(const Tensor& grad, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor t_zero = zeros_like(grad, output_mem_config);
grad_tensor.emplace_back(t_zero);
return grad_tensor;
}
std::vector<Tensor> round_bw(const Tensor& grad, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _round_bw)(grad, output_mem_config);
}

std::vector<Tensor> _unary_div_no_nan_bw(
const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Expand Down
Loading
Loading