From 04e1238aeed71f9d42810b272cdba4516966f096 Mon Sep 17 00:00:00 2001 From: Chunnien Chan <121328115+chunnienc@users.noreply.github.com> Date: Thu, 11 Jan 2024 11:33:48 -0800 Subject: [PATCH] [Core Aten] Add and enable tests for aten index_select and logical_and (#6293) --- test/test_core_aten_ops.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index a3c8316229b..e360b2d424d 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -2162,7 +2162,6 @@ def test_aten_index_select_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs) - @unittest.skip def test_aten_index_select_1(self): args = ( torch.randn((2, 10)).to(torch.float16), @@ -2181,6 +2180,33 @@ def test_aten_index_select_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs) + def test_aten_index_select_3(self): + args = ( + torch.randn((2, 10)).to(torch.float32), + 1, + torch.randint(0, 10, (2,)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs) + + def test_aten_index_select_4(self): + args = ( + torch.randn((2, 10)).to(torch.float16), + 1, + torch.randint(0, 10, (2,)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs) + + def test_aten_index_select_5(self): + args = ( + torch.randint(0, 10, (2, 10)).to(torch.int32), + 1, + torch.randint(0, 10, (2,)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs) + @unittest.skip def test_aten_index_Tensor_0(self): args = ( @@ -2437,7 +2463,6 @@ def test_aten_logical_and_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs) - @unittest.skip def test_aten_logical_and_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -2446,6 +2471,14 @@ def test_aten_logical_and_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs) + def test_aten_logical_and_3(self): + args = ( + torch.randint(0, 2, (10, 10)).to(torch.bool), + torch.randint(0, 2, (10, 10)).to(torch.bool), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs) + def test_aten_logical_not_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict()