forked from assassint2017/MICCAI-LITS2017
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_ds.py
149 lines (113 loc) · 3.19 KB
/
train_ds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
训练脚本
"""
import os
from time import time
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from visdom import Visdom
from dataset.dataset import Dataset
from loss.Dice import DiceLoss
from loss.ELDice import ELDiceLoss
from loss.WBCE import WCELoss
from loss.Jaccard import JaccardLoss
from loss.SS import SSLoss
from loss.Tversky import TverskyLoss
from loss.Hybrid import HybridLoss
from loss.BCE import BCELoss
from net.ResUNet import net
import parameter as para
# 设置visdom
viz = Visdom(port=666)
step_list = [0]
win = viz.line(X=np.array([0]), Y=np.array([1.0]), opts=dict(title='loss'))
# 设置显卡相关
os.environ['CUDA_VISIBLE_DEVICES'] = para.gpu
cudnn.benchmark = para.cudnn_benchmark
# 定义网络
net = torch.nn.DataParallel(net).cuda()
net.train()
# 定义Dateset
train_ds = Dataset(os.path.join(para.training_set_path, 'ct'), os.path.join(para.training_set_path, 'seg'))
# 定义数据加载
train_dl = DataLoader(train_ds, para.batch_size, True, num_workers=para.num_workers, pin_memory=para.pin_memory)
# 挑选损失函数
loss_func_list = [DiceLoss(), ELDiceLoss(), WCELoss(), JaccardLoss(), SSLoss(), TverskyLoss(), HybridLoss(), BCELoss()]
loss_func = loss_func_list[5]
# 定义优化器
opt = torch.optim.Adam(net.parameters(), lr=para.learning_rate)
# 学习率衰减
lr_decay = torch.optim.lr_scheduler.MultiStepLR(opt, para.learning_rate_decay)
# 深度监督衰减系数
alpha = para.alpha
# 训练网络
start = time()
for epoch in range(para.Epoch):
lr_decay.step()
mean_loss = []
for step, (ct, seg) in enumerate(train_dl):
ct = ct.cuda()
seg = seg.cuda()
outputs = net(ct)
loss1 = loss_func(outputs[0], seg)
loss2 = loss_func(outputs[1], seg)
loss3 = loss_func(outputs[2], seg)
loss4 = loss_func(outputs[3], seg)
loss = (loss1 + loss2 + loss3) * alpha + loss4
mean_loss.append(loss4.item())
opt.zero_grad()
loss.backward()
opt.step()
if step % 5 is 0:
step_list.append(step_list[-1] + 1)
viz.line(X=np.array([step_list[-1]]), Y=np.array([loss4.item()]), win=win, update='append')
print('epoch:{}, step:{}, loss1:{:.3f}, loss2:{:.3f}, loss3:{:.3f}, loss4:{:.3f}, time:{:.3f} min'
.format(epoch, step, loss1.item(), loss2.item(), loss3.item(), loss4.item(), (time() - start) / 60))
mean_loss = sum(mean_loss) / len(mean_loss)
# 保存模型
if epoch % 50 is 0 and epoch is not 0:
# 网络模型的命名方式为:epoch轮数+当前minibatch的loss+本轮epoch的平均loss
torch.save(net.state_dict(), './module/net{}-{:.3f}-{:.3f}.pth'.format(epoch, loss, mean_loss))
# 对深度监督系数进行衰减
if epoch % 40 is 0 and epoch is not 0:
alpha *= 0.8
# 深度监督的系数变化
# 1.000
# 0.800
# 0.640
# 0.512
# 0.410
# 0.328
# 0.262
# 0.210
# 0.168
# 0.134
# 0.107
# 0.086
# 0.069
# 0.055
# 0.044
# 0.035
# 0.028
# 0.023
# 0.018
# 0.014
# 0.012
# 0.009
# 0.007
# 0.006
# 0.005
# 0.004
# 0.003
# 0.002
# 0.002
# 0.002
# 0.001
# 0.001
# 0.001
# 0.001
# 0.001
# 0.000
# 0.000