diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_add.py b/tests/ttnn/unit_tests/operations/backward/test_backward_add.py index e7e5eb49434d..9bcc1ce30d8f 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_add.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_add.py @@ -23,14 +23,8 @@ def test_bw_add(input_shapes, device): tt_output_tensor_on_device = ttnn.add_bw(grad_tensor, input_tensor, other_tensor) - in_data.retain_grad() - other_data.retain_grad() - - pyt_y = torch.add(in_data, other_data) - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad, other_data.grad] + golden_function = ttnn.get_golden_function(ttnn.add_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status @@ -69,14 +63,8 @@ def test_bw_add_with_opt_output(input_shapes, device, are_required_outputs): queue_id=cq_id, ) - in_data.retain_grad() - other_data.retain_grad() - - pyt_y = torch.add(in_data, other_data) - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad, other_data.grad] + golden_function = ttnn.get_golden_function(ttnn.add_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) status = True for i in range(len(are_required_outputs)): @@ -100,13 +88,8 @@ def test_bw_unary_add(input_shapes, alpha, device): tt_output_tensor_on_device = ttnn.add_bw(grad_tensor, input_tensor, alpha=alpha) - in_data.retain_grad() - - pyt_y = torch.add(in_data, torch.tensor(alpha)) - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad] + golden_function = ttnn.get_golden_function(ttnn.add_bw) + golden_tensor = golden_function(grad_data, in_data, alpha) status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_assign.py b/tests/ttnn/unit_tests/operations/backward/test_backward_assign.py index c35439aba66a..7a2d7395cab7 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_assign.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_assign.py @@ -22,13 +22,8 @@ def test_bw_unary_assign(input_shapes, device): tt_output_tensor_on_device = ttnn.assign_bw(grad_tensor, input_tensor) - in_data.retain_grad() - - pyt_y = torch.clone(in_data) - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad] + golden_function = ttnn.get_golden_function(ttnn.assign_bw) + golden_tensor = golden_function(grad_data, in_data) status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status @@ -50,12 +45,7 @@ def test_bw_binary_assign(input_shapes, device): tt_output_tensor_on_device = ttnn.assign_bw(grad_tensor, input_tensor, other_tensor) - in_data.retain_grad() - - pyt_y = torch.clone(in_data) - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad, in_data.grad] + golden_function = ttnn.get_golden_function(ttnn.assign_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_bias_gelu.py b/tests/ttnn/unit_tests/operations/backward/test_backward_bias_gelu.py index 4e159530fb45..9ce2c7b55f58 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_bias_gelu.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_bias_gelu.py @@ -27,16 +27,9 @@ def test_bw_binary_bias_gelu(input_shapes, approximate, device): in_data_a, input_tensor_a = data_gen_with_range(input_shapes, -100, 100, device, True) in_data_b, input_tensor_b = data_gen_with_range(input_shapes, -10, 10, device, True) grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device) - in_data = in_data_a + in_data_b - - pyt_y = torch.nn.functional.gelu(in_data, approximate=approximate) - tt_output_tensor_on_device = ttnn.bias_gelu_bw(grad_tensor, input_tensor_a, input_tensor_b, approximate=approximate) - in_data.retain_grad() - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad, in_data.grad] + golden_function = ttnn.get_golden_function(ttnn.bias_gelu_bw) + golden_tensor = golden_function(grad_data, in_data_a, in_data_b, approximate) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass @@ -66,17 +59,11 @@ def test_bw_binary_bias_gelu(input_shapes, approximate, device): def test_bw_bias_gelu_unary(input_shapes, approximate, bias, device): in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device, True) - in_data = in_data + bias - - pyt_y = torch.nn.functional.gelu(in_data, approximate=approximate) tt_output_tensor_on_device = ttnn.bias_gelu_bw(grad_tensor, input_tensor, bias, approximate=approximate) - in_data.retain_grad() - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad] + golden_function = ttnn.get_golden_function(ttnn.bias_gelu_bw) + golden_tensor = golden_function(grad_data, in_data, bias, approximate) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass @@ -99,16 +86,10 @@ def test_bw_bias_gelu_unary(input_shapes, approximate, bias, device): def test_bw_bias_gelu_unary_default(input_shapes, bias, device): in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device, True) - in_data = in_data + bias - - pyt_y = torch.nn.functional.gelu(in_data) tt_output_tensor_on_device = ttnn.bias_gelu_bw(grad_tensor, input_tensor, bias) - in_data.retain_grad() - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad] + golden_function = ttnn.get_golden_function(ttnn.bias_gelu_bw) + golden_tensor = golden_function(grad_data, in_data, bias) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_div.py b/tests/ttnn/unit_tests/operations/backward/test_backward_div.py index 9173df3f84c3..dc5f44b71a4c 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_div.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_div.py @@ -32,18 +32,13 @@ def test_bw_div_binary(input_shapes, round_mode, device): grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) other_data, other_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) - pyt_y = torch.div(in_data, other_data, rounding_mode=round_mode) + golden_function = ttnn.get_golden_function(ttnn.div_bw) + golden_tensor = golden_function(grad_data, in_data, other_data, round_mode) if round_mode == None: round_mode = "None" tt_output_tensor_on_device = ttnn.div_bw(grad_tensor, input_tensor, other_tensor, round_mode=round_mode) - in_data.retain_grad() - other_data.retain_grad() - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad, other_data.grad] status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status @@ -61,16 +56,10 @@ def test_bw_div_binary_default(input_shapes, device): grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) other_data, other_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) - pyt_y = torch.div(in_data, other_data) + golden_function = ttnn.get_golden_function(ttnn.div_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) tt_output_tensor_on_device = ttnn.div_bw(grad_tensor, input_tensor, other_tensor) - - in_data.retain_grad() - other_data.retain_grad() - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad, other_data.grad] status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status @@ -99,15 +88,10 @@ def test_bw_unary_div_0(input_shapes, scalar, round_mode, device): tt_output_tensor_on_device = ttnn.div_bw(grad_tensor, input_tensor, scalar, round_mode=round_mode) - in_data.retain_grad() - if round_mode == "None": round_mode = None - pyt_y = torch.div(in_data, torch.tensor(scalar), rounding_mode=round_mode) - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad] + golden_function = ttnn.get_golden_function(ttnn.div_bw) + golden_tensor = golden_function(grad_data, in_data, scalar, round_mode) status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status @@ -140,11 +124,8 @@ def test_bw_unary_div(input_shapes, scalar, round_mode, device): if round_mode == "None": round_mode = None - pyt_y = torch.div(in_data, torch.tensor(scalar), rounding_mode=round_mode) - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad] + golden_function = ttnn.get_golden_function(ttnn.div_bw) + golden_tensor = golden_function(grad_data, in_data, scalar, round_mode) status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status @@ -166,13 +147,8 @@ def test_bw_unary_div_0_default(input_shapes, scalar, device): tt_output_tensor_on_device = ttnn.div_bw(grad_tensor, input_tensor, scalar) - in_data.retain_grad() - - pyt_y = torch.div(in_data, torch.tensor(scalar)) - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad] + golden_function = ttnn.get_golden_function(ttnn.div_bw) + golden_tensor = golden_function(grad_data, in_data, scalar) status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status @@ -193,13 +169,8 @@ def test_bw_unary_div_default(input_shapes, scalar, device): tt_output_tensor_on_device = ttnn.div_bw(grad_tensor, input_tensor, scalar) - in_data.retain_grad() - - pyt_y = torch.div(in_data, torch.tensor(scalar)) - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad] + golden_function = ttnn.get_golden_function(ttnn.div_bw) + golden_tensor = golden_function(grad_data, in_data, scalar) status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_eq.py b/tests/ttnn/unit_tests/operations/backward/test_backward_eq.py index 52868fa580f0..7b63f0b7a239 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_eq.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_eq.py @@ -19,13 +19,11 @@ def test_bw_binary_eq(input_shapes, device): in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True) - _, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device) tt_output_tensor_on_device = ttnn.eq_bw(grad_tensor, input_tensor, other_tensor) - in_grad = torch.zeros_like(in_data) - other_grad = torch.zeros_like(other_data) - - golden_tensor = [in_grad, other_grad] + golden_function = ttnn.get_golden_function(ttnn.eq_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass @@ -42,7 +40,7 @@ def test_bw_binary_eq(input_shapes, device): def test_bw_binary_eq_opt_output(input_shapes, device, are_required_outputs): in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True) - _, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device) input_grad = None other_grad = None if are_required_outputs[0]: @@ -59,10 +57,8 @@ def test_bw_binary_eq_opt_output(input_shapes, device, are_required_outputs): input_b_grad=other_grad, ) - in_grad = torch.zeros_like(in_data) - other_grad = torch.zeros_like(other_data) - - golden_tensor = [in_grad, other_grad] + golden_function = ttnn.get_golden_function(ttnn.eq_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) status = True for i in range(len(are_required_outputs)): @@ -82,7 +78,7 @@ def test_bw_binary_eq_opt_output(input_shapes, device, are_required_outputs): def test_bw_binary_eq_opt_output_qid(input_shapes, device, are_required_outputs): in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True) - _, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device) input_grad = None other_grad = None @@ -103,10 +99,8 @@ def test_bw_binary_eq_opt_output_qid(input_shapes, device, are_required_outputs) queue_id=cq_id, ) - in_grad = torch.zeros_like(in_data) - other_grad = torch.zeros_like(other_data) - - golden_tensor = [in_grad, other_grad] + golden_function = ttnn.get_golden_function(ttnn.eq_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) status = True for i in range(len(are_required_outputs)): @@ -129,7 +123,7 @@ def test_bw_unary_eq(input_shapes, other, device): grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) tt_output_tensor_on_device = ttnn.eq_bw(grad_tensor, input_tensor, other) - pt_y = torch.zeros_like(grad_data) - golden_tensor = [pt_y] + golden_function = ttnn.get_golden_function(ttnn.eq_bw) + golden_tensor = golden_function(grad_data, in_data, other) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_ge.py b/tests/ttnn/unit_tests/operations/backward/test_backward_ge.py index 46bf44b49a15..97883b0db6b3 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_ge.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_ge.py @@ -21,8 +21,9 @@ def test_bw_binary_ge(input_shapes, device): grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) tt_output_tensor_on_device = ttnn.ge_bw(grad_tensor, input_tensor, input_tensor) - pt_y = torch.zeros_like(grad_data) - golden_tensor = [pt_y, pt_y] + golden_function = ttnn.get_golden_function(ttnn.ge_bw) + golden_tensor = golden_function(grad_data, in_data, in_data) + comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass @@ -41,9 +42,8 @@ def test_bw_unary_ge(input_shapes, other, device): grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) tt_output_tensor_on_device = ttnn.ge_bw(grad_tensor, input_tensor, other) - pyt_y = torch.zeros_like(grad_data) - - golden_tensor = [pyt_y] + golden_function = ttnn.get_golden_function(ttnn.ge_bw) + golden_tensor = golden_function(grad_data, in_data, other) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_gt.py b/tests/ttnn/unit_tests/operations/backward/test_backward_gt.py index b979eaa2ca58..5cbdbcce1525 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_gt.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_gt.py @@ -22,8 +22,8 @@ def test_bw_binary_gt(input_shapes, device): grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) tt_output_tensor_on_device = ttnn.gt_bw(grad_tensor, input_tensor, other_tensor) - pt_y = torch.zeros_like(grad_data) - golden_tensor = [pt_y, pt_y] + golden_function = ttnn.get_golden_function(ttnn.gt_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass @@ -42,7 +42,7 @@ def test_bw_unary_gt(input_shapes, other, device): grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) tt_output_tensor_on_device = ttnn.gt_bw(grad_tensor, input_tensor, other) - pt_y = torch.zeros_like(grad_data) - golden_tensor = [pt_y] + golden_function = ttnn.get_golden_function(ttnn.gt_bw) + golden_tensor = golden_function(grad_data, in_data, other) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_le.py b/tests/ttnn/unit_tests/operations/backward/test_backward_le.py index cb359da322bf..9b4fd3364c38 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_le.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_le.py @@ -22,8 +22,8 @@ def test_bw_binary_le(input_shapes, device): grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) tt_output_tensor_on_device = ttnn.le_bw(grad_tensor, input_tensor, other_tensor) - pt_y = torch.zeros_like(grad_data) - golden_tensor = [pt_y, pt_y] + golden_function = ttnn.get_golden_function(ttnn.le_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass @@ -42,7 +42,7 @@ def test_bw_unary_le(input_shapes, other, device): grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) tt_output_tensor_on_device = ttnn.le_bw(grad_tensor, input_tensor, other) - pt_y = torch.zeros_like(grad_data) - golden_tensor = [pt_y] + golden_function = ttnn.get_golden_function(ttnn.le_bw) + golden_tensor = golden_function(grad_data, in_data, other) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_lt.py b/tests/ttnn/unit_tests/operations/backward/test_backward_lt.py index cc54cefb3705..1f64cf6c57ee 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_lt.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_lt.py @@ -22,8 +22,8 @@ def test_bw_binary_lt(input_shapes, device): grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device) tt_output_tensor_on_device = ttnn.lt_bw(grad_tensor, input_tensor, other_tensor) - pt_y = torch.zeros_like(grad_data) - golden_tensor = [pt_y, pt_y] + golden_function = ttnn.get_golden_function(ttnn.lt_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass @@ -42,7 +42,7 @@ def test_bw_unary_lt(input_shapes, other, device): grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) tt_output_tensor_on_device = ttnn.lt_bw(grad_tensor, input_tensor, other) - pt_y = torch.zeros_like(grad_data) - golden_tensor = [pt_y] + golden_function = ttnn.get_golden_function(ttnn.lt_bw) + golden_tensor = golden_function(grad_data, in_data, other) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_ne.py b/tests/ttnn/unit_tests/operations/backward/test_backward_ne.py index b46136e20bb1..0baa5441d328 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_ne.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_ne.py @@ -21,8 +21,8 @@ def test_bw_binary_ne(input_shapes, device): grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device) tt_output_tensor_on_device = ttnn.ne_bw(grad_tensor, input_tensor, input_tensor) - pt_y = torch.zeros_like(grad_data) - golden_tensor = [pt_y, pt_y] + golden_function = ttnn.get_golden_function(ttnn.ne_bw) + golden_tensor = golden_function(grad_data, in_data, in_data) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass @@ -40,9 +40,8 @@ def test_bw_unary_ne(input_shapes, other, device): in_data, input_tensor = data_gen_with_range(input_shapes, -10, 10, device, True) grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) tt_output_tensor_on_device = ttnn.ne_bw(grad_tensor, input_tensor, other) - pyt_y = torch.zeros_like(grad_data) - - golden_tensor = [pyt_y] + golden_function = ttnn.get_golden_function(ttnn.ne_bw) + golden_tensor = golden_function(grad_data, in_data, other) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_sub.py b/tests/ttnn/unit_tests/operations/backward/test_backward_sub.py index e6369f7d72ef..a73803150b5c 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_sub.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_sub.py @@ -23,14 +23,8 @@ def test_bw_sub(input_shapes, device): tt_output_tensor_on_device = ttnn.sub_bw(grad_tensor, input_tensor, other_tensor) - in_data.retain_grad() - other_data.retain_grad() - - pyt_y = torch.sub(in_data, other_data) - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad, other_data.grad] + golden_function = ttnn.get_golden_function(ttnn.sub_bw) + golden_tensor = golden_function(grad_data, in_data, other_data) status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status @@ -51,13 +45,8 @@ def test_bw_unary_sub(input_shapes, scalar, device): tt_output_tensor_on_device = ttnn.sub_bw(grad_tensor, input_tensor, scalar) - in_data.retain_grad() - - pyt_y = torch.sub(in_data, torch.tensor(scalar)) - - pyt_y.backward(gradient=grad_data) - - golden_tensor = [in_data.grad] + golden_function = ttnn.get_golden_function(ttnn.sub_bw) + golden_tensor = golden_function(grad_data, in_data, scalar) status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status diff --git a/ttnn/ttnn/operations/binary_backward.py b/ttnn/ttnn/operations/binary_backward.py index 46f8003ce5f8..0b2c2d14010d 100644 --- a/ttnn/ttnn/operations/binary_backward.py +++ b/ttnn/ttnn/operations/binary_backward.py @@ -56,14 +56,32 @@ def _golden_function_backward(torch_op, grad_tensor, input_tensor_a, input_tenso return _golden_function_complex_mul(grad_tensor, input_tensor_a, input_tensor_b) if torch_op == "torch.squared_difference": pyt_y = torch.square(torch.sub(input_tensor_a, input_tensor_b)) - elif torch_op == torch.clone: + else: + pyt_y = torch_op(input_tensor_a, input_tensor_b) + input_tensor_a.retain_grad() + input_tensor_b.retain_grad() + pyt_y.backward(gradient=grad_tensor) + golden_tensor = [input_tensor_a.grad, input_tensor_b.grad] + return golden_tensor + + +def _golden_function_backward_overload(torch_op, grad_tensor, input_tensor_a, input_tensor_b=None, *args, **kwargs): + if torch_op == torch.clone: pyt_y = torch.clone(input_tensor_a) + input_tensor_a.retain_grad() + pyt_y.backward(gradient=grad_tensor) + if input_tensor_b == None: + golden_tensor = [input_tensor_a.grad] + return golden_tensor + else: + golden_tensor = [input_tensor_a.grad, input_tensor_a.grad] + return golden_tensor + pyt_y = torch_op(input_tensor_a, input_tensor_b) + if isinstance(input_tensor_b, (float, int)): input_tensor_a.retain_grad() pyt_y.backward(gradient=grad_tensor) golden_tensor = [input_tensor_a.grad] return golden_tensor - else: - pyt_y = torch_op(input_tensor_a, input_tensor_b) input_tensor_a.retain_grad() input_tensor_b.retain_grad() pyt_y.backward(gradient=grad_tensor) @@ -96,22 +114,31 @@ def _golden_function_backward_with_float( return golden_tensor -def _golden_function_backward_with_string(torch_op, grad_tensor, input_tensor_a, input_tensor_b, *args, **kwargs): +def _golden_function_backward_with_string( + torch_op, grad_tensor, input_tensor_a, input_tensor_b, value=None, *args, **kwargs +): if torch.is_complex(input_tensor_a): if torch_op == torch.div: return _golden_function_complex_div(grad_tensor, input_tensor_a, input_tensor_b) - if torch_op == bias_gelu: + if torch_op == "bias_gelu_bw": sum_result = torch.add(input_tensor_a, input_tensor_b) - pyt_y = torch.nn.functional.gelu(sum_result) + pyt_y = torch.nn.functional.gelu(sum_result, approximate=value) sum_result.retain_grad() pyt_y.backward(gradient=grad_tensor) - golden_tensor = [sum_result.grad, sum_result.grad] + if isinstance(input_tensor_b, (float, int)): + golden_tensor = [sum_result.grad] + else: + golden_tensor = [sum_result.grad, sum_result.grad] return golden_tensor - value = kwargs.pop("value") - if torch_op == torch.div: + elif torch_op == torch.div: pyt_y = torch_op(input_tensor_a, input_tensor_b, rounding_mode=value) else: pyt_y = torch_op(input_tensor_a, input_tensor_b, value=value) + if isinstance(input_tensor_b, (float, int)): + input_tensor_a.retain_grad() + pyt_y.backward(gradient=grad_tensor) + golden_tensor = [input_tensor_a.grad] + return golden_tensor input_tensor_a.retain_grad() input_tensor_b.retain_grad() pyt_y.backward(gradient=grad_tensor) @@ -120,20 +147,23 @@ def _golden_function_backward_with_string(torch_op, grad_tensor, input_tensor_a, def _golden_function_comparison_ops(torch_op, grad_tensor, input_tensor_a, input_tensor_b, *args, **kwargs): - golden_tensor = [torch.zeros_like(input_tensor_a), torch.zeros_like(input_tensor_b)] + if isinstance(input_tensor_b, (float, int)): + golden_tensor = [torch.zeros_like(input_tensor_a)] + else: + golden_tensor = [torch.zeros_like(input_tensor_a), torch.zeros_like(input_tensor_b)] return golden_tensor ttnn.attach_golden_function( ttnn.sub_bw, - golden_function=lambda grad, a, b, *args, **kwargs: _golden_function_backward( + golden_function=lambda grad, a, b, *args, **kwargs: _golden_function_backward_overload( torch.sub, grad, a, b, *args, **kwargs ), ) ttnn.attach_golden_function( ttnn.add_bw, - golden_function=lambda grad, a, b, *args, **kwargs: _golden_function_backward( + golden_function=lambda grad, a, b, *args, **kwargs: _golden_function_backward_overload( torch.add, grad, a, b, *args, **kwargs ), ) @@ -210,7 +240,7 @@ def _golden_function_comparison_ops(torch_op, grad_tensor, input_tensor_a, input ttnn.attach_golden_function( ttnn.assign_bw, - golden_function=lambda grad, a, b, *args, **kwargs: _golden_function_backward( + golden_function=lambda grad, a, b=None, *args, **kwargs: _golden_function_backward_overload( torch.clone, grad, a, b, *args, **kwargs ), ) @@ -238,8 +268,8 @@ def _golden_function_comparison_ops(torch_op, grad_tensor, input_tensor_a, input ttnn.attach_golden_function( ttnn.bias_gelu_bw, - golden_function=lambda grad, a, b, value, *args, **kwargs: _golden_function_backward_with_string( - torch.gelu, grad, a, b, value, *args, **kwargs + golden_function=lambda grad, a, b, value="none", *args, **kwargs: _golden_function_backward_with_string( + "bias_gelu_bw", grad, a, b, value, *args, **kwargs ), ) @@ -287,8 +317,8 @@ def _golden_function_comparison_ops(torch_op, grad_tensor, input_tensor_a, input ttnn.attach_golden_function( ttnn.div_bw, - golden_function=lambda grad, a, b, *args, **kwargs: _golden_function_backward_with_string( - torch.div, grad, a, b, *args, **kwargs + golden_function=lambda grad, a, b, value=None, *args, **kwargs: _golden_function_backward_with_string( + torch.div, grad, a, b, value, *args, **kwargs ), ) diff --git a/ttnn/ttnn/operations/unary_backward.py b/ttnn/ttnn/operations/unary_backward.py index ff6c8e3c391e..d6ceb8767f1e 100644 --- a/ttnn/ttnn/operations/unary_backward.py +++ b/ttnn/ttnn/operations/unary_backward.py @@ -13,7 +13,7 @@ __all__ = [] -def _golden_function_unary_backward(torch_op, grad_tensor, input_tensor, *args, **kwargs): +def _golden_function_unary_ackward(torch_op, grad_tensor, input_tensor, *args, **kwargs): if torch_op == "softsign": pyt_y = torch.nn.functional.softsign(input_tensor) else: