Skip to content

Commit

Permalink
[Core Aten] Add and enable tests for aten index_select and logical_and (
Browse files Browse the repository at this point in the history
  • Loading branch information
chunnienc authored Jan 11, 2024
1 parent 8141078 commit 04e1238
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2162,7 +2162,6 @@ def test_aten_index_select_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs)

@unittest.skip
def test_aten_index_select_1(self):
args = (
torch.randn((2, 10)).to(torch.float16),
Expand All @@ -2181,6 +2180,33 @@ def test_aten_index_select_2(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs)

def test_aten_index_select_3(self):
args = (
torch.randn((2, 10)).to(torch.float32),
1,
torch.randint(0, 10, (2,)).to(torch.int32),
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs)

def test_aten_index_select_4(self):
args = (
torch.randn((2, 10)).to(torch.float16),
1,
torch.randint(0, 10, (2,)).to(torch.int32),
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs)

def test_aten_index_select_5(self):
args = (
torch.randint(0, 10, (2, 10)).to(torch.int32),
1,
torch.randint(0, 10, (2,)).to(torch.int32),
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs)

@unittest.skip
def test_aten_index_Tensor_0(self):
args = (
Expand Down Expand Up @@ -2437,7 +2463,6 @@ def test_aten_logical_and_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs)

@unittest.skip
def test_aten_logical_and_2(self):
args = (
torch.randint(0, 10, (10, 10)).to(torch.int32),
Expand All @@ -2446,6 +2471,14 @@ def test_aten_logical_and_2(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs)

def test_aten_logical_and_3(self):
args = (
torch.randint(0, 2, (10, 10)).to(torch.bool),
torch.randint(0, 2, (10, 10)).to(torch.bool),
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs)

def test_aten_logical_not_0(self):
args = (torch.randn((10, 10)).to(torch.float32),)
kwargs = dict()
Expand Down

0 comments on commit 04e1238

Please sign in to comment.