Skip to content

Commit

Permalink
#6633: Update test files for logaddexp_exp2
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Mar 28, 2024
1 parent b3eefdf commit d8237d0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt, compare_results
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
Expand All @@ -17,10 +17,10 @@
),
)
def test_bw_logaddexp(input_shapes, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
other_data, other_tensor = data_gen_pt_tt(input_shapes, device, True)
in_data, input_tensor = data_gen_with_range(input_shapes, -10, 10, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -20, 20, device, True)

grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device)

tt_output_tensor_on_device = tt_lib.tensor.logaddexp_bw(grad_tensor, input_tensor, other_tensor)

Expand All @@ -32,5 +32,5 @@ def test_bw_logaddexp(input_shapes, device):
pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad, other_data.grad]
status = compare_results(tt_output_tensor_on_device, golden_tensor)
status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt, compare_results
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import (
data_gen_with_range,
compare_pcc,
)


@pytest.mark.parametrize(
Expand All @@ -17,10 +20,10 @@
),
)
def test_bw_logaddexp2(input_shapes, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
other_data, other_tensor = data_gen_pt_tt(input_shapes, device, True)
in_data, input_tensor = data_gen_with_range(input_shapes, -10, 10, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -20, 20, device, True)

grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device)

tt_output_tensor_on_device = tt_lib.tensor.logaddexp2_bw(grad_tensor, input_tensor, other_tensor)

Expand All @@ -32,5 +35,5 @@ def test_bw_logaddexp2(input_shapes, device):
pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad, other_data.grad]
status = compare_results(tt_output_tensor_on_device, golden_tensor)
status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status

0 comments on commit d8237d0

Please sign in to comment.