forked from omarocegueda/registration
-
Notifications
You must be signed in to change notification settings - Fork 2
/
UpdateRule.py
73 lines (66 loc) · 2.49 KB
/
UpdateRule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
'''
This file contains the abstract UpdateRule which is in charge of updating
a displacement field with the new computed step. There are three main
different update rules: additive, compositive and compositive with previous
projection to the diffeomorphism space via displacement field exponentiation
'''
import abc
import numpy as np
import tensorFieldUtils as tf
class UpdateRule(object):
r'''
The abstract class defining the contract to be fulfilled by especialized
update rules.
'''
__metaclass__ = abc.ABCMeta
def __init__(self):
pass
@abc.abstractmethod
def update(self, new_displacement, current_displacement):
'''
Must return the updated displacement field and the mean norm of the
difference between the displacements before and after the update
'''
class Addition(UpdateRule):
r'''
Additive rule (simply adds the current displacement field with the new
step)
'''
def __init__(self):
pass
@staticmethod
def update(new_displacement, current_displacement):
mean_norm = np.sqrt(np.sum(new_displacement**2, -1)).mean()
updated = current_displacement+new_displacement
return updated, mean_norm
class Composition(UpdateRule):
r'''
Compositive update rule, composes the two displacement fields using
trilinear interpolation
'''
def __init__(self):
pass
@staticmethod
def update(new_displacement, current_displacement):
dim = len(new_displacement.shape)-1
mse = np.sqrt(np.sum((current_displacement**2), -1)).mean()
if dim == 2:
updated, stats = tf.compose_vector_fields(new_displacement,
current_displacement)
else:
updated, stats = tf.compose_vector_fields3D(new_displacement,
current_displacement)
return np.array(updated), np.array(mse)
class ProjectedComposition(UpdateRule):
r'''
Compositive update rule, composes the two displacement fields using
trilinear interpolation. Before composition, it applies the displacement
field exponentiation to the new step.
'''
def __init__(self):
pass
@staticmethod
def update(new_displacement, current_displacement):
expd, invexpd = tf.vector_field_exponential(new_displacement, True)
updated, stats = tf.compose_vector_fields(expd, current_displacement)
return updated, stats[0]