Skip to content

Commit

Permalink
#9874: Update golden function in test file
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jul 22, 2024
1 parent e86841e commit e49fe86
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 175 deletions.
29 changes: 6 additions & 23 deletions tests/ttnn/unit_tests/operations/backward/test_backward_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,8 @@ def test_bw_add(input_shapes, device):

tt_output_tensor_on_device = ttnn.add_bw(grad_tensor, input_tensor, other_tensor)

in_data.retain_grad()
other_data.retain_grad()

pyt_y = torch.add(in_data, other_data)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad, other_data.grad]
golden_function = ttnn.get_golden_function(ttnn.add_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status
Expand Down Expand Up @@ -69,14 +63,8 @@ def test_bw_add_with_opt_output(input_shapes, device, are_required_outputs):
queue_id=cq_id,
)

in_data.retain_grad()
other_data.retain_grad()

pyt_y = torch.add(in_data, other_data)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad, other_data.grad]
golden_function = ttnn.get_golden_function(ttnn.add_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = True
for i in range(len(are_required_outputs)):
Expand All @@ -100,13 +88,8 @@ def test_bw_unary_add(input_shapes, alpha, device):

tt_output_tensor_on_device = ttnn.add_bw(grad_tensor, input_tensor, alpha=alpha)

in_data.retain_grad()

pyt_y = torch.add(in_data, torch.tensor(alpha))

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
golden_function = ttnn.get_golden_function(ttnn.add_bw)
golden_tensor = golden_function(grad_data, in_data, alpha)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status
18 changes: 4 additions & 14 deletions tests/ttnn/unit_tests/operations/backward/test_backward_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,8 @@ def test_bw_unary_assign(input_shapes, device):

tt_output_tensor_on_device = ttnn.assign_bw(grad_tensor, input_tensor)

in_data.retain_grad()

pyt_y = torch.clone(in_data)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
golden_function = ttnn.get_golden_function(ttnn.assign_bw)
golden_tensor = golden_function(grad_data, in_data)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status
Expand All @@ -50,12 +45,7 @@ def test_bw_binary_assign(input_shapes, device):

tt_output_tensor_on_device = ttnn.assign_bw(grad_tensor, input_tensor, other_tensor)

in_data.retain_grad()

pyt_y = torch.clone(in_data)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad, in_data.grad]
golden_function = ttnn.get_golden_function(ttnn.assign_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)
status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,9 @@ def test_bw_binary_bias_gelu(input_shapes, approximate, device):
in_data_a, input_tensor_a = data_gen_with_range(input_shapes, -100, 100, device, True)
in_data_b, input_tensor_b = data_gen_with_range(input_shapes, -10, 10, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device)
in_data = in_data_a + in_data_b

pyt_y = torch.nn.functional.gelu(in_data, approximate=approximate)

tt_output_tensor_on_device = ttnn.bias_gelu_bw(grad_tensor, input_tensor_a, input_tensor_b, approximate=approximate)
in_data.retain_grad()

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad, in_data.grad]
golden_function = ttnn.get_golden_function(ttnn.bias_gelu_bw)
golden_tensor = golden_function(grad_data, in_data_a, in_data_b, approximate)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass

Expand Down Expand Up @@ -66,17 +59,11 @@ def test_bw_binary_bias_gelu(input_shapes, approximate, device):
def test_bw_bias_gelu_unary(input_shapes, approximate, bias, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device, True)
in_data = in_data + bias

pyt_y = torch.nn.functional.gelu(in_data, approximate=approximate)

tt_output_tensor_on_device = ttnn.bias_gelu_bw(grad_tensor, input_tensor, bias, approximate=approximate)

in_data.retain_grad()

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
golden_function = ttnn.get_golden_function(ttnn.bias_gelu_bw)
golden_tensor = golden_function(grad_data, in_data, bias, approximate)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass

Expand All @@ -99,16 +86,10 @@ def test_bw_bias_gelu_unary(input_shapes, approximate, bias, device):
def test_bw_bias_gelu_unary_default(input_shapes, bias, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device, True)
in_data = in_data + bias

pyt_y = torch.nn.functional.gelu(in_data)

tt_output_tensor_on_device = ttnn.bias_gelu_bw(grad_tensor, input_tensor, bias)

