-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
559 lines (475 loc) · 22.5 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as nnf
from torch.nn import Dropout, Softmax, Linear, Conv3d, LayerNorm
from torch.nn.modules.utils import _pair, _triple
import configs as configs
from torch.distributions.normal import Normal
logger = logging.getLogger(__name__)
ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"
def np2th(weights, conv=False):#转tensor
"""Possibly convert HWIO to OIHW."""
if conv:
weights = weights.transpose([3, 2, 0, 1])
return torch.from_numpy(weights)
def swish(x):#swish激活函数,在深层模型上要优于传统的relu,具有无上界,有下界,光滑,非单调
return x * torch.sigmoid(x)
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
class Attention(nn.Module):
def __init__(self, config, vis):
super(Attention, self).__init__()
self.vis = vis
self.num_attention_heads = config.transformer["num_heads"]#12
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)#252/12=21
self.all_head_size = self.num_attention_heads * self.attention_head_size#252
self.query = Linear(config.hidden_size, self.all_head_size)#(252,252)
self.key = Linear(config.hidden_size, self.all_head_size)
self.value = Linear(config.hidden_size, self.all_head_size)
self.out = Linear(config.hidden_size, config.hidden_size)
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.softmax = Softmax(dim=-1)
#维度转换函数,输入尺寸:[B, S, H](32, 128, 768), 输出尺寸:[B, N, S, H/N](32, 8, 128, 768/8)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)#[:-1]表示0到倒数第一个数
x = x.view(*new_x_shape)#元素总数不变, 相当于resize
return x.permute(0, 2, 1, 3)
## 前向传播函数,输入尺寸:[B, S, H](32, 128, 768), 输出尺寸:[B, S, H](32, 128, 768)
def forward(self, hidden_states):
#Q K V 矩阵 尺寸为[B, S, H]
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))# 将"query"和"key"点乘,得到未经处理注意力值
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
weights = attention_probs if self.vis else None
attention_probs = self.attn_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
attention_output = self.out(context_layer)
attention_output = self.proj_dropout(attention_output)
return attention_output, weights
class Mlp(nn.Module):#前向神经网络,多层感知机,也叫人工神经网络(ANN,Artificial Neural Network)有三层神经网络
def __init__(self, config):
super(Mlp, self).__init__()
self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
self.act_fn = ACT2FN["gelu"]
self.dropout = Dropout(config.transformer["dropout_rate"])# 0.1
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.normal_(self.fc1.bias, std=1e-6)
nn.init.normal_(self.fc2.bias, std=1e-6)
def forward(self, x):
x = self.fc1(x)
x = self.act_fn(x)
x = self.dropout(x) # dropout可以避免过拟合,并增强模型的泛化能力。
x = self.fc2(x)
x = self.dropout(x)
return x
class Embeddings(nn.Module):
"""Construct the embeddings from patch, position embeddings.
"""
def __init__(self, config, img_size):
super(Embeddings, self).__init__()
self.config = config
down_factor = config.down_factor# 2
patch_size = _triple(config.patches["size"])#(8, 8, 8)
n_patches = int((img_size[0]/2**down_factor// patch_size[0]) * (img_size[1]/2**down_factor// patch_size[1]) * (img_size[2]/2**down_factor// patch_size[2]))
# 5*6*7=210,/2**down_factor是因为经过2层卷积后,初始图像尺寸变了
self.hybrid_model = CNNEncoder(config, n_channels=2)
in_channels = config['encoder_channels'][-1]#32
self.patch_embeddings = Conv3d(in_channels=in_channels,
out_channels=config.hidden_size,#252
kernel_size=patch_size,
stride=patch_size)
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))#(1, 210, 252)
# nn.Parameter将一个固定不可训练的tensor转换成可以训练的类型parameter
self.dropout = Dropout(config.transformer["dropout_rate"])#0.1
def forward(self, x):
x, features = self.hybrid_model(x)#x:[2, 32, 40, 48, 56]
x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))[2, 252, 5, 6, 7]
x = x.flatten(2)
x = x.transpose(-1, -2) # (B, n_patches, hidden)转置
embeddings = x + self.position_embeddings#[2, 210, 252]
embeddings = self.dropout(embeddings)
return embeddings, features
class Block(nn.Module):
def __init__(self, config, vis):
super(Block, self).__init__()
self.hidden_size = config.hidden_size # 252
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)#LayerNorm 层归一化,train()和eval()对LayerNorm没有影响
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn = Mlp(config)
self.attn = Attention(config, vis)
def forward(self, x):
h = x
x = self.attention_norm(x)
x, weights = self.attn(x)
x = x + h
h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h
return x, weights
class Encoder(nn.Module):
def __init__(self, config, vis):
super(Encoder, self).__init__()
self.vis = vis
self.layer = nn.ModuleList()#在创建 ModuleList 的时候传入一个 module 的 列表,还可以使用extend 函数和 append 函数来添加模型
self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
for _ in range(config.transformer["num_layers"]):
layer = Block(config, vis)
self.layer.append(copy.deepcopy(layer))#和list 的append 方法一样,将 一个 Module 添加到ModuleList
#deepcopy是因为layer里有数组,每次数组指向同一个数组,所以不能用copy
def forward(self, hidden_states):
attn_weights = []
for layer_block in self.layer:
hidden_states, weights = layer_block(hidden_states)
if self.vis:
attn_weights.append(weights)
encoded = self.encoder_norm(hidden_states)
return encoded, attn_weights
class Transformer(nn.Module):
def __init__(self, config, img_size, vis):
super(Transformer, self).__init__()
self.embeddings = Embeddings(config, img_size=img_size)
self.encoder = Encoder(config, vis)
def forward(self, input_ids):
embedding_output, features = self.embeddings(input_ids)#embedding_output=[2, 210, 252]
encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)encoded=[2, 210, 252]
return encoded, attn_weights, features
class Conv3dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
conv = nn.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.LeakyReLU(inplace=True)
IN = nn.InstanceNorm3d(out_channels)
super(Conv3dReLU, self).__init__(conv, IN, relu)
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
skip_channels=0,
use_batchnorm=True,
):
super().__init__()
self.conv1 = Conv3dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = Conv3dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)
#mode是可使用的上采样算法,scale_factor输出为输入的多少倍数, trilinear是三线性插值
def forward(self, x, skip=None):
x = self.up(x)
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x
class DecoderCup(nn.Module):
def __init__(self, config, img_size):
super().__init__()
self.config = config
self.down_factor = config.down_factor
head_channels = config.conv_first_channel
self.img_size = img_size
self.conv_more = Conv3dReLU(
config.hidden_size,
head_channels,
kernel_size=3,
padding=1,
use_batchnorm=True,
)
decoder_channels = config.decoder_channels
in_channels = [head_channels] + list(decoder_channels[:-1])
out_channels = decoder_channels
self.patch_size = _triple(config.patches["size"])
skip_channels = self.config.skip_channels#(32, 32, 32, 32, 16)
blocks = [
DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
]
self.blocks = nn.ModuleList(blocks)
def forward(self, hidden_states, features=None):
B, n_patch, hidden = hidden_states.size() # [2, 210, 252]reshape from (B, n_patch, hidden) to (B, h, w, hidden)
l, h, w = (self.img_size[0]//2**self.down_factor//self.patch_size[0]), (self.img_size[1]//2**self.down_factor//self.patch_size[0]), (self.img_size[2]//2**self.down_factor//self.patch_size[0])
#5 * 6 * 7 = 210
x = hidden_states.permute(0, 2, 1)#[2, 252, 210]
x = x.contiguous().view(B, hidden, l, h, w)#[2, 252, 5, 6, 7]view只能用在contiguous的variable上。如果在view之前用了transpose, permute等,需要用contiguous()来返回一个contiguous copy
x = self.conv_more(x)
for i, decoder_block in enumerate(self.blocks):
if features is not None:
skip = features[i] if (i < self.config.n_skip) else None
#print(skip.shape)
else:
skip = None
x = decoder_block(x, skip=skip)
return x
class SpatialTransformer(nn.Module):
"""
N-D Spatial Transformer
Obtained from https://github.com/voxelmorph/voxelmorph
"""
def __init__(self, size, mode='bilinear'):
super().__init__()
self.mode = mode
# create sampling grid
vectors = [torch.arange(0, s) for s in size]
grids = torch.meshgrid(vectors)
grid = torch.stack(grids)
grid = torch.unsqueeze(grid, 0)
grid = grid.type(torch.FloatTensor)
# registering the grid as a buffer cleanly moves it to the GPU, but it also
# adds it to the state dict. this is annoying since everything in the state dict
# is included when saving weights to disk, so the model files are way bigger
# than they need to be. so far, there does not appear to be an elegant solution.
# see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
self.register_buffer('grid', grid)
def forward(self, src, flow):
# new locations
#print("self.grid.shape", self.grid.shape)
#print( "flow.shape", flow.shape )
new_locs = self.grid + flow
shape = flow.shape[2:]
# need to normalize grid values to [-1, 1] for resampler
for i in range(len(shape)):
new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
# move channels dim to last position
# also not sure why, but the channels need to be reversed
if len(shape) == 2:
new_locs = new_locs.permute(0, 2, 3, 1)
new_locs = new_locs[..., [1, 0]]
elif len(shape) == 3:
new_locs = new_locs.permute(0, 2, 3, 4, 1)
new_locs = new_locs[..., [2, 1, 0]]
return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode)
class DoubleConv(nn.Module): #双层卷积
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.LeakyReLU(inplace=True),
nn.InstanceNorm3d( out_channels ),
nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.LeakyReLU(inplace=True),
nn.InstanceNorm3d(out_channels)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool3d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class CNNEncoder(nn.Module):
def __init__(self, config, n_channels=2):
super(CNNEncoder, self).__init__()
self.n_channels = n_channels
decoder_channels = config.decoder_channels
encoder_channels = config.encoder_channels#(16, 32, 32)
self.down_num = config.down_num # 2
self.inc = DoubleConv(n_channels, encoder_channels[0])
self.down1 = Down(encoder_channels[0], encoder_channels[1])
self.down2 = Down(encoder_channels[1], encoder_channels[2])
self.width = encoder_channels[-1]
def forward(self, x):
features = []
x1 = self.inc(x)#第一层卷积
features.append(x1)
x2 = self.down1(x1)#第二层卷积
features.append(x2)
feats = self.down2(x2)#第三层卷积
features.append(feats)
feats_down = feats
for i in range(self.down_num):
feats_down = nn.MaxPool3d(2)(feats_down)
features.append(feats_down)
#print("features0:", features[0].size())#[2, 16, 160, 192, 224]
#print( "features1;", features[1].size() )#[2, 32, 80, 96, 112]
#print( "features2:", features[2].size() )#[2, 32, 40, 48, 56]
#print( "features3:", features[3].size() )#[2, 32, 20, 24, 28]
#print( "features4:", features[4].size() )#[2, 32, 10, 12, 14]
return feats, features[::-1] # [::-1]所有元素反向
class RegistrationHead(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
conv3d.weight = nn.Parameter(Normal(0, 1e-5).sample(conv3d.weight.shape))
conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape))
super().__init__(conv3d)
class ViTVNet(nn.Module):
def __init__(self, config, img_size=(64, 256, 256), int_steps=7, vis=False, mode='bilinear'):
super(ViTVNet, self).__init__()
self.transformer = Transformer(config, img_size, vis)
self.decoder = DecoderCup(config, img_size)
self.reg_head = RegistrationHead(
in_channels=config.decoder_channels[-1],#16
out_channels=config['n_dims'],#3
kernel_size=3,
)
self.spatial_trans = SpatialTransformer(img_size, mode)
self.config = config
#self.integrate = VecInt(img_size, int_steps)
def forward(self, x):
#x:[2, 2, 160, 192, 224]
moving = x[:,0:1,:,:]#[2, 1, 160, 192, 224]
x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden),x=([2, 210, 252])
x = self.decoder(x, features)#([2, 16, 160, 192, 224])
flow = self.reg_head(x)#[2, 3, 160, 192, 224]
#flow = self.integrate(x)
#img = x[0].cuda( )
#flow = x[1].cuda( )
out = self.spatial_trans(moving, flow)#[2, 1, 160, 192, 224])
#out = self.spatial_trans( source, flow )
return out, flow
class vm2(nn.Module):
def __init__(self, dim=3, enc_nf = [16, 32, 32, 32], dec_nf = [32, 32, 32, 32, 32, 16, 16], bn=None, full_size=True, img_size=(64, 256, 256), mode='bilinear'):
super(vm2, self).__init__()
#vm1是[32, 32, 32, 32, 8, 8]
#vm2是[32, 32, 32, 32, 32, 16, 16]
self.bn = bn
self.dim = dim
self.enc_nf = enc_nf
self.full_size = full_size
self.vm2 = len(dec_nf) == 7
# Encoder functions
self.enc = nn.ModuleList()
for i in range(len(enc_nf)):
prev_nf = 2 if i == 0 else enc_nf[i - 1]
self.enc.append(self.conv_block(dim, prev_nf, enc_nf[i], 4, 2, batchnorm=bn))
# Decoder functions
self.dec = nn.ModuleList()
self.dec.append(self.conv_block(dim, enc_nf[-1], dec_nf[0], batchnorm=bn)) # 1
self.dec.append(self.conv_block(dim, dec_nf[0] * 2, dec_nf[1], batchnorm=bn)) # 2
self.dec.append(self.conv_block(dim, dec_nf[1] * 2, dec_nf[2], batchnorm=bn)) # 3
self.dec.append(self.conv_block(dim, dec_nf[2] + enc_nf[0], dec_nf[3], batchnorm=bn)) # 4
self.dec.append(self.conv_block(dim, dec_nf[3], dec_nf[4], batchnorm=bn)) # 5
if self.full_size:
self.dec.append(self.conv_block(dim, dec_nf[4] + 2, dec_nf[5], batchnorm=bn))
if self.vm2:
self.vm2_conv = self.conv_block(dim, dec_nf[5], dec_nf[6], batchnorm=bn)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
# One conv to get the flow field
conv_fn = getattr(nn, 'Conv%dd' % dim)
self.flow = conv_fn(dec_nf[-1], dim, kernel_size=3, padding=1)
# Make flow weights + bias small. Not sure this is necessary.
nd = Normal(0, 1e-5)
self.flow.weight = nn.Parameter(nd.sample(self.flow.weight.shape))
self.flow.bias = nn.Parameter(torch.zeros(self.flow.bias.shape))
self.batch_norm = getattr(nn, "BatchNorm{0}d".format(dim))(3)
self.spatial_trans = SpatialTransformer( img_size, mode )
def conv_block(self, dim, in_channels, out_channels, kernel_size=3, stride=1, padding=1, batchnorm=False):
conv_fn = getattr(nn, "Conv{0}d".format(dim))
bn_fn = getattr(nn, "BatchNorm{0}d".format(dim))
if batchnorm:
layer = nn.Sequential(
conv_fn(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
bn_fn(out_channels),
nn.LeakyReLU(0.2))
else:
layer = nn.Sequential(
conv_fn(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
nn.LeakyReLU(0.2))
return layer
def forward(self, x):
moving = x[:, 0:1, :, :]#[1, 1, 192, 160, 192]tensor
# Get encoder activations
x_enc = [x]
for i, l in enumerate(self.enc):
x = l(x_enc[-1])
x_enc.append(x)
# Three conv + upsample + concatenate series
y = x_enc[-1]
for i in range(3):
y = self.dec[i](y)
y = self.upsample(y)
y = torch.cat([y, x_enc[-(i + 2)]], dim=1)
# Two convs at full_size/2 res
y = self.dec[3](y)
y = self.dec[4](y)
# Upsample to full res, concatenate and conv
if self.full_size:
y = self.upsample(y)
y = torch.cat([y, x_enc[0]], dim=1)
y = self.dec[5](y)
# Extra conv for vm2
if self.vm2:
y = self.vm2_conv(y)
flow = self.flow(y)
if self.bn:
flow = self.batch_norm(flow)
out = self.spatial_trans( moving, flow ) # [2, 1, 160, 192, 224])
return out, flow
class VecInt(nn.Module):
"""
Integrates a vector field via scaling and squaring.
Obtained from https://github.com/voxelmorph/voxelmorph
"""
def __init__(self, inshape, nsteps):
super().__init__()
assert nsteps >= 0, 'nsteps should be >= 0, found: %d' % nsteps
self.nsteps = nsteps
self.scale = 1.0 / (2 ** self.nsteps)
self.transformer = SpatialTransformer(inshape)
def forward(self, vec):
vec = vec * self.scale
for _ in range(self.nsteps):
vec = vec + self.transformer(vec, vec)
return vec
CONFIGS = {
'ViT-V-Net': configs.get_3DReg_config(),
}