-
Notifications
You must be signed in to change notification settings - Fork 2
/
transform.py
72 lines (57 loc) · 2.64 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright (c) 2016-2017 Shafeen Tejani. Released under GPLv3.
import tensorflow as tf
WEIGHTS_INIT_STDEV = .1
def net(image):
image = image / 255.0
conv1 = _conv_layer(image, 32, 9, 1)
conv2 = _conv_layer(conv1, 64, 3, 2)
conv3 = _conv_layer(conv2, 128, 3, 2)
resid1 = _residual_block(conv3, 3)
resid2 = _residual_block(resid1, 3)
resid3 = _residual_block(resid2, 3)
resid4 = _residual_block(resid3, 3)
resid5 = _residual_block(resid4, 3)
conv_t1 = _conv_tranpose_layer(resid5, 64, 3, 2)
conv_t2 = _conv_tranpose_layer(conv_t1, 32, 3, 2)
conv_t3 = _conv_layer(conv_t2, 3, 9, 1, relu=False)
preds = tf.nn.tanh(conv_t3)
output = image + preds
return tf.nn.tanh(output) * 127.5 + 255./2
def _conv_layer(net, num_filters, filter_size, strides, relu=True):
weights_init = _conv_init_vars(net, num_filters, filter_size)
strides_shape = [1, strides, strides, 1]
net = tf.nn.conv2d(net, weights_init, strides_shape, padding='SAME')
net = _instance_norm(net)
if relu:
net = tf.nn.relu(net)
return net
def _conv_tranpose_layer(net, num_filters, filter_size, strides):
weights_init = _conv_init_vars(net, num_filters, filter_size, transpose=True)
batch_size, rows, cols, in_channels = [i.value for i in net.get_shape()]
new_rows, new_cols = int(rows * strides), int(cols * strides)
new_shape = [batch_size, new_rows, new_cols, num_filters]
tf_shape = tf.stack(new_shape)
strides_shape = [1,strides,strides,1]
net = tf.nn.conv2d_transpose(net, weights_init, tf_shape, strides_shape, padding='SAME')
net = _instance_norm(net)
return tf.nn.relu(net)
def _residual_block(net, filter_size=3):
tmp = _conv_layer(net, 128, filter_size, 1)
return net + _conv_layer(tmp, 128, filter_size, 1, relu=False)
def _instance_norm(net, train=True):
batch, rows, cols, channels = [i.value for i in net.get_shape()]
var_shape = [channels]
mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True)
shift = tf.Variable(tf.zeros(var_shape))
scale = tf.Variable(tf.ones(var_shape))
epsilon = 1e-3
normalized = (net-mu)/(sigma_sq + epsilon)**(.5)
return scale * normalized + shift
def _conv_init_vars(net, out_channels, filter_size, transpose=False):
_, rows, cols, in_channels = [i.value for i in net.get_shape()]
if not transpose:
weights_shape = [filter_size, filter_size, in_channels, out_channels]
else:
weights_shape = [filter_size, filter_size, out_channels, in_channels]
weights_init = tf.Variable(tf.truncated_normal(weights_shape, stddev=WEIGHTS_INIT_STDEV, seed=1), dtype=tf.float32)
return weights_init