From 490e37b1dd85f5a2cd88b2d6440fe30f21b13419 Mon Sep 17 00:00:00 2001 From: VirdhatchaniKN Date: Sat, 20 Jul 2024 12:02:34 +0000 Subject: [PATCH] #9628: Update addalpha_bw test file to test default value and update golden function --- .../backward/test_backward_addalpha.py | 68 +++++++++++++++++-- ttnn/ttnn/operations/binary_backward.py | 2 +- 2 files changed, 64 insertions(+), 6 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_addalpha.py b/tests/ttnn/unit_tests/operations/backward/test_backward_addalpha.py index 1697393e6286..406521cef609 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_addalpha.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_addalpha.py @@ -31,6 +31,28 @@ def test_bw_addalpha(input_shapes, alpha, device): assert status +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_addalpha_wo_alpha(input_shapes, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + other_data, other_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) + + tt_output_tensor_on_device = ttnn.addalpha_bw(grad_tensor, input_tensor, other_tensor) + + golden_function = ttnn.get_golden_function(ttnn.addalpha_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) + + status = compare_pcc(tt_output_tensor_on_device, golden_tensor) + assert status + + @pytest.mark.parametrize( "input_shapes", ( @@ -65,14 +87,50 @@ def test_bw_addalpha_with_opt_output(input_shapes, alpha, device, are_required_o queue_id=cq_id, ) - in_data.retain_grad() - other_data.retain_grad() + golden_function = ttnn.get_golden_function(ttnn.addalpha_bw) + golden_tensor = golden_function(grad_data, in_data, other_data, alpha) + + status = True + for i in range(len(are_required_outputs)): + if are_required_outputs[i]: + status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]]) + assert status + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True], [False, False]]) +def test_bw_addalpha_with_opt_output_wo_alpha(input_shapes, device, are_required_outputs): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -70, 90, device) + input_grad = None + other_grad = None - pyt_y = torch.add(in_data, other_data, alpha=alpha) + if are_required_outputs[0]: + _, input_grad = data_gen_with_range(input_shapes, -1, 1, device) + if are_required_outputs[1]: + _, other_grad = data_gen_with_range(input_shapes, -1, 1, device) - pyt_y.backward(gradient=grad_data) + cq_id = 0 + tt_output_tensor_on_device = ttnn.addalpha_bw( + grad_tensor, + input_tensor, + other_tensor, + are_required_outputs=are_required_outputs, + input_a_grad=input_grad, + input_b_grad=other_grad, + queue_id=cq_id, + ) - golden_tensor = [in_data.grad, other_data.grad] + golden_function = ttnn.get_golden_function(ttnn.addalpha_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) status = True for i in range(len(are_required_outputs)): diff --git a/ttnn/ttnn/operations/binary_backward.py b/ttnn/ttnn/operations/binary_backward.py index 3ff1a4aa7d27..71c935f3e94b 100644 --- a/ttnn/ttnn/operations/binary_backward.py +++ b/ttnn/ttnn/operations/binary_backward.py @@ -181,7 +181,7 @@ def _golden_function_comparison_ops(torch_op, grad_tensor, input_tensor_a, input ttnn.attach_golden_function( ttnn.addalpha_bw, - golden_function=lambda grad, a, b, alpha, *args, **kwargs: _golden_function_backward_with_float( + golden_function=lambda grad, a, b, alpha=None, *args, **kwargs: _golden_function_backward_with_float( torch.add, grad, a, b, alpha, *args, **kwargs ), )