From a83be44c487758aefce8d893424abd9f9ec6c3ab Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Tue, 30 Jan 2024 13:02:54 -0800 Subject: [PATCH] Fix some more core aten ops (#6414) --- test/test_core_aten_ops.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 8f1dd0a3344..3ab0c8c3c53 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -29,7 +29,8 @@ def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True): output2_cpu = output2.detach().cpu() if output2_cpu.dtype != output1.dtype: output2_cpu = output2_cpu.to(output1.dtype) - # import pdb; pdb.set_trace() + # import pdb + # pdb.set_trace() testcase.assertTrue( torch.allclose( output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan)) @@ -2335,19 +2336,16 @@ def test_aten_leaky_relu_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.leaky_relu, args, kwargs) - @unittest.skip def test_aten_lift_fresh_copy_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.lift_fresh_copy, args, kwargs) - @unittest.skip def test_aten_lift_fresh_copy_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.lift_fresh_copy, args, kwargs) - @unittest.skip def test_aten_lift_fresh_copy_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -2700,7 +2698,6 @@ def test_aten_max_pool2d_with_indices_1(self): run_export_and_compare(self, torch.ops.aten.max_pool2d_with_indices, args, kwargs) - @unittest.skip def test_aten_max_pool2d_with_indices_2(self): args = ( torch.randint(0, 10, (3, 2, 10)).to(torch.int32), @@ -3832,7 +3829,6 @@ def test_aten_scatter_reduce_two_0(self): run_export_and_compare(self, torch.ops.aten.scatter_reduce.two, args, kwargs) - @unittest.skip def test_aten_scatter_reduce_two_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -3842,8 +3838,14 @@ def test_aten_scatter_reduce_two_1(self): "sum", ) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.scatter_reduce.two, args, - kwargs) + run_export_and_compare( + self, + torch.ops.aten.scatter_reduce.two, + args, + kwargs, + rtol=0.001, + atol=0.01, + ) def test_aten_scatter_reduce_two_2(self): args = ( @@ -4388,14 +4390,20 @@ def test_aten_sum_dim_IntList_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.sum.dim_IntList, args, kwargs) - @unittest.skip def test_aten_sum_dim_IntList_1(self): args = ( torch.randn((10, 10)).to(torch.float16), None, ) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sum.dim_IntList, args, kwargs) + run_export_and_compare( + self, + torch.ops.aten.sum.dim_IntList, + args, + kwargs, + rtol=0.001, + atol=0.01, + ) def test_aten_sum_dim_IntList_2(self): args = (