forked from hkaggarwal/modl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
195 lines (175 loc) · 6.07 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
This code will create the model described in our following paper
MoDL: Model-Based Deep Learning Architecture for Inverse Problems
by H.K. Aggarwal, M.P. Mani, M. Jacob from University of Iowa.
Paper dwonload Link: https://arxiv.org/abs/1712.02862
@author: haggarwal
"""
import tensorflow as tf
import numpy as np
from os.path import expanduser
home = expanduser("~")
epsilon=1e-5
TFeps=tf.constant(1e-5,dtype=tf.float32)
# function c2r contatenate complex input as new axis two two real inputs
c2r=lambda x:tf.stack([tf.real(x),tf.imag(x)],axis=-1)
#r2c takes the last dimension of real input and converts to complex
r2c=lambda x:tf.complex(x[...,0],x[...,1])
def createLayer(x, szW, trainning,lastLayer):
"""
This function create a layer of CNN consisting of convolution, batch-norm,
and ReLU. Last layer does not have ReLU to avoid truncating the negative
part of the learned noise and alias patterns.
"""
W=tf.get_variable('W',shape=szW,initializer=tf.contrib.layers.xavier_initializer())
x = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
xbn=tf.layers.batch_normalization(x,training=trainning,fused=True,name='BN')
if not(lastLayer):
return tf.nn.relu(xbn)
else:
return xbn
def dw(inp,trainning,nLay):
"""
This is the Dw block as defined in the Fig. 1 of the MoDL paper
It creates an n-layer (nLay) residual learning CNN.
Convolution filters are of size 3x3 and 64 such filters are there.
nw: It is the learned noise
dw: it is the output of residual learning after adding the input back.
"""
lastLayer=False
nw={}
nw['c'+str(0)]=inp
szW={}
szW = {key: (3,3,64,64) for key in range(2,nLay)}
szW[1]=(3,3,2,64)
szW[nLay]=(3,3,64,2)
for i in np.arange(1,nLay+1):
if i==nLay:
lastLayer=True
with tf.variable_scope('Layer'+str(i)):
nw['c'+str(i)]=createLayer(nw['c'+str(i-1)],szW[i],trainning,lastLayer)
with tf.name_scope('Residual'):
shortcut=tf.identity(inp)
dw=shortcut+nw['c'+str(nLay)]
return dw
class Aclass:
"""
This class is created to do the data-consistency (DC) step as described in paper.
"""
def __init__(self, csm,mask,lam):
with tf.name_scope('Ainit'):
s=tf.shape(mask)
self.nrow,self.ncol=s[0],s[1]
self.pixels=self.nrow*self.ncol
self.mask=mask
self.csm=csm
self.SF=tf.complex(tf.sqrt(tf.to_float(self.pixels) ),0.)
self.lam=lam
#self.cgIter=cgIter
#self.tol=tol
def myAtA(self,img):
with tf.name_scope('AtA'):
coilImages=self.csm*img
kspace= tf.fft2d(coilImages)/self.SF
temp=kspace*self.mask
coilImgs =tf.ifft2d(temp)*self.SF
coilComb= tf.reduce_sum(coilImgs*tf.conj(self.csm),axis=0)
coilComb=coilComb+self.lam*img
return coilComb
def myCG(A,rhs):
"""
This is my implementation of CG algorithm in tensorflow that works on
complex data and runs on GPU. It takes the class object as input.
"""
rhs=r2c(rhs)
cond=lambda i,rTr,*_: tf.logical_and( tf.less(i,10), rTr>1e-10)
def body(i,rTr,x,r,p):
with tf.name_scope('cgBody'):
Ap=A.myAtA(p)
alpha = rTr / tf.to_float(tf.reduce_sum(tf.conj(p)*Ap))
alpha=tf.complex(alpha,0.)
x = x + alpha * p
r = r - alpha * Ap
rTrNew = tf.to_float( tf.reduce_sum(tf.conj(r)*r))
beta = rTrNew / rTr
beta=tf.complex(beta,0.)
p = r + beta * p
return i+1,rTrNew,x,r,p
x=tf.zeros_like(rhs)
i,r,p=0,rhs,rhs
rTr = tf.to_float( tf.reduce_sum(tf.conj(r)*r),)
loopVar=i,rTr,x,r,p
out=tf.while_loop(cond,body,loopVar,name='CGwhile',parallel_iterations=1)[2]
return c2r(out)
def getLambda():
"""
create a shared variable called lambda.
"""
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
lam = tf.get_variable(name='lam1', dtype=tf.float32, initializer=.05)
return lam
def callCG(rhs):
"""
this function will call the function myCG on each image in a batch
"""
G=tf.get_default_graph()
getnext=G.get_operation_by_name('getNext')
_,_,csm,mask=getnext.outputs
l=getLambda()
l2=tf.complex(l,0.)
def fn(tmp):
c,m,r=tmp
Aobj=Aclass(c,m,l2)
y=myCG(Aobj,r)
return y
inp=(csm,mask,rhs)
rec=tf.map_fn(fn,inp,dtype=tf.float32,name='mapFn2' )
return rec
@tf.custom_gradient
def dcManualGradient(x):
"""
This function impose data consistency constraint. Rather than relying on
TensorFlow to calculate the gradient for the conjuagte gradient part.
We can calculate the gradient manually as well by using this function.
Please see section III (c) in the paper.
"""
y=callCG(x)
def grad(inp):
out=callCG(inp)
return out
return y,grad
def dc(rhs,csm,mask,lam1):
"""
This function is called to create testing model. It apply CG on each image
in the batch.
"""
lam2=tf.complex(lam1,0.)
def fn( tmp ):
c,m,r=tmp
Aobj=Aclass( c,m,lam2 )
y=myCG(Aobj,r)
return y
inp=(csm,mask,rhs)
rec=tf.map_fn(fn,inp,dtype=tf.float32,name='mapFn' )
return rec
def makeModel(atb,csm,mask,training,nLayers,K,gradientMethod):
"""
This is the main function that creates the model.
"""
out={}
out['dc0']=atb
with tf.name_scope('myModel'):
with tf.variable_scope('Wts',reuse=tf.AUTO_REUSE):
for i in range(1,K+1):
j=str(i)
out['dw'+j]=dw(out['dc'+str(i-1)],training,nLayers)
lam1=getLambda()
rhs=atb + lam1*out['dw'+j]
if gradientMethod=='AG':
out['dc'+j]=dc(rhs,csm,mask,lam1)
elif gradientMethod=='MG':
if training:
out['dc'+j]=dcManualGradient(rhs)
else:
out['dc'+j]=dc(rhs,csm,mask,lam1)
return out