Skip to content

Commit

Permalink
added test for numpy conversions (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
denisri committed Feb 7, 2023
1 parent 186e901 commit 7ba5486
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions pyaims/python/soma/aims/tests/test_volume_strides.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,43 @@
import numpy as np
import sys


class TestVolumeStrides(unittest.TestCase):

def _test_numpy_conversion(self, vol):
arr = vol.arraydata()
test_vol = aims.Volume(arr)
dims = tuple(vol.getSize())
rdims = tuple(reversed(dims))
self.assertEqual(rdims, arr.shape)
self.assertEqual((36499400, 125860, 434, 2), arr.strides)
self.assertEqual(rdims, test_vol.getSize())
# Here, vol and test_vol are *NOT* equal: test_vol is transposed

arr = np.asarray(vol)
test_vol = aims.Volume(arr)
self.assertEqual(dims, arr.shape)
self.assertEqual((2, 434, 125860, 36499400), arr.strides)
self.assertEqual(dims, test_vol.getSize())
dif = (vol == test_vol)
self.assertTrue(np.all((dif).np))

arr = np.asfortranarray(vol)
test_vol = aims.Volume(arr)
self.assertEqual(dims, arr.shape)
self.assertEqual((2, 434, 125860, 36499400), arr.strides)
self.assertEqual(dims, test_vol.getSize())
dif = (vol == test_vol)
self.assertTrue(np.all((dif).np))

arr = np.ascontiguousarray(vol)
test_vol = aims.Volume(arr)
self.assertEqual(dims, arr.shape)
self.assertEqual((168200, 580, 2, 2), arr.strides)
self.assertEqual(dims, test_vol.getSize())
dif = (vol == test_vol)
self.assertTrue(np.all((dif).np))

def test_volume_strides(self):
vol = aims.Volume(4, 5, 6, dtype='S16')
self.assertEqual(vol.shape, (4, 5, 6, 1))
Expand All @@ -29,6 +64,12 @@ def test_volume_strides(self):
self.assertEqual(vol3.shape, (4, 5, 6, 1))
self.assertTrue(np.all(vol3.np == vol.np * 5))

def test_numpy_conversion(self):
vol = aims.Volume((217, 290, 290, 1), dtype='S16')
vol.np[:] = np.arange(vol.np.size).reshape(vol.np.shape)
self._test_numpy_conversion(vol)


def test():
suite = unittest.TestLoader().loadTestsFromTestCase(TestVolumeStrides)
runtime = unittest.TextTestRunner(verbosity=2).run(suite)
Expand Down

0 comments on commit 7ba5486

Please sign in to comment.