Skip to content

Commit

Permalink
Fix some more core aten ops (#6414)
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 authored Jan 30, 2024
1 parent 0baf1da commit a83be44
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True):
output2_cpu = output2.detach().cpu()
if output2_cpu.dtype != output1.dtype:
output2_cpu = output2_cpu.to(output1.dtype)
# import pdb; pdb.set_trace()
# import pdb
# pdb.set_trace()
testcase.assertTrue(
torch.allclose(
output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan))
Expand Down Expand Up @@ -2335,19 +2336,16 @@ def test_aten_leaky_relu_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.leaky_relu, args, kwargs)

@unittest.skip
def test_aten_lift_fresh_copy_0(self):
args = (torch.randn((10, 10)).to(torch.float32),)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.lift_fresh_copy, args, kwargs)

@unittest.skip
def test_aten_lift_fresh_copy_1(self):
args = (torch.randn((10, 10)).to(torch.float16),)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.lift_fresh_copy, args, kwargs)

@unittest.skip
def test_aten_lift_fresh_copy_2(self):
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
kwargs = dict()
Expand Down Expand Up @@ -2700,7 +2698,6 @@ def test_aten_max_pool2d_with_indices_1(self):
run_export_and_compare(self, torch.ops.aten.max_pool2d_with_indices, args,
kwargs)

@unittest.skip
def test_aten_max_pool2d_with_indices_2(self):
args = (
torch.randint(0, 10, (3, 2, 10)).to(torch.int32),
Expand Down Expand Up @@ -3832,7 +3829,6 @@ def test_aten_scatter_reduce_two_0(self):
run_export_and_compare(self, torch.ops.aten.scatter_reduce.two, args,
kwargs)

@unittest.skip
def test_aten_scatter_reduce_two_1(self):
args = (
torch.randn((10, 10)).to(torch.float16),
Expand All @@ -3842,8 +3838,14 @@ def test_aten_scatter_reduce_two_1(self):
"sum",
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.scatter_reduce.two, args,
kwargs)
run_export_and_compare(
self,
torch.ops.aten.scatter_reduce.two,
args,
kwargs,
rtol=0.001,
atol=0.01,
)

def test_aten_scatter_reduce_two_2(self):
args = (
Expand Down Expand Up @@ -4388,14 +4390,20 @@ def test_aten_sum_dim_IntList_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.sum.dim_IntList, args, kwargs)

@unittest.skip
def test_aten_sum_dim_IntList_1(self):
args = (
torch.randn((10, 10)).to(torch.float16),
None,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.sum.dim_IntList, args, kwargs)
run_export_and_compare(
self,
torch.ops.aten.sum.dim_IntList,
args,
kwargs,
rtol=0.001,
atol=0.01,
)

def test_aten_sum_dim_IntList_2(self):
args = (
Expand Down

0 comments on commit a83be44

Please sign in to comment.