Skip to content

Commit

Permalink
#6443: Fix sqrt and update backward ops silu, selu, tan, tanh, tanhsh…
Browse files Browse the repository at this point in the history
…rink
  • Loading branch information
ruthreshx committed Apr 5, 2024
1 parent f4a84d1 commit e428cf3
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt, compare_results
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -17,8 +17,8 @@
),
)
def test_bw_selu(input_shapes, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)
in_data, input_tensor = data_gen_with_range(input_shapes, -1e4, 1e4, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -1e4, 1e4, device)

pyt_y = torch.nn.functional.selu(in_data)

Expand All @@ -29,6 +29,6 @@ def test_bw_selu(input_shapes, device):
pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)

assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt, compare_results
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -17,8 +17,8 @@
),
)
def test_bw_sigmoid(input_shapes, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)
in_data, input_tensor = data_gen_with_range(input_shapes, -1, 1, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)

pyt_y = torch.sigmoid(in_data)

Expand All @@ -30,5 +30,5 @@ def test_bw_sigmoid(input_shapes, device):

golden_tensor = [in_data.grad]

comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor, 0.90)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor, 0.90)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt, compare_results
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -17,8 +17,8 @@
),
)
def test_bw_silu(input_shapes, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)
in_data, input_tensor = data_gen_with_range(input_shapes, -1e4, 1e4, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -1e4, 1e4, device)

pyt_y = torch.nn.functional.silu(in_data)

Expand All @@ -29,6 +29,6 @@ def test_bw_silu(input_shapes, device):
pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)

assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt, compare_results
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -17,8 +17,8 @@
),
)
def test_bw_sqrt(input_shapes, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)
in_data, input_tensor = data_gen_with_range(input_shapes, -1e4, 1e4, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -1e4, 1e4, device)

pyt_y = torch.sqrt(in_data)

Expand All @@ -29,5 +29,5 @@ def test_bw_sqrt(input_shapes, device):
pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
status = compare_results(tt_output_tensor_on_device, golden_tensor)
status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt, compare_results
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import (
data_gen_with_range,
compare_results,
)


@pytest.mark.parametrize(
Expand All @@ -17,13 +20,10 @@
),
)
def test_bw_tan(input_shapes, device):
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)
# tt tan supports input range [-1.45, 1.45]
in_data = torch.Tensor(size=input_shapes).uniform_(-1.45, 1.45)
in_data.requires_grad = True
input_tensor = (
tt_lib.tensor.Tensor(in_data, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device)
)
in_data, input_tensor = data_gen_with_range(input_shapes, -1.45, 1.45, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -1e4, 1e4, device)

pyt_y = torch.tan(in_data)

tt_output_tensor_on_device = tt_lib.tensor.tan_bw(grad_tensor, input_tensor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_results, data_gen_pt_tt
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -17,8 +17,9 @@
),
)
def test_bw_tanh(input_shapes, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)
# tt tan supports input range [-1.45, 1.45]
in_data, input_tensor = data_gen_with_range(input_shapes, -1.45, 1.45, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -1e4, 1e4, device)
pyt_y = torch.tanh(in_data)

tt_output_tensor_on_device = tt_lib.tensor.tanh_bw(grad_tensor, input_tensor)
Expand All @@ -29,5 +30,5 @@ def test_bw_tanh(input_shapes, device):

golden_tensor = [in_data.grad]

status = compare_results(tt_output_tensor_on_device, golden_tensor, 0.95)
status = compare_pcc(tt_output_tensor_on_device, golden_tensor, 0.95)
assert status
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_results, data_gen_pt_tt
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range


@pytest.mark.parametrize(
Expand All @@ -17,12 +17,8 @@
),
)
def test_bw_tanhshrink(input_shapes, device):
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)
in_data = torch.Tensor(size=input_shapes).uniform_()
in_data.requires_grad = True
input_tensor = (
tt_lib.tensor.Tensor(in_data, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device)
)
in_data, input_tensor = data_gen_with_range(input_shapes, -1e4, 1e4, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -1e4, 1e4, device)

pyt_y = torch.nn.functional.tanhshrink(in_data)
tt_output_tensor_on_device = tt_lib.tensor.tanhshrink_bw(grad_tensor, input_tensor)
Expand All @@ -32,5 +28,5 @@ def test_bw_tanhshrink(input_shapes, device):
pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ std::vector<Tensor> _sqrt_bw(const Tensor& grad, const Tensor& input, const Memo
Tensor sqrt_result = sqrt(input, output_mem_config);
Tensor result = mul(grad, recip(mul_unary(sqrt_result, 2.0, output_mem_config), output_mem_config), std::nullopt, output_mem_config);
float t_nan = std::nanf("");
result = where(ltz(input, output_mem_config), t_nan, result, output_mem_config);
result = where(lez(input, output_mem_config), t_nan, result, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}
Expand Down

0 comments on commit e428cf3

Please sign in to comment.