Skip to content

Commit

Permalink
Merge pull request #757 from ANTsX/syn_only_init
Browse files Browse the repository at this point in the history
ENH: Don't do COM initialization for "syn-only" reg

Also includes TV* registrations, which now all support an initial transform with default of identity
  • Loading branch information
cookpa authored Dec 12, 2024
2 parents 80826f5 + d1590d5 commit c06c34f
Showing 1 changed file with 78 additions and 70 deletions.
148 changes: 78 additions & 70 deletions ants/registration/registration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
ANTsPy Registration
"""
__all__ = ["registration",
__all__ = ["registration",
"motion_correction",
"label_image_registration"]

Expand Down Expand Up @@ -64,8 +64,9 @@ def registration(
See Notes below for more.
initial_transform : list of strings (optional)
transforms to prepend. If None, a translation is computed to align the image centers of mass.
To use an identity transform, set this to 'Identity'.
transforms to prepend. If None, a translation is computed to align the image centers of mass, unless the type of
transform is deformable-only (time-varying diffeomorphisms, SyNOnly, or antsRegistrationSyN*[so|bo]).
To force initialization with an identity transform, set this to 'Identity'.
outprefix : string
output will be named with this prefix.
Expand All @@ -84,10 +85,10 @@ def registration(
flow_sigma : scalar
smoothing for update field
At each iteration, the similarity metric and gradient is calculated.
That gradient field is also called the update field and is smoothed
before composing with the total field (i.e., the estimate of the total
transform at that iteration). This total field can also be smoothed
At each iteration, the similarity metric and gradient is calculated.
That gradient field is also called the update field and is smoothed
before composing with the total field (i.e., the estimate of the total
transform at that iteration). This total field can also be smoothed
after each iteration.
total_sigma : scalar
Expand Down Expand Up @@ -155,7 +156,7 @@ def registration(
singleprecision : boolean
if True, use float32 for computations. This is useful for reducing memory
usage for large datasets, at the cost of precision.
usage for large datasets, at the cost of precision.
kwargs : keyword args
extra arguments
Expand Down Expand Up @@ -197,12 +198,7 @@ def registration(
- "SyNRA": Symmetric normalization: Rigid + Affine + deformable
transformation, with mutual information as optimization metric.
- "SyNOnly": Symmetric normalization with no rigid or affine stages.
Uses mutual information as optimization metric. Affine alignment is
from the initial_transform arg, either provide the .mat from linear
registration or use initial_transform='Identity' if the images are
already affinely aligned.
Can be useful if you want to run an unmasked affine followed by
masked deformable registration.
Uses mutual information as optimization metric.
- "SyNCC": SyN, but with cross-correlation as the metric.
- "SyNabp": SyN optimized for abpBrainExtraction.
- "SyNBold": SyN, but optimized for registrations between BOLD and T1 images.
Expand Down Expand Up @@ -446,8 +442,18 @@ def registration(
else:
earlymaskopt = "[NA,NA]"

deformable_only_transforms = ["SyNOnly", "antsRegistrationSyN[so]", "antsRegistrationSyNQuick[so]",
"antsRegistrationSyNRepro[so]", "antsRegistrationSyNQuickRepro[so]",
"antsRegistrationSyN[bo]", "antsRegistrationSyNQuick[bo]",
"antsRegistrationSyNRepro[bo]", "antsRegistrationSyNQuickRepro[bo]",
"TVMSQ", "TVMSQC"] + tvTypes

if initx is None:
initx = ["[%s,%s,1]" % (f, m)]
if type_of_transform in deformable_only_transforms:
initx = ["Identity"]
else:
initx = ["[%s,%s,1]" % (f, m)]

# ------------------------------------------------------------
if type_of_transform == "SyNBold":
args = [
Expand Down Expand Up @@ -1067,7 +1073,8 @@ def registration(
args = [
"-d",
str(fixed.dimension),
# '-r', initx,
'-r'
] + initx + [
"-m",
"%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling),
"-t",
Expand Down Expand Up @@ -1098,7 +1105,8 @@ def registration(
args = [
"-d",
str(fixed.dimension),
# '-r', initx,
'-r'
] + initx + [
"-m",
"demons[%s,%s,0.5,0]" % (f, m),
"-m",
Expand Down Expand Up @@ -1573,7 +1581,7 @@ def motion_correction(
"FD": FD,
}

def label_image_registration(fixed_label_images,
def label_image_registration(fixed_label_images,
moving_label_images,
fixed_intensity_images=None,
moving_intensity_images=None,
Expand All @@ -1587,8 +1595,8 @@ def label_image_registration(fixed_label_images,
verbose=False):

"""
Perform pairwise registration using fixed and moving sets of label
images (and, optionally, sets of corresponding intensity images).
Perform pairwise registration using fixed and moving sets of label
images (and, optionally, sets of corresponding intensity images).
Arguments
---------
Expand All @@ -1607,34 +1615,34 @@ def label_image_registration(fixed_label_images,
fixed_mask : ANTsImage
Defines region for similarity metric calculation in the space
of the fixed image.
moving_mask : ANTsImage
Defines region for similarity metric calculation in the space
of the moving image.
type_of_linear_transform : string
Use label images with the centers of mass to a calculate linear
Use label images with the centers of mass to a calculate linear
transform of type 'rigid', 'similarity', or 'affine'.
type_of_deformable_transform : string
Only works with deformable-only transforms, specifically the family
of antsRegistrationSyN*[so] or antsRegistrationSyN*[bo] transforms.
of antsRegistrationSyN*[so] or antsRegistrationSyN*[bo] transforms.
See 'type_of_transform' in ants.registration. Additionally, one can
use a list to pass a more tailored deformably-only transform
optimization using SyN or BSplineSyN transforms. The order of
use a list to pass a more tailored deformably-only transform
optimization using SyN or BSplineSyN transforms. The order of
parameters in the list would be 1) transform specification, i.e.
"SyN" or "BSplineSyN", 2) gradient (real), 3) intensity metric (string),
4) intensity metric parameter (real), 5) convergence iterations per level
(tuple) 6) smoothing factors per level (tuple), 7) shrink factors per level
(tuple). An example would type_of_deformable_transform = ["SyN", 0.2, "CC",
"SyN" or "BSplineSyN", 2) gradient (real), 3) intensity metric (string),
4) intensity metric parameter (real), 5) convergence iterations per level
(tuple) 6) smoothing factors per level (tuple), 7) shrink factors per level
(tuple). An example would type_of_deformable_transform = ["SyN", 0.2, "CC",
4, (100,50,10), (2,1,0), (4,2,1)].
label_image_weighting : float or list of floats
Relative weighting for the label images.
output_prefix : string
Define the output prefix for the filenames of the output transform
files.
files.
random_seed : integer
Definition for deformable registration.
Expand All @@ -1644,7 +1652,7 @@ def label_image_registration(fixed_label_images,
Returns
-------
Set of transforms definining the mapping to/from the fixed image domain
Set of transforms definining the mapping to/from the fixed image domain
to the moving image domain.
Example
Expand All @@ -1658,7 +1666,7 @@ def label_image_registration(fixed_label_images,
>>> r64_seg1 = ants.threshold_image(r64, "Kmeans", 3) - 1
>>> r64_seg2 = ants.threshold_image(r64, "Kmeans", 5) - 1
>>> reg = ants.label_image_registration([r16_seg1, r16_seg2],
[r64_seg1, r64_seg2],
[r64_seg1, r64_seg2],
fixed_intensity_images=r16,
moving_intensity_images=r64,
type_of_linear_transform='affine',
Expand Down Expand Up @@ -1691,18 +1699,18 @@ def label_image_registration(fixed_label_images,
else:
label_image_weights = tuple(label_image_weighting)
if len(fixed_label_images) != len(label_image_weights):
raise ValueError("The length of label_image_weights must" +
raise ValueError("The length of label_image_weights must" +
"match the number of label image pairs.")

image_dimension = fixed_label_images[0].dimension

if output_prefix == "" or output_prefix is None or len(output_prefix) == 0:
output_prefix = mktemp()

allowable_linear_transforms = ['rigid', 'similarity', 'affine']
allowable_linear_transforms = ['rigid', 'similarity', 'affine']
if not type_of_linear_transform in allowable_linear_transforms:
raise ValueError("Unrecognized linear transform.")
raise ValueError("Unrecognized linear transform.")

do_deformable = True
if type_of_deformable_transform is None or len(type_of_deformable_transform) == 0:
do_deformable = False
Expand All @@ -1720,7 +1728,7 @@ def label_image_registration(fixed_label_images,
print("Common label ids for image pair ", str(i), ": ", common_label_ids[i])
if len(common_label_ids[i]) == 0:
raise ValueError("No common labels for image pair " + str(i))

if verbose:
print("Total number of labels: " + str(total_number_of_labels))

Expand All @@ -1737,12 +1745,12 @@ def label_image_registration(fixed_label_images,
print("\n\nComputing linear transform.\n")

if total_number_of_labels < 3:
raise ValueError(" Number of labels must be >= 3.")
raise ValueError(" Number of labels must be >= 3.")

fixed_centers_of_mass = np.zeros((total_number_of_labels, image_dimension))
fixed_centers_of_mass = np.zeros((total_number_of_labels, image_dimension))
moving_centers_of_mass = np.zeros((total_number_of_labels, image_dimension))
deformable_multivariate_extras = list()

count = 0
for i in range(len(common_label_ids)):
for j in range(len(common_label_ids[i])):
Expand All @@ -1755,17 +1763,17 @@ def label_image_registration(fixed_label_images,
moving_centers_of_mass[count, :] = ants.get_center_of_mass(moving_single_label_image)
count += 1
if do_deformable:
deformable_multivariate_extras.append(["MSQ", fixed_single_label_image,
moving_single_label_image,
deformable_multivariate_extras.append(["MSQ", fixed_single_label_image,
moving_single_label_image,
label_image_weights[i], 0])
linear_xfrm = ants.fit_transform_to_paired_points(moving_centers_of_mass,
fixed_centers_of_mass,

linear_xfrm = ants.fit_transform_to_paired_points(moving_centers_of_mass,
fixed_centers_of_mass,
transform_type=type_of_linear_transform,
verbose=verbose)

linear_xfrm_file = output_prefix + "0GenericAffine.mat"
ants.write_transform(linear_xfrm, linear_xfrm_file)
ants.write_transform(linear_xfrm, linear_xfrm_file)

##############################
#
Expand All @@ -1787,25 +1795,25 @@ def label_image_registration(fixed_label_images,
gradient_step = 0.1
syn_transform = "SyN"

syn_stage = list()
syn_stage = list()

if isinstance(type_of_deformable_transform, list):

if (len(type_of_deformable_transform) != 7 or
not isinstance(type_of_deformable_transform[0], str) or
not isinstance(type_of_deformable_transform[1], float) or
not isinstance(type_of_deformable_transform[2], str) or
not isinstance(type_of_deformable_transform[3], int) or
not isinstance(type_of_deformable_transform[4], tuple) or
not isinstance(type_of_deformable_transform[5], tuple) or
not isinstance(type_of_deformable_transform[3], int) or
not isinstance(type_of_deformable_transform[4], tuple) or
not isinstance(type_of_deformable_transform[5], tuple) or
not isinstance(type_of_deformable_transform[6], tuple)):
raise ValueError("Incorrect specification for type_of_deformable_transform. See help menu.")
raise ValueError("Incorrect specification for type_of_deformable_transform. See help menu.")

syn_transform = type_of_deformable_transform[0]
gradient_step = type_of_deformable_transform[1]
intensity_metric = type_of_deformable_transform[2]
intensity_metric_parameter = type_of_deformable_transform[3]

t = type_of_deformable_transform[4]
tstr = ''.join(map(lambda x: str(x) + 'x', t[:len(t)-1])) + str(t[len(t)-1])
syn_convergence = "[" + tstr + ",1e-6,10]"
Expand All @@ -1816,8 +1824,8 @@ def label_image_registration(fixed_label_images,

t = type_of_deformable_transform[6]
syn_shrink_factors = ''.join(map(lambda x: str(x) + 'x', t[:len(t)-1])) + str(t[len(t)-1])
else:

else:

do_quick = False
if "Quick" in type_of_deformable_transform:
Expand All @@ -1840,10 +1848,10 @@ def label_image_registration(fixed_label_images,
spline_distance = subtype_of_deformable_transform_args[2]

if do_quick:
intensity_metric = "MI"
intensity_metric = "MI"
if intensity_metric_parameter is None:
intensity_metric_parameter = 32
syn_convergence = "[100x70x50x0,1e-6,10]"
syn_convergence = "[100x70x50x0,1e-6,10]"

if fixed_intensity_images is not None and len(fixed_intensity_images) > 0:
for i in range(len(fixed_intensity_images)):
Expand All @@ -1854,15 +1862,15 @@ def label_image_registration(fixed_label_images,
get_pointer_string(moving_intensity_images[i]),
1.0, intensity_metric_parameter)
syn_stage.append(metric_string)

for kk in range(len(deformable_multivariate_extras)):
syn_stage.append("--metric")
metricString = "%s[%s,%s,%s,%s]" % (
"MSQ",
get_pointer_string(deformable_multivariate_extras[kk][1]),
get_pointer_string(deformable_multivariate_extras[kk][2]),
deformable_multivariate_extras[kk][3], 0.0)
syn_stage.append(metricString)
syn_stage.append(metricString)

syn_stage.append("--convergence")
syn_stage.append(syn_convergence)
Expand All @@ -1887,25 +1895,25 @@ def label_image_registration(fixed_label_images,
"-o", output_prefix]
args.append(syn_stage)

fixed_mask_string = 'NA'
fixed_mask_string = 'NA'
if fixed_mask is not None:
fixed_mask_binary = fixed_mask != 0
fixed_mask_string = get_pointer_string(fixed_mask_binary)

moving_mask_string = 'NA'
moving_mask_string = 'NA'
if moving_mask is not None:
moving_mask_binary = moving_mask != 0
moving_mask_string = get_pointer_string(moving_mask_binary)

mask_option = "[%s,%s]" % (fixed_mask_string, moving_mask_string)

args.append("-x")
args.append(mask_option)

args = list(itertools.chain.from_iterable(
itertools.repeat(x, 1)
if isinstance(x, str)
else x for x in args))
itertools.repeat(x, 1)
if isinstance(x, str)
else x for x in args))

args.append("--float")
args.append("1")
Expand All @@ -1929,7 +1937,7 @@ def label_image_registration(fixed_label_images,
raise RuntimeError(f"Registration failed with error code {deformable_registration_exit_error}")

all_xfrms = sorted(set(glob.glob(output_prefix + "*" + "[0-9]*")))

find_inverse_warps = np.where([re.search("[0-9]InverseWarp.nii.gz", ff) for ff in all_xfrms])[0]
find_forward_warps = np.where([re.search("[0-9]Warp.nii.gz", ff) for ff in all_xfrms])[0]

Expand All @@ -1942,8 +1950,8 @@ def label_image_registration(fixed_label_images,

if verbose:
print("\n\nResulting transforms")
print(" fwdtransforms: ", fwdtransforms)
print(" invtransforms: ", invtransforms)
print(" fwdtransforms: ", fwdtransforms)
print(" invtransforms: ", invtransforms)

return {
"fwdtransforms": fwdtransforms,
Expand Down

0 comments on commit c06c34f

Please sign in to comment.