Skip to content

Commit

Permalink
Increase tolerance for tan
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Nov 22, 2023
1 parent 3fa035e commit 86d6596
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 8 deletions.
79 changes: 79 additions & 0 deletions docs/fixing_core_aten_ops_log.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@

# Issue beging worked https://github.com/pytorch/xla/issues/5902

qihqi

## 1. Uncomment and rerun the test

```
LD_LIBRARY_PATH=/mnt/hanq/miniconda3/envs/torch310/lib/:/usr/lib/x86_64-linux-gnu/ PJRT_DEVICE=CPU XLA_STABLEHLO_COMPILE=1 XLA_HLO_DEBUG=1 XLA_IR_DEBUG=1 pytest test/test_core_aten_ops.py -k test_aten_tan_1
```

output:
```
=========================== short test summary info ============================
[torch_xla_diff:0.001] SUBFAIL test/test_core_aten_ops.py::AtenOpTest::test_aten_tan_1 - AssertionError: False is not true
[stablehlo_diff: 0.001] SUBFAIL test/test_core_aten_ops.py::AtenOpTest::test_aten_tan_1 - AssertionError: False is not true
================= 2 failed, 1 passed, 514 deselected in 5.51s ==================
I0000 00:00:1700690393.569658 2513762 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.
(torch310) hanq@hanq-compile-2:/mnt/hanq/git/qihqi/pytorch/xla$
```

This means that the accuracy is not good.

Break line here
```
(torch310) hanq@hanq-compile-2:/mnt/hanq/git/qihqi/pytorch/xla$ git diff
diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py
index 46a18494d..ff055ee38 100644
--- a/test/test_core_aten_ops.py
+++ b/test/test_core_aten_ops.py
@@ -36,6 +36,7 @@ def run_export_and_compare(testcase, func, args, kwargs, atol=1e-3):
lambda x: x.to(device=device), kwargs)
res_xla = func(*args2, **kwargs2)
with testcase.subTest('torch_xla_diff:' + str(atol)):
+ import pdb; pdb.set_trace()
diff_output(testcase, res, res_xla, atol)
```

Rerun, print out the difference:
```
(Pdb) p res - res_xla.cpu()
tensor([[ 0.0000e+00, 0.0000e+00, -4.8828e-04, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 6.1035e-05, 0.0000e+00, 0.0000e+00],
[-4.8828e-04, 0.0000e+00, 0.0000e+00, 9.7656e-04, 0.0000e+00,
1.2207e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, -1.5259e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00,
4.8828e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.2207e-04],
[ 0.0000e+00, 2.4414e-04, 0.0000e+00, -1.9531e-03, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, -3.0518e-05, 0.0000e+00],
[ 0.0000e+00, -4.8828e-04, -2.4414e-04, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, -6.1035e-05, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -1.9531e-03,
0.0000e+00, 0.0000e+00, 1.9531e-03, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, -1.9531e-03, 0.0000e+00, 0.0000e+00,
2.4414e-04, 9.7656e-04, 1.2207e-04, 0.0000e+00, 0.0000e+00],
[ 4.8828e-04, 0.0000e+00, 0.0000e+00, -7.8125e-03, 1.2207e-04,
-9.7656e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 1.5625e-02, 0.0000e+00, 0.0000e+00, -4.8828e-04,
-1.2207e-04, 0.0000e+00, 0.0000e+00, -4.8828e-04, -3.9062e-03],
[ 0.0000e+00, -1.2207e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],
dtype=torch.float16)
```
The result looks good enough; This means that probably we are being too strict in
test; setting a larger tolerance probably will work.
```
(Pdb) p torch.max(torch.abs(res - res_xla.cpu()))
tensor(0.0156, dtype=torch.float16)
```
printing out the difference shows that roughly 0.01 atol with a slightly larger
`rtol` probably work.

```
(Pdb) torch.allclose(res, res_xla.cpu(), atol=0.01, rtol=0.001)
True
```
Now it's time to PR:
https://github.com/pytorch/xla/pull/5915
23 changes: 15 additions & 8 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,24 @@
import unittest


def diff_output(testcase, output1, output2, atol):
def diff_output(testcase, output1, output2, rtol, atol):
if isinstance(output1, torch.Tensor):
testcase.assertIsInstance(output2, torch.Tensor)
output2_cpu = output2.detach().cpu()
if output2_cpu.dtype != output1.dtype:
output2_cpu = output2_cpu.to(output1.dtype)
testcase.assertTrue(torch.allclose(output1, output2_cpu, atol=atol))
testcase.assertTrue(
torch.allclose(output1, output2_cpu, atol=atol, rtol=rtol))
elif isinstance(output1, (tuple, list)):
testcase.assertIsInstance(output2, (tuple, list))
testcase.assertEqual(len(output1), len(output2))
for o1, o2 in zip(output1, output2):
diff_output(testcase, o1, o2, atol)
diff_output(testcase, o1, o2, rtol, atol)
else:
testcase.assertEqual(output1, output2)


def run_export_and_compare(testcase, func, args, kwargs, atol=1e-3):
def run_export_and_compare(testcase, func, args, kwargs, atol=1e-3, rtol=1e-5):
device = xm.xla_device()
with testcase.subTest('torch_eval'):
res = func(*args, **kwargs)
Expand All @@ -36,15 +37,15 @@ def run_export_and_compare(testcase, func, args, kwargs, atol=1e-3):
lambda x: x.to(device=device), kwargs)
res_xla = func(*args2, **kwargs2)
with testcase.subTest('torch_xla_diff:' + str(atol)):
diff_output(testcase, res, res_xla, atol)
diff_output(testcase, res, res_xla, atol=atol, rtol=rtol)
with testcase.subTest('can_export'):
exported = torch.export.export(func, args, kwargs)
with testcase.subTest('can_convert_to_stablehlo'):
shlo = exported_program_to_stablehlo(exported)
with testcase.subTest('stablehlo_can_run'):
res2 = shlo(*args, **kwargs)
with testcase.subTest('stablehlo_diff: ' + str(atol)):
diff_output(testcase, res, res2, atol)
diff_output(testcase, res, res2, rtol=rtol, atol=atol)


class AtenOpTest(unittest.TestCase):
Expand Down Expand Up @@ -4363,11 +4364,17 @@ def test_aten_tan_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.tan, args, kwargs)

@unittest.skip
def test_aten_tan_1(self):
args = (torch.randn((10, 10)).to(torch.float16),)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.tan, args, kwargs)
run_export_and_compare(
self,
torch.ops.aten.tan,
args,
kwargs,
rtol=0.001,
atol=0.01,
)

def test_aten_tan_2(self):
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
Expand Down

0 comments on commit 86d6596

Please sign in to comment.