Skip to content

Commit

Permalink
Merge pull request #752 from ANTsX/LabelRegList
Browse files Browse the repository at this point in the history
ENH:  Add functionality for specifying def. xfrm
  • Loading branch information
ntustison authored Dec 3, 2024
2 parents 3cf8ac9 + c39bfd6 commit 2f55742
Showing 1 changed file with 72 additions and 38 deletions.
110 changes: 72 additions & 38 deletions ants/registration/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,7 +1619,15 @@ def label_image_registration(fixed_label_images,
type_of_deformable_transform : string
Only works with deformable-only transforms, specifically the family
of antsRegistrationSyN*[so] or antsRegistrationSyN*[bo] transforms.
See 'type_of_transform' in ants.registration.
See 'type_of_transform' in ants.registration. Additionally, one can
use a list to pass a more tailored deformably-only transform
optimization using SyN or BSplineSyN transforms. The order of
parameters in the list would be 1) transform specification, i.e.
"SyN" or "BSplineSyN", 2) gradient (real), 3) intensity metric (string),
4) intensity metric parameter (real), 5) convergence iterations per level
(tuple) 6) smoothing factors per level (tuple), 7) shrink factors per level
(tuple). An example would type_of_deformable_transform = ["SyN", 0.2, "CC",
4, (100,50,10), (2,1,0), (4,2,1)].
label_image_weighting : float or list of floats
Relative weighting for the label images.
Expand Down Expand Up @@ -1770,40 +1778,74 @@ def label_image_registration(fixed_label_images,
if verbose:
print("\n\nComputing deformable transform using images.\n")

do_quick = False
do_repro = False

if "Quick" in type_of_deformable_transform:
do_quick = True
elif "Repro" in type_of_deformable_transform:
do_repro = True
random_seed = str(1)

intensity_metric_parameter = None
intensity_metric = "CC"
intensity_metric_parameter = 2
syn_shrink_factors = "8x4x2x1"
syn_smoothing_sigmas = "3x2x1x0vox"
syn_convergence = "[100x70x50x20,1e-6,10]"
spline_distance = 26
if "[" in type_of_deformable_transform and "]" in type_of_deformable_transform:
subtype_of_deformable_transform = type_of_deformable_transform.split("[")[1].split("]")[0]
if not ('bo' in subtype_of_deformable_transform or 'so' in subtype_of_deformable_transform):
raise ValueError("Only 'so' or 'bo' transforms are available.")
if "," in subtype_of_deformable_transform:
subtype_of_deformable_transform_args = subtype_of_deformable_transform.split(",")
subtype_of_deformable_transform = subtype_of_deformable_transform_args[0]
intensity_metric_parameter = subtype_of_deformable_transform_args[1]
if len(subtype_of_deformable_transform_args) > 2:
spline_distance = subtype_of_deformable_transform_args[2]
gradient_step = 0.1
syn_transform = "SyN"

syn_stage = list()

intensity_metric = None
if fixed_intensity_images is not None and len(fixed_intensity_images) > 0:
if isinstance(type_of_deformable_transform, list):

if (len(type_of_deformable_transform) != 7 or
not isinstance(type_of_deformable_transform[0], str) or
not isinstance(type_of_deformable_transform[1], float) or
not isinstance(type_of_deformable_transform[2], str) or
not isinstance(type_of_deformable_transform[3], int) or
not isinstance(type_of_deformable_transform[4], tuple) or
not isinstance(type_of_deformable_transform[5], tuple) or
not isinstance(type_of_deformable_transform[6], tuple)):
raise ValueError("Incorrect specification for type_of_deformable_transform. See help menu.")

syn_transform = type_of_deformable_transform[0]
gradient_step = type_of_deformable_transform[1]
intensity_metric = type_of_deformable_transform[2]
intensity_metric_parameter = type_of_deformable_transform[3]

t = type_of_deformable_transform[4]
tstr = ''.join(map(lambda x: str(x) + 'x', t[:len(t)-1])) + str(t[len(t)-1])
syn_convergence = "[" + tstr + ",1e-6,10]"

t = type_of_deformable_transform[5]
tstr = ''.join(map(lambda x: str(x) + 'x', t[:len(t)-1])) + str(t[len(t)-1])
syn_smoothing_sigmas = tstr + "vox"

t = type_of_deformable_transform[6]
syn_shrink_factors = ''.join(map(lambda x: str(x) + 'x', t[:len(t)-1])) + str(t[len(t)-1])

else:

do_quick = False
if "Quick" in type_of_deformable_transform:
do_quick = True
elif "Repro" in type_of_deformable_transform:
random_seed = str(1)

if "[" in type_of_deformable_transform and "]" in type_of_deformable_transform:
subtype_of_deformable_transform = type_of_deformable_transform.split("[")[1].split("]")[0]
if not ('bo' in subtype_of_deformable_transform or 'so' in subtype_of_deformable_transform):
raise ValueError("Only 'so' or 'bo' transforms are available.")
else:
if 'bo' in subtype_of_deformable_transform:
syn_transform = "BSplineSyN"
if "," in subtype_of_deformable_transform:
subtype_of_deformable_transform_args = subtype_of_deformable_transform.split(",")
subtype_of_deformable_transform = subtype_of_deformable_transform_args[0]
intensity_metric_parameter = subtype_of_deformable_transform_args[1]
if len(subtype_of_deformable_transform_args) > 2:
spline_distance = subtype_of_deformable_transform_args[2]

if do_quick:
intensity_metric = "MI"
if intensity_metric_parameter is None:
intensity_metric_parameter = 32
if not do_quick or do_repro:
intensity_metric = "CC"
if intensity_metric_parameter is None:
intensity_metric_parameter = 2
syn_convergence = "[100x70x50x0,1e-6,10]"

if fixed_intensity_images is not None and len(fixed_intensity_images) > 0:
for i in range(len(fixed_intensity_images)):
syn_stage.append("--metric")
metric_string = "%s[%s,%s,%s,%s]" % (
Expand All @@ -1822,25 +1864,17 @@ def label_image_registration(fixed_label_images,
deformable_multivariate_extras[kk][3], 0.0)
syn_stage.append(metricString)

syn_shrink_factors = "8x4x2x1"
syn_smoothing_sigmas = "3x2x1x0vox"

if do_quick:
syn_convergence = "[100x70x50x0,1e-6,10]"
else:
syn_convergence = "[100x70x50x20,1e-6,10]"

syn_stage.append("--convergence")
syn_stage.append(syn_convergence)
syn_stage.append("--shrink-factors")
syn_stage.append(syn_shrink_factors)
syn_stage.append("--smoothing-sigmas")
syn_stage.append(syn_smoothing_sigmas)

if 'b' in subtype_of_deformable_transform:
syn_stage.insert(0, "BSplineSyN[0.1," + str(spline_distance) + ",0,3]")
if syn_transform == "SyN":
syn_stage.insert(0, "SyN[" + str(gradient_step) + ",3,0]")
else:
syn_stage.insert(0, "SyN[0.1,3,0]")
syn_stage.insert(0, "BSplineSyN[" + str(gradient_step) + "," + str(spline_distance) + ",0,3]")
syn_stage.insert(0, "--transform")

args = None
Expand Down

0 comments on commit 2f55742

Please sign in to comment.