From 7afeb563c03fad0dc6cdabd29a18d4b48d0acff9 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 28 Feb 2024 11:22:23 -0300 Subject: [PATCH] Add more `as_strided` tests. --- test/test_operations.py | 50 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/test/test_operations.py b/test/test_operations.py index 9e274218c4c6..b32d3680bd5a 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1498,6 +1498,56 @@ def test_fn(r): self.runAtenTest([torch.arange(144, dtype=torch.int32)], test_fn) + def test_as_strided_with_gap(self): + + def test_fn(r): + return torch.as_strided(r, (4, 4), (8, 1)) + + self.runAtenTest([torch.arange(28, dtype=torch.int32)], test_fn) + + def test_as_strided_with_gap_no_unit_stride(self): + + def test_fn(r): + return torch.as_strided(r, (4, 4), (8, 2)) + + self.runAtenTest([torch.arange(31, dtype=torch.int32)], test_fn) + + def test_as_strided_with_overlap(self): + + def test_fn(r): + return torch.as_strided(r, (4, 4), (2, 1)) + + self.runAtenTest([torch.arange(10, dtype=torch.int32)], test_fn) + + def test_as_strided_with_overlap_and_gap(self): + + def test_fn(r): + return torch.as_strided(r, (4, 4), (4, 2)) + + self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn) + + def test_as_strided_with_overlap_zero_stride(self): + + def test_fn(r): + return torch.as_strided(r, (4, 4), (0, 1)) + + self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn) + + def test_as_strided_with_gap_no_unit_stride(self): + + def test_fn(r): + x = r.view(8, 4) + return torch.as_strided(r, (4, 4), (6, 2)) + + self.runAtenTest([torch.arange(32, dtype=torch.int32)], test_fn) + + def test_as_strided_with_empty_args(self): + + def test_fn(r): + return torch.as_strided(r, tuple(), tuple()) + + self.runAtenTest([torch.arange(32, dtype=torch.int32)], test_fn) + def test_basic_bfloat16(self): def test_fn(s):