-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
110 lines (89 loc) · 3.79 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch import nn, optim
from data import MYdata
from config import params
# from network.net import net
from evaluate import Total_loss, cal_iou
from tqdm import tqdm
import segmentation_models_pytorch as smp
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def valid(epoch, model, val_data, optimizer, Loss):
losses = AverageMeter()
IOU = AverageMeter()
model.eval()
with torch.no_grad():
for step, (input, label) in enumerate(tqdm(val_data)):
input = input.cuda(params['gpu'][0])
label = label.cuda(params['gpu'][0])
out = model(input)
loss, celoss, diceloss = Loss(out, label)
iou = 0
for i, (out_batch, label_batch) in enumerate(zip(out, label)):
out_batch = torch.argmax(out_batch, 0)
batch_iou = cal_iou(out_batch, label_batch)
iou += batch_iou
iou /= i
losses.update(loss.item())
IOU.update(iou)
print('epoch:', epoch, 'valid loss:%0.4f'%losses.avg, celoss.item(), diceloss.item(), 'iou:%0.4f'%IOU.avg, 'lr:', optimizer.param_groups[0]['lr'])
return losses.avg
def train(epoch, model, train_data, optimizer, Loss):
losses = AverageMeter()
IOU = AverageMeter()
model.train()
for step, (input, label) in enumerate(tqdm(train_data)):
input = input.cuda(params['gpu'][0])
label = label.cuda(params['gpu'][0])
out = model(input)
loss, celoss, diceloss = Loss(out, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
iou = 0
for i, (out_batch, label_batch) in enumerate(zip(out, label)):
out_batch = torch.argmax(out_batch, 0)
batch_iou = cal_iou(out_batch, label_batch)
iou += batch_iou
iou /= i
losses.update(loss.item())
IOU.update(iou)
print('epoch:', epoch, 'train loss:%0.4f'%losses.avg, celoss.item(), diceloss.item(), 'iou:%0.4f'%IOU.avg, 'lr:', optimizer.param_groups[0]['lr'])
return losses.avg
def main():
train_data = DataLoader(MYdata(params['csv'], mode='train'), batch_size=params['batchsize'], shuffle=True,num_workers=params['num_works'])
valid_data = DataLoader(MYdata(params['csv'], mode='valid'),batch_size=params['batchsize'],shuffle=False, num_workers=params['num_works'])
# model = net()
model = smp.Unet('resnet18', classes=8, encoder_weights='imagenet', activation='softmax')#, activation='softmax'
model = model.cuda(params['gpu'][0])
model = nn.DataParallel(model, device_ids=params['gpu'])
if params['pretrain']:
pretrain_dict = torch.load(params['pretrain'], map_location='cpu')
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
optimizer = optim.Adam(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])
schedule = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.333, patience=3, verbose=True)
Loss = Total_loss()
for n in range(params['max_epoch']):
train_loss = train(n+1, model, train_data, optimizer, Loss)
valid_loss = valid(n+1,model, valid_data,optimizer,Loss)
schedule.step(valid_loss)
torch.save(model.state_dict(), params['save_path'])
if __name__ == '__main__':
main()