Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add functionality for specifying def. xfrm #752

Merged
merged 2 commits into from
Dec 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 72 additions & 38 deletions ants/registration/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,7 +1619,15 @@ def label_image_registration(fixed_label_images,
type_of_deformable_transform : string
Only works with deformable-only transforms, specifically the family
of antsRegistrationSyN*[so] or antsRegistrationSyN*[bo] transforms.
See 'type_of_transform' in ants.registration.
See 'type_of_transform' in ants.registration. Additionally, one can
use a list to pass a more tailored deformably-only transform
optimization using SyN or BSplineSyN transforms. The order of
parameters in the list would be 1) transform specification, i.e.
"SyN" or "BSplineSyN", 2) gradient (real), 3) intensity metric (string),
4) intensity metric parameter (real), 5) convergence iterations per level
(tuple) 6) smoothing factors per level (tuple), 7) shrink factors per level
(tuple). An example would type_of_deformable_transform = ["SyN", 0.2, "CC",
4, (100,50,10), (2,1,0), (4,2,1)].

label_image_weighting : float or list of floats
Relative weighting for the label images.
Expand Down Expand Up @@ -1770,40 +1778,74 @@ def label_image_registration(fixed_label_images,
if verbose:
print("\n\nComputing deformable transform using images.\n")

do_quick = False
do_repro = False

if "Quick" in type_of_deformable_transform:
do_quick = True
elif "Repro" in type_of_deformable_transform:
do_repro = True
random_seed = str(1)

intensity_metric_parameter = None
intensity_metric = "CC"
intensity_metric_parameter = 2
syn_shrink_factors = "8x4x2x1"
syn_smoothing_sigmas = "3x2x1x0vox"
syn_convergence = "[100x70x50x20,1e-6,10]"
spline_distance = 26
if "[" in type_of_deformable_transform and "]" in type_of_deformable_transform:
subtype_of_deformable_transform = type_of_deformable_transform.split("[")[1].split("]")[0]
if not ('bo' in subtype_of_deformable_transform or 'so' in subtype_of_deformable_transform):
raise ValueError("Only 'so' or 'bo' transforms are available.")
if "," in subtype_of_deformable_transform:
subtype_of_deformable_transform_args = subtype_of_deformable_transform.split(",")
subtype_of_deformable_transform = subtype_of_deformable_transform_args[0]
intensity_metric_parameter = subtype_of_deformable_transform_args[1]
if len(subtype_of_deformable_transform_args) > 2:
spline_distance = subtype_of_deformable_transform_args[2]
gradient_step = 0.1
syn_transform = "SyN"

syn_stage = list()

intensity_metric = None
if fixed_intensity_images is not None and len(fixed_intensity_images) > 0:
if isinstance(type_of_deformable_transform, list):

if (len(type_of_deformable_transform) != 7 or
not isinstance(type_of_deformable_transform[0], str) or
not isinstance(type_of_deformable_transform[1], float) or
not isinstance(type_of_deformable_transform[2], str) or
not isinstance(type_of_deformable_transform[3], int) or
not isinstance(type_of_deformable_transform[4], tuple) or
not isinstance(type_of_deformable_transform[5], tuple) or
not isinstance(type_of_deformable_transform[6], tuple)):
raise ValueError("Incorrect specification for type_of_deformable_transform. See help menu.")

syn_transform = type_of_deformable_transform[0]
gradient_step = type_of_deformable_transform[1]
intensity_metric = type_of_deformable_transform[2]
intensity_metric_parameter = type_of_deformable_transform[3]

t = type_of_deformable_transform[4]
tstr = ''.join(map(lambda x: str(x) + 'x', t[:len(t)-1])) + str(t[len(t)-1])
syn_convergence = "[" + tstr + ",1e-6,10]"

t = type_of_deformable_transform[5]
tstr = ''.join(map(lambda x: str(x) + 'x', t[:len(t)-1])) + str(t[len(t)-1])
syn_smoothing_sigmas = tstr + "vox"

t = type_of_deformable_transform[6]
syn_shrink_factors = ''.join(map(lambda x: str(x) + 'x', t[:len(t)-1])) + str(t[len(t)-1])

else:

do_quick = False
if "Quick" in type_of_deformable_transform:
do_quick = True
elif "Repro" in type_of_deformable_transform:
random_seed = str(1)

if "[" in type_of_deformable_transform and "]" in type_of_deformable_transform:
subtype_of_deformable_transform = type_of_deformable_transform.split("[")[1].split("]")[0]
if not ('bo' in subtype_of_deformable_transform or 'so' in subtype_of_deformable_transform):
raise ValueError("Only 'so' or 'bo' transforms are available.")
else:
if 'bo' in subtype_of_deformable_transform:
syn_transform = "BSplineSyN"
if "," in subtype_of_deformable_transform:
subtype_of_deformable_transform_args = subtype_of_deformable_transform.split(",")
subtype_of_deformable_transform = subtype_of_deformable_transform_args[0]
intensity_metric_parameter = subtype_of_deformable_transform_args[1]
if len(subtype_of_deformable_transform_args) > 2:
spline_distance = subtype_of_deformable_transform_args[2]

if do_quick:
intensity_metric = "MI"
if intensity_metric_parameter is None:
intensity_metric_parameter = 32
if not do_quick or do_repro:
intensity_metric = "CC"
if intensity_metric_parameter is None:
intensity_metric_parameter = 2
syn_convergence = "[100x70x50x0,1e-6,10]"

if fixed_intensity_images is not None and len(fixed_intensity_images) > 0:
for i in range(len(fixed_intensity_images)):
syn_stage.append("--metric")
metric_string = "%s[%s,%s,%s,%s]" % (
Expand All @@ -1822,25 +1864,17 @@ def label_image_registration(fixed_label_images,
deformable_multivariate_extras[kk][3], 0.0)
syn_stage.append(metricString)

syn_shrink_factors = "8x4x2x1"
syn_smoothing_sigmas = "3x2x1x0vox"

if do_quick:
syn_convergence = "[100x70x50x0,1e-6,10]"
else:
syn_convergence = "[100x70x50x20,1e-6,10]"

syn_stage.append("--convergence")
syn_stage.append(syn_convergence)
syn_stage.append("--shrink-factors")
syn_stage.append(syn_shrink_factors)
syn_stage.append("--smoothing-sigmas")
syn_stage.append(syn_smoothing_sigmas)

if 'b' in subtype_of_deformable_transform:
syn_stage.insert(0, "BSplineSyN[0.1," + str(spline_distance) + ",0,3]")
if syn_transform == "SyN":
syn_stage.insert(0, "SyN[" + str(gradient_step) + ",3,0]")
else:
syn_stage.insert(0, "SyN[0.1,3,0]")
syn_stage.insert(0, "BSplineSyN[" + str(gradient_step) + "," + str(spline_distance) + ",0,3]")
syn_stage.insert(0, "--transform")

args = None
Expand Down