Skip to content

Commit

Permalink
Merge pull request #619 from ANTsX/add-more-tests
Browse files Browse the repository at this point in the history
add more tests
  • Loading branch information
Nicholas Cullen, PhD authored May 7, 2024
2 parents 4d49f44 + 577c1d6 commit 3c6c614
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 2 deletions.
3 changes: 3 additions & 0 deletions ants/registration/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ def registration(
>>> fi = ants.resample_image(fi, (60,60), 1, 0)
>>> mi = ants.resample_image(mi, (60,60), 1, 0)
>>> mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = 'SyN' )
>>> mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = 'antsRegistrationSyN[t]' )
>>> mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = 'antsRegistrationSyN[b]' )
>>> mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = 'antsRegistrationSyN[s]' )
"""
if isinstance(fixed, list) and (moving is None):
processed_args = utils._int_antsProcessArguments(fixed)
Expand Down
34 changes: 32 additions & 2 deletions tests/test_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,21 +369,37 @@ def tearDown(self):
def test_landmark_transforms(self):
fixed = np.array([[50.0,50.0],[200.0,50.0],[200.0,200.0]])
moving = np.array([[50.0,50.0],[50.0,200.0],[200.0,200.0]])
xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="syn",
domain_image=ants.image_read(ants.get_data('r16')))
xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="tv",
domain_image=ants.image_read(ants.get_data('r16')))
xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="affine")
xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="rigid")
xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="similarity")
domain_image = ants.image_read(ants.get_ants_data("r16"))
xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="bspline", domain_image=domain_image, number_of_fitting_levels=5)
xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="diffeo", domain_image=domain_image, number_of_fitting_levels=6)


res = ants.fit_time_varying_transform_to_point_sets([fixed, moving, moving],
domain_image=ants.image_read(ants.get_data('r16')))

def test_deformation_gradient(self):
fi = ants.image_read( ants.get_ants_data('r16'))
mi = ants.image_read( ants.get_ants_data('r64'))
fi = ants.resample_image(fi,(128,128),1,0)
mi = ants.resample_image(mi,(128,128),1,0)
mytx = ants.registration(fixed=fi , moving=mi, type_of_transform = ('SyN') )
dg = ants.deformation_gradient( ants.image_read( mytx['fwdtransforms'][0] ) )


dg = ants.deformation_gradient( ants.image_read( mytx['fwdtransforms'][0] ),
py_based=True)

dg = ants.deformation_gradient( ants.image_read( mytx['fwdtransforms'][0] ),
to_rotation=True)

dg = ants.deformation_gradient( ants.image_read( mytx['fwdtransforms'][0] ),
to_rotation=True, py_based=True)

def test_jacobian(self):
fi = ants.image_read( ants.get_ants_data('r16'))
mi = ants.image_read( ants.get_ants_data('r64'))
Expand Down Expand Up @@ -418,5 +434,19 @@ def test_warped_grid(self):
mywarpedgrid = ants.create_warped_grid( mi, grid_directions=(False,True),
transform=mytx['fwdtransforms'], fixed_reference_image=fi )

def test_more_registration(self):
fi = ants.image_read(ants.get_ants_data('r16'))
mi = ants.image_read(ants.get_ants_data('r64'))
fi = ants.resample_image(fi, (60,60), 1, 0)
mi = ants.resample_image(mi, (60,60), 1, 0)
mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = 'SyN' )
mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = 'antsRegistrationSyN[t]' )
mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = 'antsRegistrationSyN[b]' )
mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = 'antsRegistrationSyN[s]' )

def test_motion_correction(self):
fi = ants.image_read(ants.get_ants_data('ch2'))
mytx = ants.motion_correction( fi )

if __name__ == "__main__":
run_tests()
29 changes: 29 additions & 0 deletions tests/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,35 @@ def test_example(self):
label_list=seglist, rad=[r]*ref.dimension )
pp = ants.joint_label_fusion(ref,refmask,ilist, r_search=2, rad=2 )

def test_max_lab_plus_one(self):
ref = ants.image_read( ants.get_ants_data('r16'))
ref = ants.resample_image(ref, (50,50),1,0)
ref = ants.iMath(ref,'Normalize')
mi = ants.image_read( ants.get_ants_data('r27'))
mi2 = ants.image_read( ants.get_ants_data('r30'))
mi3 = ants.image_read( ants.get_ants_data('r62'))
mi4 = ants.image_read( ants.get_ants_data('r64'))
mi5 = ants.image_read( ants.get_ants_data('r85'))
refmask = ants.get_mask(ref)
refmask = ants.iMath(refmask,'ME',2) # just to speed things up
ilist = [mi,mi2,mi3,mi4,mi5]
seglist = [None]*len(ilist)
for i in range(len(ilist)):
ilist[i] = ants.iMath(ilist[i],'Normalize')
mytx = ants.registration(fixed=ref , moving=ilist[i] ,
typeofTransform = ('Affine') )
mywarpedimage = ants.apply_transforms(fixed=ref,moving=ilist[i],
transformlist=mytx['fwdtransforms'])
ilist[i] = mywarpedimage
seg = ants.threshold_image(ilist[i],'Otsu', 3)
seglist[i] = ( seg ) + ants.threshold_image( seg, 1, 3 ).morphology( operation='dilate', radius=3 )

r = 2
pp = ants.joint_label_fusion(ref, refmask, ilist, r_search=2,
label_list=seglist, rad=[r]*ref.dimension, max_lab_plus_one=True )
pp = ants.joint_label_fusion(ref,refmask,ilist, r_search=2, rad=2,
max_lab_plus_one=True)



class TestModule_kelly_kapowski(unittest.TestCase):
Expand Down
74 changes: 74 additions & 0 deletions tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,25 @@ def test_plot_example(self):
filename = mktemp(suffix='.png')
for img in self.imgs:
ants.plot(img)
ants.plot(img, overlay=img*2)
ants.plot(img, overlay=img*2)
ants.plot(img, filename=filename)

def test_extra_plot(self):
img = ants.image_read(ants.get_ants_data('r16'))
ants.plot(img, overlay=img*2, domain_image_map=ants.image_read(ants.get_data('r64')))

img = ants.image_read(ants.get_ants_data('r16'))
ants.plot(img, crop=True)

img = ants.image_read(ants.get_ants_data('mni'))
ants.plot(img, overlay=img*2,
domain_image_map=ants.image_read(ants.get_data('mni')).resample_image((4,4,4)))

img = ants.image_read(ants.get_ants_data('mni'))
ants.plot(img, overlay=img*2, reorient=True, crop=True)


class TestModule_plot_ortho(unittest.TestCase):

def setUp(self):
Expand All @@ -52,6 +69,15 @@ def test_plot_example(self):
for img in self.imgs:
ants.plot_ortho(img)
ants.plot_ortho(img, filename=filename)

def test_plot_extra(self):
img = ants.image_read(ants.get_ants_data('mni'))
ants.plot_ortho(img, overlay=img*2,
domain_image_map=ants.image_read(ants.get_data('mni')))

img = ants.image_read(ants.get_ants_data('mni'))
ants.plot_ortho(img, overlay=img*2, reorient=True, crop=True)


class TestModule_plot_ortho_stack(unittest.TestCase):

Expand All @@ -65,6 +91,14 @@ def test_plot_example(self):
filename = mktemp(suffix='.png')
ants.plot_ortho_stack([self.img, self.img])
ants.plot_ortho_stack([self.img, self.img], filename=filename)

def test_extra_ortho_stack(self):
img = ants.image_read(ants.get_ants_data('mni'))
ants.plot_ortho_stack([img, img], overlays=[img*2, img*2],
domain_image_map=ants.image_read(ants.get_data('mni')))

img = ants.image_read(ants.get_ants_data('mni'))
ants.plot_ortho_stack([img, img], overlays=[img*2, img*2], reorient=True, crop=True)

class TestModule_plot_hist(unittest.TestCase):

Expand Down Expand Up @@ -102,6 +136,46 @@ def test_plot_example(self):
ants.plot_grid(self.images3d)
# should work with 2d images
ants.plot_grid(self.images2d)

def test_examples(self):
mni1 = ants.image_read(ants.get_data('mni'))
mni2 = mni1.smooth_image(1.)
mni3 = mni1.smooth_image(2.)
mni4 = mni1.smooth_image(3.)
images = np.asarray([[mni1, mni2],
[mni3, mni4]])
slices = np.asarray([[100, 100],
[100, 100]])
ants.plot_grid(images=images, slices=slices, title='2x2 Grid')
images2d = np.asarray([[mni1.slice_image(2,100), mni2.slice_image(2,100)],
[mni3.slice_image(2,100), mni4.slice_image(2,100)]])
ants.plot_grid(images=images2d, title='2x2 Grid Pre-Sliced')
ants.plot_grid(images.reshape(1,4), slices.reshape(1,4), title='1x4 Grid')
ants.plot_grid(images.reshape(4,1), slices.reshape(4,1), title='4x1 Grid')

# Padding between rows and/or columns
ants.plot_grid(images, slices, cpad=0.02, title='Col Padding')
ants.plot_grid(images, slices, rpad=0.02, title='Row Padding')
ants.plot_grid(images, slices, rpad=0.02, cpad=0.02, title='Row and Col Padding')

# Adding plain row and/or column labels
ants.plot_grid(images, slices, title='Adding Row Labels', rlabels=['Row #1', 'Row #2'])
ants.plot_grid(images, slices, title='Adding Col Labels', clabels=['Col #1', 'Col #2'])
ants.plot_grid(images, slices, title='Row and Col Labels',
rlabels=['Row 1', 'Row 2'], clabels=['Col 1', 'Col 2'])

# Making a publication-quality image
images = np.asarray([[mni1, mni2, mni2],
[mni3, mni4, mni4]])
slices = np.asarray([[100, 100, 100],
[100, 100, 100]])
axes = np.asarray([[0, 1, 2],
[0, 1, 2]])
ants.plot_grid(images, slices, axes, title='Publication Figures with ANTsPy',
tfontsize=20, title_dy=0.03, title_dx=-0.04,
rlabels=['Row 1', 'Row 2'],
clabels=['Col 1', 'Col 2', 'Col 3'],
rfontsize=16, cfontsize=16)



Expand Down

0 comments on commit 3c6c614

Please sign in to comment.