forked from omarocegueda/registration
-
Notifications
You must be signed in to change notification settings - Fork 2
/
SimilarityMetric.py
174 lines (158 loc) · 6.71 KB
/
SimilarityMetric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
'''
Defines the contract that must be fulfilled by the especialized similarity
metrics to be used with a RegistrationOptimizer
'''
import abc
class SimilarityMetric(object):
'''
A similarity metric is in charge of keeping track of the numerical value
of the similarity (or distance) between the two given images. It also
computes the update field for the forward and inverse
displacement fields to be used in a gradient-based optimization algorithm.
Note that this metric does not depend on any transformation (affine or
non-linear), so it assumes the fixed and reference images are already warped
'''
__metaclass__ = abc.ABCMeta
def __init__(self, dim, parameters):
self.dim = dim
default_parameters = self.get_default_parameters()
for key, val in parameters.iteritems():
if key in default_parameters:
default_parameters[key] = val
else:
print "Warning: parameter '", key, "' unknown. Ignored."
self.parameters = default_parameters
self.set_fixed_image(None)
self.set_moving_image(None)
self.levels_above = 0
self.levels_below = 0
self.symmetric = False
def set_levels_below(self, levels):
r'''
Informs this metric the number of pyramid levels below the current one.
The metric may change its behavior (e.g. number of inner iterations)
accordingly
'''
self.levels_below = levels
def set_levels_above(self, levels):
r'''
Informs this metric the number of pyramid levels above the current one.
The metric may change its behavior (e.g. number of inner iterations)
accordingly
'''
self.levels_above = levels
def set_fixed_image(self, fixed_image):
'''
Sets the fixed image. Verifies that the image dimension is consistent
with this metric.
'''
new_dim = len(fixed_image.shape) if fixed_image != None else self.dim
if new_dim!=self.dim:
raise AttributeError('Unexpected fixed_image dimension: '+str(new_dim))
self.fixed_image = fixed_image
@abc.abstractmethod
def get_metric_name(self):
'''
Must return the name of the metric that specializes this generic metric
'''
pass
@abc.abstractmethod
def use_fixed_image_dynamics(self,
original_fixed_image,
transformation):
'''
This methods provides the metric a chance to compute any useful
information from knowing how the current fixed image was generated
(as the transformation of an original fixed image). This method is
called by the optimizer just after it sets the fixed image.
Transformation will be an instance of TransformationModel or None if
the originalMovingImage equals self.moving_image.
'''
@abc.abstractmethod
def use_original_fixed_image(self, original_fixed_image):
'''
This methods provides the metric a chance to compute any useful
information from the original moving image (to be used along with the
sequence of movingImages during optimization, for example the binary
mask delimiting the object of interest can be computed from the original
image only and then warp this binary mask instead of thresholding
at each iteration, which might cause artifacts due to interpolation)
'''
def set_moving_image(self, moving_image):
'''
Sets the moving image. Verifies that the image dimension is consistent
with this metric.
'''
new_dim = len(moving_image.shape) if moving_image != None else self.dim
if new_dim!=self.dim:
raise AttributeError('Unexpected fixed_image dimension: '+str(new_dim))
self.moving_image = moving_image
@abc.abstractmethod
def use_original_moving_image(self, original_moving_image):
'''
This methods provides the metric a chance to compute any useful
information from the original moving image (to be used along with the
sequence of movingImages during optimization, for example the binary
mask delimiting the object of interest can be computed from the original
image only and then warp this binary mask instead of thresholding
at each iteration, which might cause artifacts due to interpolation)
'''
@abc.abstractmethod
def use_moving_image_dynamics(self,
original_moving_image,
transformation):
'''
This methods provides the metric a chance to compute any useful
information from knowing how the current fixed image was generated
(as the transformation of an original fixed image). This method is
called by the optimizer just after it sets the fixed image.
Transformation will be an instance of TransformationModel or None if
the originalMovingImage equals self.moving_image.
'''
@abc.abstractmethod
def initialize_iteration(self):
'''
This method will be called before any computeUpdate or computeInverse
call, this gives the chance to the Metric to precompute any useful
information for speeding up the update computations. This initialization
was needed in ANTS because the updates are called once per voxel. In
Python this is unpractical, though.
'''
@abc.abstractmethod
def free_iteration(self):
'''
This method is called by the RegistrationOptimizer after the required
iterations have been computed (forward and/or backward) so that the
SimilarityMetric can safely delete any data it computed as part of the
initialization
'''
@abc.abstractmethod
def compute_forward(self):
'''
Must return the forward update field for a gradient-based optimization
algorithm
'''
@abc.abstractmethod
def compute_backward(self):
'''
Must return the inverse update field for a gradient-based optimization
algorithm
'''
@abc.abstractmethod
def get_energy(self):
'''
Must return the numeric value of the similarity between the given fixed
and moving images
'''
@abc.abstractmethod
def get_default_parameters(self):
r'''
Derived classes must return a dictionary containing its parameter names
and default values
'''
@abc.abstractmethod
def report_status(self):
'''
This function is called mostly for debugging purposes. The metric
can for example show the overlaid images or print some statistics
'''