-
Notifications
You must be signed in to change notification settings - Fork 36
/
main.py
127 lines (107 loc) · 4.37 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# ------------------------------
# Residual Dense Network
# ------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import argparse
import numpy as np
from torch.nn import init
import torch.optim as optim
from model import model
from data import DIV2K
from utils import *
parser = argparse.ArgumentParser(description='Semantic aware super-resolution')
parser.add_argument('--dataDir', default='./data', help='dataset directory')
parser.add_argument('--saveDir', default='./result', help='datasave directory')
parser.add_argument('--load', default= 'model_name', help='save result')
parser.add_argument('--model_name', default= 'RDN', help='model to select')
parser.add_argument('--finetuning', default=False, help='finetuning the training')
parser.add_argument('--need_patch', default=False, help='get patch form image')
parser.add_argument('--nDenselayer', type=int, default=6, help='nDenselayer of RDB')
parser.add_argument('--growthRate', type=int, default=32, help='growthRate of dense net')
parser.add_argument('--nBlock', type=int, default=16, help='number of RDB block')
parser.add_argument('--nFeat', type=int, default=64, help='number of feature maps')
parser.add_argument('--nChannel', type=int, default=3, help='number of color channels to use')
parser.add_argument('--patchSize', type=int, default=96, help='patch size')
parser.add_argument('--nThreads', type=int, default=3, help='number of threads for data loading')
parser.add_argument('--batchSize', type=int, default=16, help='input batch size for training')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--epochs', type=int, default=1000, help='number of epochs to train')
parser.add_argument('--lrDecay', type=int, default=500, help='input LR video')
parser.add_argument('--decayType', default='step', help='output SR video')
parser.add_argument('--lossType', default='L1', help='output SR video')
parser.add_argument('--scale', type=int, default= 1, help='scale output size /input size')
args = parser.parse_args()
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
init.kaiming_normal(m.weight.data)
def get_dataset(args):
data_train = DIV2K(args)
dataloader = torch.utils.data.DataLoader(data_train, batch_size=args.batchSize,
drop_last=True, shuffle=True, num_workers=int(args.nThreads), pin_memory=False)
return dataloader
def set_loss(args):
lossType = args.lossType
if lossType == 'MSE':
lossfunction = nn.MSELoss()
elif lossType == 'L1':
lossfunction = nn.L1Loss()
return lossfunction
def set_lr(args, epoch, optimizer):
lrDecay = args.lrDecay
decayType = args.decayType
if decayType == 'step':
epoch_iter = (epoch + 1) // lrDecay
lr = args.lr / 2**epoch_iter
elif decayType == 'exp':
k = math.log(2) / lrDecay
lr = args.lr * math.exp(-k * epoch)
elif decayType == 'inv':
k = 1 / lrDecay
lr = args.lr / (1 + k * epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
def train(args):
# select network
if args.model_name == 'RDN':
my_model = model.RDN(args)
my_model.apply(weights_init_kaiming)
my_model.cuda()
save = saveData(args)
# fine-tuning or retrain
if args.finetuning:
my_model = save.load_model(my_model)
# load data
dataloader = get_dataset(args)
L1_lossfunction = set_loss(args)
total_loss = 0
for epoch in range(args.epochs):
optimizer = optim.Adam(my_model.parameters())
learning_rate = set_lr(args, epoch, optimizer)
total_loss_ = 0
L1_loss_ = 0
for batch, (im_lr, im_hr) in enumerate(dataloader):
im_lr = Variable(im_lr.cuda(), volatile=False)
im_hr = Variable(im_hr.cuda())
my_model.zero_grad()
output = my_model(im_lr)
L1_loss = L1_lossfunction(output, im_hr)
total_loss = L1_loss
total_loss.backward()
optimizer.step()
L1_loss_ += L1_loss.data.cpu().numpy()[0]
total_loss_ += L1_loss.data.cpu().numpy()[0]
L1_loss_ = L1_loss_ / (batch + 1)
total_loss_ = total_loss_ / (batch + 1)
if (epoch+1) % 10 == 0:
log = "[{} / {}] \tLearning_rate: {}\t total_loss: {:.4f}\t L2_loss: {:.4f}".format(epoch+1,
args.epochs, learning_rate, total_loss_, L1_loss_)
print(log)
save.save_log(log)
save.save_model(my_model)
if __name__ == '__main__':
train(args)