forked from SSL92/hyperIQA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
HyerIQASolver.py
115 lines (91 loc) · 4.56 KB
/
HyerIQASolver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
from scipy import stats
import numpy as np
import models
import data_loader
class HyperIQASolver(object):
"""Solver for training and testing hyperIQA"""
def __init__(self, config, path, train_idx, test_idx):
self.epochs = config.epochs
self.test_patch_num = config.test_patch_num
self.model_hyper = models.HyperNet(16, 112, 224, 112, 56, 28, 14, 7).cuda()
self.model_hyper.train(True)
self.l1_loss = torch.nn.L1Loss().cuda()
backbone_params = list(map(id, self.model_hyper.res.parameters()))
self.hypernet_params = filter(lambda p: id(p) not in backbone_params, self.model_hyper.parameters())
self.lr = config.lr
self.lrratio = config.lr_ratio
self.weight_decay = config.weight_decay
paras = [{'params': self.hypernet_params, 'lr': self.lr * self.lrratio},
{'params': self.model_hyper.res.parameters(), 'lr': self.lr}
]
self.solver = torch.optim.Adam(paras, weight_decay=self.weight_decay)
train_loader = data_loader.DataLoader(config.dataset, path, train_idx, config.patch_size, config.train_patch_num, batch_size=config.batch_size, istrain=True)
test_loader = data_loader.DataLoader(config.dataset, path, test_idx, config.patch_size, config.test_patch_num, istrain=False)
self.train_data = train_loader.get_data()
self.test_data = test_loader.get_data()
def train(self):
"""Training"""
best_srcc = 0.0
best_plcc = 0.0
print('Epoch\tTrain_Loss\tTrain_SRCC\tTest_SRCC\tTest_PLCC')
for t in range(self.epochs):
epoch_loss = []
pred_scores = []
gt_scores = []
for img, label in self.train_data:
img = torch.tensor(img.cuda())
label = torch.tensor(label.cuda())
self.solver.zero_grad()
# Generate weights for target network
paras = self.model_hyper(img) # 'paras' contains the network weights conveyed to target network
# Building target network
model_target = models.TargetNet(paras).cuda()
for param in model_target.parameters():
param.requires_grad = False
# Quality prediction
pred = model_target(paras['target_in_vec']) # while 'paras['target_in_vec']' is the input to target net
pred_scores = pred_scores + pred.cpu().tolist()
gt_scores = gt_scores + label.cpu().tolist()
loss = self.l1_loss(pred.squeeze(), label.float().detach())
epoch_loss.append(loss.item())
loss.backward()
self.solver.step()
train_srcc, _ = stats.spearmanr(pred_scores, gt_scores)
test_srcc, test_plcc = self.test(self.test_data)
if test_srcc > best_srcc:
best_srcc = test_srcc
best_plcc = test_plcc
print('%d\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f' %
(t + 1, sum(epoch_loss) / len(epoch_loss), train_srcc, test_srcc, test_plcc))
# Update optimizer
lr = self.lr / pow(10, (t // 6))
if t > 8:
self.lrratio = 1
self.paras = [{'params': self.hypernet_params, 'lr': lr * self.lrratio},
{'params': self.model_hyper.res.parameters(), 'lr': self.lr}
]
self.solver = torch.optim.Adam(self.paras, weight_decay=self.weight_decay)
print('Best test SRCC %f, PLCC %f' % (best_srcc, best_plcc))
return best_srcc, best_plcc
def test(self, data):
"""Testing"""
self.model_hyper.train(False)
pred_scores = []
gt_scores = []
for img, label in data:
# Data.
img = torch.tensor(img.cuda())
label = torch.tensor(label.cuda())
paras = self.model_hyper(img)
model_target = models.TargetNet(paras).cuda()
model_target.train(False)
pred = model_target(paras['target_in_vec'])
pred_scores.append(float(pred.item()))
gt_scores = gt_scores + label.cpu().tolist()
pred_scores = np.mean(np.reshape(np.array(pred_scores), (-1, self.test_patch_num)), axis=1)
gt_scores = np.mean(np.reshape(np.array(gt_scores), (-1, self.test_patch_num)), axis=1)
test_srcc, _ = stats.spearmanr(pred_scores, gt_scores)
test_plcc, _ = stats.pearsonr(pred_scores, gt_scores)
self.model_hyper.train(True)
return test_srcc, test_plcc