Skip to content

Commit

Permalink
#9628: Update addalpha_bw test file to test default value and update …
Browse files Browse the repository at this point in the history
…golden function
  • Loading branch information
VirdhatchaniKN committed Jul 20, 2024
1 parent 9e122ea commit 490e37b
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,28 @@ def test_bw_addalpha(input_shapes, alpha, device):
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_addalpha_wo_alpha(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)

tt_output_tensor_on_device = ttnn.addalpha_bw(grad_tensor, input_tensor, other_tensor)

golden_function = ttnn.get_golden_function(ttnn.addalpha_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
Expand Down Expand Up @@ -65,14 +87,50 @@ def test_bw_addalpha_with_opt_output(input_shapes, alpha, device, are_required_o
queue_id=cq_id,
)

in_data.retain_grad()
other_data.retain_grad()
golden_function = ttnn.get_golden_function(ttnn.addalpha_bw)
golden_tensor = golden_function(grad_data, in_data, other_data, alpha)

status = True
for i in range(len(are_required_outputs)):
if are_required_outputs[i]:
status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]])
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True], [False, False]])
def test_bw_addalpha_with_opt_output_wo_alpha(input_shapes, device, are_required_outputs):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -70, 90, device)
input_grad = None
other_grad = None

pyt_y = torch.add(in_data, other_data, alpha=alpha)
if are_required_outputs[0]:
_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)
if are_required_outputs[1]:
_, other_grad = data_gen_with_range(input_shapes, -1, 1, device)

pyt_y.backward(gradient=grad_data)
cq_id = 0
tt_output_tensor_on_device = ttnn.addalpha_bw(
grad_tensor,
input_tensor,
other_tensor,
are_required_outputs=are_required_outputs,
input_a_grad=input_grad,
input_b_grad=other_grad,
queue_id=cq_id,
)

golden_tensor = [in_data.grad, other_data.grad]
golden_function = ttnn.get_golden_function(ttnn.addalpha_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = True
for i in range(len(are_required_outputs)):
Expand Down
2 changes: 1 addition & 1 deletion ttnn/ttnn/operations/binary_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _golden_function_comparison_ops(torch_op, grad_tensor, input_tensor_a, input

ttnn.attach_golden_function(
ttnn.addalpha_bw,
golden_function=lambda grad, a, b, alpha, *args, **kwargs: _golden_function_backward_with_float(
golden_function=lambda grad, a, b, alpha=None, *args, **kwargs: _golden_function_backward_with_float(
torch.add, grad, a, b, alpha, *args, **kwargs
),
)
Expand Down

0 comments on commit 490e37b

Please sign in to comment.