diff --git a/ants/registration/registration.py b/ants/registration/registration.py index 54c06504..2c6739a1 100644 --- a/ants/registration/registration.py +++ b/ants/registration/registration.py @@ -1619,7 +1619,15 @@ def label_image_registration(fixed_label_images, type_of_deformable_transform : string Only works with deformable-only transforms, specifically the family of antsRegistrationSyN*[so] or antsRegistrationSyN*[bo] transforms. - See 'type_of_transform' in ants.registration. + See 'type_of_transform' in ants.registration. Additionally, one can + use a list to pass a more tailored deformably-only transform + optimization using SyN or BSplineSyN transforms. The order of + parameters in the list would be 1) transform specification, i.e. + "SyN" or "BSplineSyN", 2) gradient (real), 3) intensity metric (string), + 4) intensity metric parameter (real), 5) convergence iterations per level + (tuple) 6) smoothing factors per level (tuple), 7) shrink factors per level + (tuple). An example would type_of_deformable_transform = ["SyN", 0.2, "CC", + 4, (100,50,10), (2,1,0), (4,2,1)]. label_image_weighting : float or list of floats Relative weighting for the label images. @@ -1770,40 +1778,74 @@ def label_image_registration(fixed_label_images, if verbose: print("\n\nComputing deformable transform using images.\n") - do_quick = False - do_repro = False - - if "Quick" in type_of_deformable_transform: - do_quick = True - elif "Repro" in type_of_deformable_transform: - do_repro = True - random_seed = str(1) - - intensity_metric_parameter = None + intensity_metric = "CC" + intensity_metric_parameter = 2 + syn_shrink_factors = "8x4x2x1" + syn_smoothing_sigmas = "3x2x1x0vox" + syn_convergence = "[100x70x50x20,1e-6,10]" spline_distance = 26 - if "[" in type_of_deformable_transform and "]" in type_of_deformable_transform: - subtype_of_deformable_transform = type_of_deformable_transform.split("[")[1].split("]")[0] - if not ('bo' in subtype_of_deformable_transform or 'so' in subtype_of_deformable_transform): - raise ValueError("Only 'so' or 'bo' transforms are available.") - if "," in subtype_of_deformable_transform: - subtype_of_deformable_transform_args = subtype_of_deformable_transform.split(",") - subtype_of_deformable_transform = subtype_of_deformable_transform_args[0] - intensity_metric_parameter = subtype_of_deformable_transform_args[1] - if len(subtype_of_deformable_transform_args) > 2: - spline_distance = subtype_of_deformable_transform_args[2] + gradient_step = 0.1 + syn_transform = "SyN" syn_stage = list() - intensity_metric = None - if fixed_intensity_images is not None and len(fixed_intensity_images) > 0: + if isinstance(type_of_deformable_transform, list): + + if (len(type_of_deformable_transform) != 7 or + not isinstance(type_of_deformable_transform[0], str) or + not isinstance(type_of_deformable_transform[1], float) or + not isinstance(type_of_deformable_transform[2], str) or + not isinstance(type_of_deformable_transform[3], int) or + not isinstance(type_of_deformable_transform[4], tuple) or + not isinstance(type_of_deformable_transform[5], tuple) or + not isinstance(type_of_deformable_transform[6], tuple)): + raise ValueError("Incorrect specification for type_of_deformable_transform. See help menu.") + + syn_transform = type_of_deformable_transform[0] + gradient_step = type_of_deformable_transform[1] + intensity_metric = type_of_deformable_transform[2] + intensity_metric_parameter = type_of_deformable_transform[3] + + t = type_of_deformable_transform[4] + tstr = ''.join(map(lambda x: str(x) + 'x', t[:len(t)-1])) + str(t[len(t)-1]) + syn_convergence = "[" + tstr + ",1e-6,10]" + + t = type_of_deformable_transform[5] + tstr = ''.join(map(lambda x: str(x) + 'x', t[:len(t)-1])) + str(t[len(t)-1]) + syn_smoothing_sigmas = tstr + "vox" + + t = type_of_deformable_transform[6] + syn_shrink_factors = ''.join(map(lambda x: str(x) + 'x', t[:len(t)-1])) + str(t[len(t)-1]) + + else: + + do_quick = False + if "Quick" in type_of_deformable_transform: + do_quick = True + elif "Repro" in type_of_deformable_transform: + random_seed = str(1) + + if "[" in type_of_deformable_transform and "]" in type_of_deformable_transform: + subtype_of_deformable_transform = type_of_deformable_transform.split("[")[1].split("]")[0] + if not ('bo' in subtype_of_deformable_transform or 'so' in subtype_of_deformable_transform): + raise ValueError("Only 'so' or 'bo' transforms are available.") + else: + if 'bo' in subtype_of_deformable_transform: + syn_transform = "BSplineSyN" + if "," in subtype_of_deformable_transform: + subtype_of_deformable_transform_args = subtype_of_deformable_transform.split(",") + subtype_of_deformable_transform = subtype_of_deformable_transform_args[0] + intensity_metric_parameter = subtype_of_deformable_transform_args[1] + if len(subtype_of_deformable_transform_args) > 2: + spline_distance = subtype_of_deformable_transform_args[2] + if do_quick: intensity_metric = "MI" if intensity_metric_parameter is None: intensity_metric_parameter = 32 - if not do_quick or do_repro: - intensity_metric = "CC" - if intensity_metric_parameter is None: - intensity_metric_parameter = 2 + syn_convergence = "[100x70x50x0,1e-6,10]" + + if fixed_intensity_images is not None and len(fixed_intensity_images) > 0: for i in range(len(fixed_intensity_images)): syn_stage.append("--metric") metric_string = "%s[%s,%s,%s,%s]" % ( @@ -1822,14 +1864,6 @@ def label_image_registration(fixed_label_images, deformable_multivariate_extras[kk][3], 0.0) syn_stage.append(metricString) - syn_shrink_factors = "8x4x2x1" - syn_smoothing_sigmas = "3x2x1x0vox" - - if do_quick: - syn_convergence = "[100x70x50x0,1e-6,10]" - else: - syn_convergence = "[100x70x50x20,1e-6,10]" - syn_stage.append("--convergence") syn_stage.append(syn_convergence) syn_stage.append("--shrink-factors") @@ -1837,10 +1871,10 @@ def label_image_registration(fixed_label_images, syn_stage.append("--smoothing-sigmas") syn_stage.append(syn_smoothing_sigmas) - if 'b' in subtype_of_deformable_transform: - syn_stage.insert(0, "BSplineSyN[0.1," + str(spline_distance) + ",0,3]") + if syn_transform == "SyN": + syn_stage.insert(0, "SyN[" + str(gradient_step) + ",3,0]") else: - syn_stage.insert(0, "SyN[0.1,3,0]") + syn_stage.insert(0, "BSplineSyN[" + str(gradient_step) + "," + str(spline_distance) + ",0,3]") syn_stage.insert(0, "--transform") args = None