forked from omarocegueda/registration
-
Notifications
You must be signed in to change notification settings - Fork 2
/
SymmetricRegistrationOptimizer.py
454 lines (438 loc) · 20.7 KB
/
SymmetricRegistrationOptimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
'''
Especialization of the registration optimizer to perform symmetric registration
'''
import time
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
import registrationCommon as rcommon
import tensorFieldUtils as tf
import UpdateRule
from TransformationModel import TransformationModel
from SSDMetric import SSDMetric
from EMMetric import EMMetric
from RegistrationOptimizer import RegistrationOptimizer
from scipy import interpolate
class SymmetricRegistrationOptimizer(RegistrationOptimizer):
r'''
Performs the multi-resolution optimization algorithm for non-linear
registration using a given similarity metric and update rule (this
scheme was inspider on the ANTS package).
'''
def get_default_parameters(self):
return {'max_iter':[25, 50, 100], 'inversion_iter':20,
'inversion_tolerance':1e-3, 'tolerance':1e-4,
'report_status':True}
def __init__(self,
fixed = None,
moving = None,
affine_fixed = None,
affine_moving = None,
similarity_metric = None,
update_rule = None,
parameters = None):
super(SymmetricRegistrationOptimizer, self).__init__(
fixed, moving, affine_fixed, affine_moving, similarity_metric,
update_rule, parameters)
self.set_max_iter(self.parameters['max_iter'])
self.tolerance = self.parameters['tolerance']
self.inversion_tolerance = self.parameters['inversion_tolerance']
self.inversion_iter = self.parameters['inversion_iter']
self.report_status = self.parameters['report_status']
self.energy_window = 12
def __connect_functions(self):
r'''
Assigns the appropriate functions to be called for displacement field
inversion, Gaussian pyramid, and affine/dense deformation composition
according to the dimension of the input images
'''
if self.dim == 2:
self.invert_vector_field = tf.invert_vector_field_fixed_point
self.generate_pyramid = rcommon.pyramid_gaussian_2D
self.append_affine = tf.append_affine_to_displacement_field_2d
self.prepend_affine = tf.prepend_affine_to_displacement_field_2d
else:
self.invert_vector_field = tf.invert_vector_field_fixed_point3D
self.generate_pyramid = rcommon.pyramid_gaussian_3D
self.append_affine = tf.append_affine_to_displacement_field_3d
self.prepend_affine = tf.prepend_affine_to_displacement_field_3d
def __check_ready(self):
r'''
Verifies that the configuration of the optimizer and input data are
consistent and the optimizer is ready to run
'''
ready = True
if self.fixed == None:
ready = False
print('Error: Fixed image not set.')
elif self.dim != len(self.fixed.shape):
ready = False
print('Error: inconsistent dimensions. Last dimension update: %d.'
'Fixed image dimension: %d.'%(self.dim,
len(self.fixed.shape)))
if self.moving == None:
ready = False
print('Error: Moving image not set.')
elif self.dim != len(self.moving.shape):
ready = False
print('Error: inconsistent dimensions. Last dimension update: %d.'
'Moving image dimension: %d.'%(self.dim,
len(self.moving.shape)))
if self.similarity_metric == None:
ready = False
print('Error: Similarity metric not set.')
if self.update_rule == None:
ready = False
print('Error: Update rule not set.')
if self.max_iter == None:
ready = False
print('Error: Maximum number of iterations per level not set.')
return ready
def __init_optimizer(self):
r'''
Computes the Gaussian Pyramid of the input images and allocates
the required memory for the transformation models at the coarcest
scale.
'''
ready = self.__check_ready()
self.__connect_functions()
if not ready:
print 'Not ready'
return False
self.moving_pyramid = [img for img
in self.generate_pyramid(self.moving,
self.levels-1)]
self.fixed_pyramid = [img for img
in self.generate_pyramid(self.fixed,
self.levels-1)]
starting_forward = np.zeros(
shape = self.fixed_pyramid[self.levels-1].shape+(self.dim,),
dtype = np.float64)
starting_forward_inv = np.zeros(
shape = self.fixed_pyramid[self.levels-1].shape+(self.dim,),
dtype = np.float64)
self.forward_model.scale_affines(0.5**(self.levels-1))
self.forward_model.set_forward(starting_forward)
self.forward_model.set_backward(starting_forward_inv)
starting_backward = np.zeros(
shape = self.moving_pyramid[self.levels-1].shape+(self.dim,),
dtype = np.float64)
starting_backward_inverse = np.zeros(
shape = self.fixed_pyramid[self.levels-1].shape+(self.dim,),
dtype = np.float64)
self.backward_model.scale_affines(0.5**(self.levels-1))
self.backward_model.set_forward(starting_backward)
self.backward_model.set_backward(starting_backward_inverse)
def __end_optimizer(self):
r'''
Frees the resources allocated during initialization
'''
del self.moving_pyramid
del self.fixed_pyramid
def __iterate(self, show_images = False):
r'''
Performs one symmetric iteration:
1.Compute forward
2.Compute backward
3.Update forward
4.Update backward
5.Compute inverses
6.Invert the inverses to improve invertibility
'''
#tic = time.time()
wmoving = self.backward_model.warp_backward(self.current_moving)
wfixed = self.forward_model.warp_backward(self.current_fixed)
self.similarity_metric.set_moving_image(wmoving)
self.similarity_metric.use_moving_image_dynamics(
self.current_moving, self.backward_model.inverse())
self.similarity_metric.set_fixed_image(wfixed)
self.similarity_metric.use_fixed_image_dynamics(
self.current_fixed, self.forward_model.inverse())
self.similarity_metric.initialize_iteration()
ff_shape = np.array(self.forward_model.forward.shape).astype(np.int32)
fb_shape = np.array(self.forward_model.backward.shape).astype(np.int32)
bf_shape = np.array(self.backward_model.forward.shape).astype(np.int32)
bb_shape = np.array(self.backward_model.backward.shape).astype(np.int32)
del self.forward_model.backward
del self.backward_model.backward
fw_step = np.array(self.similarity_metric.compute_forward())
self.forward_model.forward, md_forward = self.update_rule.update(
self.forward_model.forward, fw_step)
del fw_step
try:
fw_energy = self.similarity_metric.energy
except NameError:
pass
bw_step = np.array(self.similarity_metric.compute_backward())
self.backward_model.forward, md_backward = self.update_rule.update(
self.backward_model.forward, bw_step)
del bw_step
try:
bw_energy = self.similarity_metric.energy
except NameError:
pass
der = '-'
try:
n_iter = len(self.energy_list)
if len(self.energy_list)>=self.energy_window:
der = self.__get_energy_derivative()
print('%d:\t%0.6f\t%0.6f\t%0.6f\t%s'%(n_iter , fw_energy, bw_energy,
fw_energy + bw_energy, der))
self.energy_list.append(fw_energy+bw_energy)
except NameError:
pass
self.similarity_metric.free_iteration()
inv_iter = self.inversion_iter
inv_tol = self.inversion_tolerance
self.forward_model.backward = np.array(
self.invert_vector_field(
self.forward_model.forward, fb_shape, inv_iter, inv_tol, None))
self.backward_model.backward = np.array(
self.invert_vector_field(
self.backward_model.forward, bb_shape, inv_iter, inv_tol, None))
self.forward_model.forward = np.array(
self.invert_vector_field(
self.forward_model.backward, ff_shape, inv_iter, inv_tol,
self.forward_model.forward))
self.backward_model.forward = np.array(
self.invert_vector_field(
self.backward_model.backward, bf_shape, inv_iter, inv_tol,
self.backward_model.forward))
if show_images:
self.similarity_metric.report_status()
#toc = time.time()
#print('Iter time: %f sec' % (toc - tic))
return 1 if der=='-' else der
def __get_energy_derivative(self):
r'''
Returns the derivative of the estimated energy as a function of "time"
(iterations) at the last iteration
'''
n_iter = len(self.energy_list)
if n_iter<self.energy_window:
print 'Error: attempting to fit the energy profile with less points (',n_iter,') than required (energy_window=', self.energy_window,')'
return 1
x=range(self.energy_window)
y=self.energy_list[(n_iter-self.energy_window):n_iter]
ss=sum(y)
if(ss>0):
ss*=-1
y=[v/ss for v in y]
spline = interpolate.UnivariateSpline(x, y, s = 1e6, k=2)
derivative = spline.derivative()
der = derivative(0.5*self.energy_window)
return der
def __report_status(self, level):
r'''
Shows the current overlaid images either on the common space or the
reference space
'''
show_common_space = True
if show_common_space:
wmoving = self.backward_model.warp_backward(self.current_moving)
wfixed = self.forward_model.warp_backward(self.current_fixed)
self.similarity_metric.set_moving_image(wmoving)
self.similarity_metric.use_moving_image_dynamics(
self.current_moving, self.backward_model.inverse())
self.similarity_metric.set_fixed_image(wfixed)
self.similarity_metric.use_fixed_image_dynamics(
self.current_fixed, self.forward_model.inverse())
self.similarity_metric.initialize_iteration()
self.similarity_metric.report_status()
else:
phi1 = self.forward_model.forward
phi2 = self.backward_model.backward
phi1_inv = self.forward_model.backward
phi2_inv = self.backward_model.forward
phi, mean_disp = self.update_rule.update(phi1, phi2)
phi_inv, mean_disp = self.update_rule.update(phi2_inv, phi1_inv)
composition = TransformationModel(phi, phi_inv, None, None)
composition.scale_affines(0.5**level)
residual, stats = composition.compute_inversion_error()
print('Current inversion error: %0.6f (%0.6f)'%(stats[1], stats[2]))
wmoving = composition.warp_forward(self.current_moving)
self.similarity_metric.set_moving_image(wmoving)
self.similarity_metric.use_moving_image_dynamics(
self.current_moving, composition)
self.similarity_metric.set_fixed_image(self.current_fixed)
self.similarity_metric.use_fixed_image_dynamics(
self.current_fixed, None)
self.similarity_metric.initialize_iteration()
self.similarity_metric.report_status()
def __optimize(self):
r'''
The main multi-scale symmetric optimization algorithm
'''
self.__init_optimizer()
for level in range(self.levels-1, -1, -1):
print 'Processing level', level
self.current_fixed = self.fixed_pyramid[level]
self.current_moving = self.moving_pyramid[level]
self.similarity_metric.use_original_fixed_image(
self.fixed_pyramid[level])
self.similarity_metric.use_original_fixed_image(
self.moving_pyramid[level])
self.similarity_metric.set_levels_below(self.levels-level)
self.similarity_metric.set_levels_above(level)
if level < self.levels - 1:
self.forward_model.upsample(self.current_fixed.shape,
self.current_fixed.shape)
self.backward_model.upsample(self.current_moving.shape,
self.current_fixed.shape)
niter = 0
self.energy_list = []
derivative = 1
while ((niter < self.max_iter[level]) and (self.tolerance<derivative)):
niter += 1
derivative = self.__iterate()
if self.report_status:
self.__report_status(level)
residual, stats = self.forward_model.compute_inversion_error()
print('Forward Residual error (Symmetric diffeomorphism):%0.6f (%0.6f)'
%(stats[1], stats[2]))
residual, stats = self.backward_model.compute_inversion_error()
print('Backward Residual error (Symmetric diffeomorphism):%0.6f (%0.6f)'
%(stats[1], stats[2]))
#Compose the two partial transformations
self.forward_model=self.backward_model.inverse().compose(self.forward_model)
self.forward_model.consolidate()
del self.backward_model
residual, stats = self.forward_model.compute_inversion_error()
print('Residual error (Symmetric diffeomorphism):%0.6f (%0.6f)'
%(stats[1], stats[2]))
self.__end_optimizer()
def optimize(self):
print 'Optimizer parameters:\n', self.parameters
print 'Metric:', self.similarity_metric.get_metric_name()
print 'Metric parameters:\n', self.similarity_metric.parameters
self.__optimize()
def test_optimizer_monomodal_2d():
r'''
Classical Circle-To-C experiment for 2D Monomodal registration
'''
fname_moving = 'data/circle.png'
fname_fixed = 'data/C.png'
moving = plt.imread(fname_moving)
fixed = plt.imread(fname_fixed)
moving = moving[:, :, 0].astype(np.float64)
fixed = fixed[:, :, 0].astype(np.float64)
moving = np.copy(moving, order = 'C')
fixed = np.copy(fixed, order = 'C')
moving = (moving-moving.min())/(moving.max() - moving.min())
fixed = (fixed-fixed.min())/(fixed.max() - fixed.min())
################Configure and run the Optimizer#####################
max_iter = [i for i in [20, 100, 100, 100]]
similarity_metric = SSDMetric(2, {'symmetric':True,
'lambda':5.0,
'stepType':SSDMetric.GAUSS_SEIDEL_STEP})
optimizer_parameters = {
'max_iter':max_iter,
'inversion_iter':40,
'inversion_tolerance':1e-3,
'report_status':True}
update_rule = UpdateRule.Composition()
registration_optimizer = SymmetricRegistrationOptimizer(fixed, moving,
None, None,
similarity_metric,
update_rule, optimizer_parameters)
registration_optimizer.optimize()
#######################show results#################################
displacement = registration_optimizer.get_forward()
direct_inverse = registration_optimizer.get_backward()
moving_to_fixed = np.array(tf.warp_image(moving, displacement))
fixed_to_moving = np.array(tf.warp_image(fixed, direct_inverse))
rcommon.overlayImages(moving_to_fixed, fixed, True)
rcommon.overlayImages(fixed_to_moving, moving, True)
direct_residual, stats = tf.compose_vector_fields(displacement,
direct_inverse)
direct_residual = np.array(direct_residual)
rcommon.plotDiffeomorphism(displacement, direct_inverse, direct_residual,
'inv-direct', 7)
def test_optimizer_multimodal_2d(lambda_param):
r'''
Registers one of the mid-slices (axial, coronal or sagital) of each input
volume (the volumes are expected to be from diferent modalities and
should already be affine-registered, for example Brainweb t1 vs t2)
'''
fname_moving = 'data/t2/IBSR_t2template_to_01.nii.gz'
fname_fixed = 'data/t1/IBSR_template_to_01.nii.gz'
# fnameMoving = 'data/circle.png'
# fnameFixed = 'data/C.png'
nifti = True
if nifti:
nib_moving = nib.load(fname_moving)
nib_fixed = nib.load(fname_fixed)
moving = nib_moving.get_data().squeeze().astype(np.float64)
fixed = nib_fixed.get_data().squeeze().astype(np.float64)
moving = np.copy(moving, order = 'C')
fixed = np.copy(fixed, order = 'C')
shape_moving = moving.shape
shape_fixed = fixed.shape
moving = moving[:, shape_moving[1]//2, :].copy()
fixed = fixed[:, shape_fixed[1]//2, :].copy()
moving = (moving-moving.min())/(moving.max()-moving.min())
fixed = (fixed-fixed.min())/(fixed.max()-fixed.min())
else:
nib_moving = plt.imread(fname_moving)
nib_fixed = plt.imread(fname_fixed)
moving = nib_moving[:, :, 0].astype(np.float64)
fixed = nib_fixed[:, :, 1].astype(np.float64)
moving = np.copy(moving, order = 'C')
fixed = np.copy(fixed, order = 'C')
moving = (moving-moving.min())/(moving.max() - moving.min())
fixed = (fixed-fixed.min())/(fixed.max() - fixed.min())
max_iter = [i for i in [25, 50, 100]]
similarity_metric = EMMetric(2, {'symmetric':True,
'lambda':lambda_param,
'stepType':SSDMetric.GAUSS_SEIDEL_STEP,
'q_levels':256,
'max_inner_iter':20,
'use_double_gradient':True,
'max_step_length':0.25})
optimizer_parameters = {
'max_iter':max_iter,
'inversion_iter':20,
'inversion_tolerance':1e-3,
'report_status':True}
update_rule = UpdateRule.Composition()
print('Generating synthetic field...')
#----apply synthetic deformation field to fixed image
ground_truth = rcommon.createDeformationField2D_type2(fixed.shape[0],
fixed.shape[1], 8)
warped_fixed = rcommon.warpImage(fixed, ground_truth)
print('Registering T2 (template) to deformed T1 (template)...')
plt.figure()
rcommon.overlayImages(warped_fixed, moving, False)
registration_optimizer = SymmetricRegistrationOptimizer(warped_fixed,
moving,
None, None,
similarity_metric,
update_rule,
optimizer_parameters)
registration_optimizer.optimize()
#######################show results#################################
displacement = registration_optimizer.get_forward()
direct_inverse = registration_optimizer.get_backward()
moving_to_fixed = np.array(tf.warp_image(moving, displacement))
fixed_to_moving = np.array(tf.warp_image(warped_fixed, direct_inverse))
rcommon.overlayImages(moving_to_fixed, fixed_to_moving, True)
direct_residual, stats = tf.compose_vector_fields(displacement,
direct_inverse)
direct_residual = np.array(direct_residual)
rcommon.plotDiffeomorphism(displacement, direct_inverse, direct_residual,
'inv-direct', 7)
residual = ((displacement-ground_truth))**2
mean_displacement_error = np.sqrt(residual.sum(2)*(warped_fixed>0)).mean()
stdev_displacement_error = np.sqrt(residual.sum(2)*(warped_fixed>0)).std()
print('Mean displacement error: %0.6f (%0.6f)'%
(mean_displacement_error, stdev_displacement_error))
if __name__ == '__main__':
start_time = time.time()
test_optimizer_multimodal_2d(50)
end_time = time.time()
print('Registration time: %f sec' % (end_time - start_time))
#testRegistrationOptimizerMonomodal2D()
# import nibabel as nib
# result = nib.load('data/circleToC.nii.gz')
# result = result.get_data().astype(np.double)
# plt.imshow(result)