in_data.retain_grad()

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
golden_function = ttnn.get_golden_function(ttnn.bias_gelu_bw)
golden_tensor = golden_function(grad_data, in_data, bias)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
53 changes: 12 additions & 41 deletions tests/ttnn/unit_tests/operations/backward/test_backward_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,13 @@ def test_bw_div_binary(input_shapes, round_mode, device):
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)

pyt_y = torch.div(in_data, other_data, rounding_mode=round_mode)
golden_function = ttnn.get_golden_function(ttnn.div_bw)
golden_tensor = golden_function(grad_data, in_data, other_data, round_mode)

if round_mode == None:
round_mode = "None"
tt_output_tensor_on_device = ttnn.div_bw(grad_tensor, input_tensor, other_tensor, round_mode=round_mode)

in_data.retain_grad()
other_data.retain_grad()

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad, other_data.grad]
status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status

Expand All @@ -61,16 +56,10 @@ def test_bw_div_binary_default(input_shapes, device):
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)

pyt_y = torch.div(in_data, other_data)
golden_function = ttnn.get_golden_function(ttnn.div_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

tt_output_tensor_on_device = ttnn.div_bw(grad_tensor, input_tensor, other_tensor)

in_data.retain_grad()
other_data.retain_grad()

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad, other_data.grad]
status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status

Expand Down Expand Up @@ -99,15 +88,10 @@ def test_bw_unary_div_0(input_shapes, scalar, round_mode, device):

tt_output_tensor_on_device = ttnn.div_bw(grad_tensor, input_tensor, scalar, round_mode=round_mode)

in_data.retain_grad()

if round_mode == "None":
round_mode = None
pyt_y = torch.div(in_data, torch.tensor(scalar), rounding_mode=round_mode)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
golden_function = ttnn.get_golden_function(ttnn.div_bw)
golden_tensor = golden_function(grad_data, in_data, scalar, round_mode)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status
Expand Down Expand Up @@ -140,11 +124,8 @@ def test_bw_unary_div(input_shapes, scalar, round_mode, device):

if round_mode == "None":
round_mode = None
pyt_y = torch.div(in_data, torch.tensor(scalar), rounding_mode=round_mode)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
golden_function = ttnn.get_golden_function(ttnn.div_bw)
golden_tensor = golden_function(grad_data, in_data, scalar, round_mode)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status
Expand All @@ -166,13 +147,8 @@ def test_bw_unary_div_0_default(input_shapes, scalar, device):

tt_output_tensor_on_device = ttnn.div_bw(grad_tensor, input_tensor, scalar)

in_data.retain_grad()

pyt_y = torch.div(in_data, torch.tensor(scalar))

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
golden_function = ttnn.get_golden_function(ttnn.div_bw)
golden_tensor = golden_function(grad_data, in_data, scalar)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status
Expand All @@ -193,13 +169,8 @@ def test_bw_unary_div_default(input_shapes, scalar, device):

tt_output_tensor_on_device = ttnn.div_bw(grad_tensor, input_tensor, scalar)

in_data.retain_grad()

pyt_y = torch.div(in_data, torch.tensor(scalar))

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
golden_function = ttnn.get_golden_function(ttnn.div_bw)
golden_tensor = golden_function(grad_data, in_data, scalar)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status
28 changes: 11 additions & 17 deletions tests/ttnn/unit_tests/operations/backward/test_backward_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@
def test_bw_binary_eq(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True)
_, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device)

tt_output_tensor_on_device = ttnn.eq_bw(grad_tensor, input_tensor, other_tensor)
in_grad = torch.zeros_like(in_data)
other_grad = torch.zeros_like(other_data)

golden_tensor = [in_grad, other_grad]
golden_function = ttnn.get_golden_function(ttnn.eq_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass

Expand All @@ -42,7 +40,7 @@ def test_bw_binary_eq(input_shapes, device):
def test_bw_binary_eq_opt_output(input_shapes, device, are_required_outputs):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True)
_, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device)
input_grad = None
other_grad = None
if are_required_outputs[0]:
Expand All @@ -59,10 +57,8 @@ def test_bw_binary_eq_opt_output(input_shapes, device, are_required_outputs):
input_b_grad=other_grad,
)

