-
Notifications
You must be signed in to change notification settings - Fork 28
/
sr_network.py
175 lines (129 loc) · 7.06 KB
/
sr_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
''' @author: Andrew Glaws, Karen Stengel, Ryan King
'''
import tensorflow as tf
from utils import *
class SR_NETWORK(object):
def __init__(self, x_LR=None, x_HR=None, r=None, status='pretraining', alpha_advers=0.001):
status = status.lower()
if status not in ['pretraining', 'training', 'testing']:
print('Error in network status.')
exit()
self.x_LR, self.x_HR = x_LR, x_HR
if r is None:
print('Error in SR scaling. Variable r must be specified.')
exit()
if status in ['pretraining', 'training']:
self.x_SR = self.generator(self.x_LR, r=r, is_training=True)
else:
self.x_SR = self.generator(self.x_LR, r=r, is_training=False)
self.g_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
if status == 'pretraining':
self.g_loss = self.compute_losses(self.x_HR, self.x_SR, None, None, alpha_advers, isGAN=False)
self.d_loss, self.disc_HR, self.disc_SR, self.d_variables = None, None, None, None
self.advers_perf, self.content_loss, self.g_advers_loss = None, None, None
elif status == 'training':
self.disc_HR = self.discriminator(self.x_HR, reuse=False)
self.disc_SR = self.discriminator(self.x_SR, reuse=True)
self.d_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
loss_out = self.compute_losses(self.x_HR, self.x_SR, self.disc_HR, self.disc_SR, alpha_advers, isGAN=True)
self.g_loss = loss_out[0]
self.d_loss = loss_out[1]
self.advers_perf = loss_out[2]
self.content_loss = loss_out[3]
self.g_advers_loss = loss_out[4]
else:
self.g_loss, self.d_loss = None, None
self.disc_HR, self.disc_SR, self.d_variables = None, None, None
self.advers_perf, self.content_loss, self.g_advers_loss = None, None, None
self.disc_HR, self.disc_SR, self.d_variables = None, None, None
def generator(self, x, r, is_training=False, reuse=False):
if is_training:
N, h, w, C = tf.shape(x)[0], x.get_shape()[1], x.get_shape()[2], x.get_shape()[3]
else:
N, h, w, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], x.get_shape()[3]
k, stride = 3, 1
output_shape = [N, h+2*k, w+2*k, -1]
with tf.variable_scope('generator', reuse=reuse):
with tf.variable_scope('deconv1'):
C_in, C_out = C, 64
output_shape[-1] = C_out
x = deconv_layer_2d(x, [k, k, C_out, C_in], output_shape, stride, k)
x = tf.nn.relu(x)
skip_connection = x
# B residual blocks
C_in, C_out = C_out, 64
output_shape[-1] = C_out
for i in range(16):
B_skip_connection = x
with tf.variable_scope('block_{}a'.format(i+1)):
x = deconv_layer_2d(x, [k, k, C_out, C_in], output_shape, stride, k)
x = tf.nn.relu(x)
with tf.variable_scope('block_{}b'.format(i+1)):
x = deconv_layer_2d(x, [k, k, C_out, C_in], output_shape, stride, k)
x = tf.add(x, B_skip_connection)
with tf.variable_scope('deconv2'):
x = deconv_layer_2d(x, [k, k, C_out, C_in], output_shape, stride, k)
x = tf.add(x, skip_connection)
# Super resolution scaling
r_prod = 1
for i, r_i in enumerate(r):
C_out = (r_i**2)*C_in
with tf.variable_scope('deconv{}'.format(i+3)):
output_shape = [N, r_prod*h+2*k, r_prod*w+2*k, C_out]
x = deconv_layer_2d(x, [k, k, C_out, C_in], output_shape, stride, k)
x = tf.depth_to_space(x, r_i)
x = tf.nn.relu(x)
r_prod *= r_i
output_shape = [N, r_prod*h+2*k, r_prod*w+2*k, C]
with tf.variable_scope('deconv_out'):
x = deconv_layer_2d(x, [k, k, C, C_in], output_shape, stride, k)
return x
def discriminator(self, x, reuse=False):
N, h, w, C = tf.shape(x)[0], x.get_shape()[1], x.get_shape()[2], x.get_shape()[3]
with tf.variable_scope('discriminator', reuse=reuse):
with tf.variable_scope('conv1'):
x = conv_layer_2d(x, [3, 3, C, 32], 1)
x = tf.nn.leaky_relu(x, alpha=0.2)
with tf.variable_scope('conv2'):
x = conv_layer_2d(x, [3, 3, 32, 32], 2)
x = tf.nn.leaky_relu(x, alpha=0.2)
with tf.variable_scope('conv3'):
x = conv_layer_2d(x, [3, 3, 32, 64], 1)
x = tf.nn.leaky_relu(x, alpha=0.2)
with tf.variable_scope('conv4'):
x = conv_layer_2d(x, [3, 3, 64, 64], 2)
x = tf.nn.leaky_relu(x, alpha=0.2)
with tf.variable_scope('conv5'):
x = conv_layer_2d(x, [3, 3, 64, 128], 1)
x = tf.nn.leaky_relu(x, alpha=0.2)
with tf.variable_scope('conv6'):
x = conv_layer_2d(x, [3, 3, 128, 128], 2)
x = tf.nn.leaky_relu(x, alpha=0.2)
with tf.variable_scope('conv7'):
x = conv_layer_2d(x, [3, 3, 128, 256], 1)
x = tf.nn.leaky_relu(x, alpha=0.2)
with tf.variable_scope('conv8'):
x = conv_layer_2d(x, [3, 3, 256, 256], 2)
x = tf.nn.leaky_relu(x, alpha=0.2)
x = flatten_layer(x)
with tf.variable_scope('fully_connected1'):
x = dense_layer(x, 1024)
x = tf.nn.leaky_relu(x, alpha=0.2)
with tf.variable_scope('fully_connected2'):
x = dense_layer(x, 1)
return x
def compute_losses(self, x_HR, x_SR, d_HR, d_SR, alpha_advers=0.001, isGAN=False):
content_loss = tf.reduce_mean((x_HR - x_SR)**2, axis=[1, 2, 3])
if isGAN:
g_advers_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_SR, labels=tf.ones_like(d_SR))
d_advers_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=tf.concat([d_HR, d_SR], axis=0),
labels=tf.concat([tf.ones_like(d_HR), tf.zeros_like(d_SR)], axis=0))
advers_perf = [tf.reduce_mean(tf.cast(tf.sigmoid(d_HR) > 0.5, tf.float32)), # % true positive
tf.reduce_mean(tf.cast(tf.sigmoid(d_SR) < 0.5, tf.float32)), # % true negative
tf.reduce_mean(tf.cast(tf.sigmoid(d_SR) > 0.5, tf.float32)), # % false positive
tf.reduce_mean(tf.cast(tf.sigmoid(d_HR) < 0.5, tf.float32))] # % false negative
g_loss = tf.reduce_mean(content_loss) + alpha_advers*tf.reduce_mean(g_advers_loss)
d_loss = tf.reduce_mean(d_advers_loss)
return g_loss, d_loss, advers_perf, content_loss, g_advers_loss
else:
return tf.reduce_mean(content_loss)