From 3e2f2d3b524e5f24ed86f1e2750c0ab02220b1b3 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Wed, 22 Nov 2023 22:18:36 +0000 Subject: [PATCH] Increase tolerance for tan --- FIX_LOWERING_FOR_CORE_ATEN_OPS.md | 11 ++++++++++- test/test_core_aten_ops.py | 23 +++++++++++++++-------- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/FIX_LOWERING_FOR_CORE_ATEN_OPS.md b/FIX_LOWERING_FOR_CORE_ATEN_OPS.md index 0f889450458..deb3b804d9c 100644 --- a/FIX_LOWERING_FOR_CORE_ATEN_OPS.md +++ b/FIX_LOWERING_FOR_CORE_ATEN_OPS.md @@ -23,7 +23,16 @@ This subtest calls the torch_xla version of the op. If you've made changes to lo ### `torch_xla_diff` -This subtest compares the output of the op between torch and torch_xla. If this subtest fails, it implies that your lowering runs successfully but contains a bug and/or logical error. We recommend you to review your lowering code. And again, feel free to leave a comment in your assigned GitHub issue if you're blocked and/or unable to debug further. +This subtest compares the output of the op between torch and torch_xla. +If this subtest fails, it implies that your lowering runs successfully +but produced a different result than torch eager mode. + +If the test uses 16-bit floats (float16, bfloat16); This is very likely +that the tolerances that we give to `torch.allclose` to compare was to +strict. You can relax it a bit. Take a look at [this issue](https://github.com/pytorch/xla/issues/5934) of one such example. + +If the result torchxla produces is totally different than what torch produces, that means it's a bug in lowering code; and probably need +more work. Feel free to tag more people (such as qihqi to look). ### `can_export`, `can_convert_to_stablehlo`, `stablehlo_can_run`, `stablehlo_diff` diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 147043f268f..eedbeca065d 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -9,23 +9,24 @@ import unittest -def diff_output(testcase, output1, output2, atol): +def diff_output(testcase, output1, output2, rtol, atol): if isinstance(output1, torch.Tensor): testcase.assertIsInstance(output2, torch.Tensor) output2_cpu = output2.detach().cpu() if output2_cpu.dtype != output1.dtype: output2_cpu = output2_cpu.to(output1.dtype) - testcase.assertTrue(torch.allclose(output1, output2_cpu, atol=atol)) + testcase.assertTrue( + torch.allclose(output1, output2_cpu, atol=atol, rtol=rtol)) elif isinstance(output1, (tuple, list)): testcase.assertIsInstance(output2, (tuple, list)) testcase.assertEqual(len(output1), len(output2)) for o1, o2 in zip(output1, output2): - diff_output(testcase, o1, o2, atol) + diff_output(testcase, o1, o2, rtol, atol) else: testcase.assertEqual(output1, output2) -def run_export_and_compare(testcase, func, args, kwargs, atol=1e-3): +def run_export_and_compare(testcase, func, args, kwargs, atol=1e-3, rtol=1e-5): device = xm.xla_device() with testcase.subTest('torch_eval'): res = func(*args, **kwargs) @@ -36,7 +37,7 @@ def run_export_and_compare(testcase, func, args, kwargs, atol=1e-3): lambda x: x.to(device=device), kwargs) res_xla = func(*args2, **kwargs2) with testcase.subTest('torch_xla_diff:' + str(atol)): - diff_output(testcase, res, res_xla, atol) + diff_output(testcase, res, res_xla, atol=atol, rtol=rtol) with testcase.subTest('can_export'): exported = torch.export.export(func, args, kwargs) with testcase.subTest('can_convert_to_stablehlo'): @@ -44,7 +45,7 @@ def run_export_and_compare(testcase, func, args, kwargs, atol=1e-3): with testcase.subTest('stablehlo_can_run'): res2 = shlo(*args, **kwargs) with testcase.subTest('stablehlo_diff: ' + str(atol)): - diff_output(testcase, res, res2, atol) + diff_output(testcase, res, res2, rtol=rtol, atol=atol) class AtenOpTest(unittest.TestCase): @@ -4372,11 +4373,17 @@ def test_aten_tan_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.tan, args, kwargs) - @unittest.skip def test_aten_tan_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.tan, args, kwargs) + run_export_and_compare( + self, + torch.ops.aten.tan, + args, + kwargs, + rtol=0.001, + atol=0.01, + ) def test_aten_tan_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)