in_grad = torch.zeros_like(in_data)
other_grad = torch.zeros_like(other_data)

golden_tensor = [in_grad, other_grad]
golden_function = ttnn.get_golden_function(ttnn.eq_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = True
for i in range(len(are_required_outputs)):
Expand All @@ -82,7 +78,7 @@ def test_bw_binary_eq_opt_output(input_shapes, device, are_required_outputs):
def test_bw_binary_eq_opt_output_qid(input_shapes, device, are_required_outputs):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True)
_, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device)
input_grad = None
other_grad = None

Expand All @@ -103,10 +99,8 @@ def test_bw_binary_eq_opt_output_qid(input_shapes, device, are_required_outputs)
queue_id=cq_id,
)

in_grad = torch.zeros_like(in_data)
other_grad = torch.zeros_like(other_data)

golden_tensor = [in_grad, other_grad]
golden_function = ttnn.get_golden_function(ttnn.eq_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = True
for i in range(len(are_required_outputs)):
Expand All @@ -129,7 +123,7 @@ def test_bw_unary_eq(input_shapes, other, device):
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)

tt_output_tensor_on_device = ttnn.eq_bw(grad_tensor, input_tensor, other)
pt_y = torch.zeros_like(grad_data)
golden_tensor = [pt_y]
golden_function = ttnn.get_golden_function(ttnn.eq_bw)
golden_tensor = golden_function(grad_data, in_data, other)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
10 changes: 5 additions & 5 deletions tests/ttnn/unit_tests/operations/backward/test_backward_ge.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ def test_bw_binary_ge(input_shapes, device):
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)

tt_output_tensor_on_device = ttnn.ge_bw(grad_tensor, input_tensor, input_tensor)
pt_y = torch.zeros_like(grad_data)
golden_tensor = [pt_y, pt_y]
golden_function = ttnn.get_golden_function(ttnn.ge_bw)
golden_tensor = golden_function(grad_data, in_data, in_data)

comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass

Expand All @@ -41,9 +42,8 @@ def test_bw_unary_ge(input_shapes, other, device):
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)
tt_output_tensor_on_device = ttnn.ge_bw(grad_tensor, input_tensor, other)

pyt_y = torch.zeros_like(grad_data)

golden_tensor = [pyt_y]
golden_function = ttnn.get_golden_function(ttnn.ge_bw)
golden_tensor = golden_function(grad_data, in_data, other)

comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
8 changes: 4 additions & 4 deletions tests/ttnn/unit_tests/operations/backward/test_backward_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def test_bw_binary_gt(input_shapes, device):
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)

tt_output_tensor_on_device = ttnn.gt_bw(grad_tensor, input_tensor, other_tensor)
pt_y = torch.zeros_like(grad_data)
golden_tensor = [pt_y, pt_y]
golden_function = ttnn.get_golden_function(ttnn.gt_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass

Expand All @@ -42,7 +42,7 @@ def test_bw_unary_gt(input_shapes, other, device):
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)

tt_output_tensor_on_device = ttnn.gt_bw(grad_tensor, input_tensor, other)
pt_y = torch.zeros_like(grad_data)
golden_tensor = [pt_y]
golden_function = ttnn.get_golden_function(ttnn.gt_bw)
golden_tensor = golden_function(grad_data, in_data, other)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
8 changes: 4 additions & 4 deletions tests/ttnn/unit_tests/operations/backward/test_backward_le.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def test_bw_binary_le(input_shapes, device):
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)

tt_output_tensor_on_device = ttnn.le_bw(grad_tensor, input_tensor, other_tensor)
pt_y = torch.zeros_like(grad_data)
golden_tensor = [pt_y, pt_y]
golden_function = ttnn.get_golden_function(ttnn.le_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass

Expand All @@ -42,7 +42,7 @@ def test_bw_unary_le(input_shapes, other, device):
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)

tt_output_tensor_on_device = ttnn.le_bw(grad_tensor, input_tensor, other)
pt_y = torch.zeros_like(grad_data)
golden_tensor = [pt_y]
golden_function = ttnn.get_golden_function(ttnn.le_bw)
golden_tensor = golden_function(grad_data, in_data, other)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Loading

0 comments on commit e49fe86

Please sign in to comment.