Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

更新到1.0 API #5

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ This code is based on [Tensorflow-Slim](https://github.com/tensorflow/models/tre

## Requirements and Prerequisites:
- Python 2.7.x
- Tensorflow(>= 0.11)
- Tensorflow(>= 1.0)

And make sure you installed pyyaml:
```
Expand Down
4 changes: 2 additions & 2 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def main(_):
image = tf.expand_dims(image, 0)
generated = model.net(image, training=False)
generated = tf.squeeze(generated, [0])
saver = tf.train.Saver(tf.all_variables())
sess.run([tf.initialize_all_variables(), tf.initialize_local_variables()])
saver = tf.train.Saver(tf.global_variables())
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
FLAGS.model_file = os.path.abspath(FLAGS.model_file)
saver.restore(sess, FLAGS.model_file)

Expand Down
8 changes: 4 additions & 4 deletions losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def gram(layer):
width = shape[1]
height = shape[2]
num_filters = shape[3]
filters = tf.reshape(layer, tf.pack([num_images, -1, num_filters]))
filters = tf.reshape(layer, tf.stack([num_images, -1, num_filters]))
grams = tf.batch_matmul(filters, filters, adj_x=True) / tf.to_float(width * height * num_filters)

return grams
Expand Down Expand Up @@ -43,7 +43,7 @@ def get_style_features(FLAGS):
else:
image = tf.image.decode_jpeg(img_bytes)
# image = _aspect_preserving_resize(image, size)
images = tf.pack([image_preprocessing_fn(image, size, size)])
images = tf.stack([image_preprocessing_fn(image, size, size)])
_, endpoints_dict = network_fn(images, spatial_squeeze=False)
features = []
for layer in FLAGS.style_layers:
Expand Down Expand Up @@ -90,7 +90,7 @@ def total_variation_loss(layer):
shape = tf.shape(layer)
height = shape[1]
width = shape[2]
y = tf.slice(layer, [0, 0, 0, 0], tf.pack([-1, height - 1, -1, -1])) - tf.slice(layer, [0, 1, 0, 0], [-1, -1, -1, -1])
x = tf.slice(layer, [0, 0, 0, 0], tf.pack([-1, -1, width - 1, -1])) - tf.slice(layer, [0, 0, 1, 0], [-1, -1, -1, -1])
y = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, height - 1, -1, -1])) - tf.slice(layer, [0, 1, 0, 0], [-1, -1, -1, -1])
x = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, -1, width - 1, -1])) - tf.slice(layer, [0, 0, 1, 0], [-1, -1, -1, -1])
loss = tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y))
return loss
6 changes: 3 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def conv2d_transpose(x, input_filters, output_filters, kernel, strides):
batch_size = tf.shape(x)[0]
height = tf.shape(x)[1] * strides
width = tf.shape(x)[2] * strides
output_shape = tf.pack([batch_size, height, width, output_filters])
output_shape = tf.stack([batch_size, height, width, output_filters])
return tf.nn.conv2d_transpose(x, weight, output_shape, strides=[1, strides, strides, 1], name='conv_transpose')


Expand Down Expand Up @@ -51,7 +51,7 @@ def instance_norm(x):

mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)

return tf.div(tf.sub(x, mean), tf.sqrt(tf.add(var, epsilon)))
return tf.div(tf.subtract(x, mean), tf.sqrt(tf.add(var, epsilon)))


def batch_norm(x, size, training, decay=0.999):
Expand Down Expand Up @@ -121,6 +121,6 @@ def net(image, training):
# Remove border effect reducing padding.
height = tf.shape(y)[1]
width = tf.shape(y)[2]
y = tf.slice(y, [0, 10, 10, 0], tf.pack([-1, height - 20, width - 20, -1]))
y = tf.slice(y, [0, 10, 10, 0], tf.stack([-1, height - 20, width - 20, -1]))

return y
13 changes: 6 additions & 7 deletions preprocessing/vgg_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _crop(image, offset_height, offset_width, crop_height, crop_width):
['Rank of image must be equal to 3.'])
cropped_shape = control_flow_ops.with_dependencies(
[rank_assertion],
tf.pack([crop_height, crop_width, original_shape[2]]))
tf.stack([crop_height, crop_width, original_shape[2]]))

# print(original_shape[0], crop_height)
# print(original_shape[1], crop_width)
Expand All @@ -83,7 +83,7 @@ def _crop(image, offset_height, offset_width, crop_height, crop_width):
tf.greater_equal(original_shape[1], crop_width)),
['Crop size greater than the image size.'])

offsets = tf.to_int32(tf.pack([offset_height, offset_width, 0]))
offsets = tf.to_int32(tf.stack([offset_height, offset_width, 0]))

# Use tf.slice instead of crop_to_bounding box as it accepts tensors to
# define the crop size.
Expand Down Expand Up @@ -227,11 +227,10 @@ def _mean_image_subtraction(image, means):
num_channels = image.get_shape().as_list()[-1]
if len(means) != num_channels:
raise ValueError('len(means) must match the number of channels')

channels = tf.split(2, num_channels, image)
channels = tf.split(image, num_channels, 2)
for i in range(num_channels):
channels[i] -= means[i]
return tf.concat(2, channels)
return tf.concat(channels, 2)


def _mean_image_add(image, means):
Expand All @@ -241,10 +240,10 @@ def _mean_image_add(image, means):
if len(means) != num_channels:
raise ValueError('len(means) must match the number of channels')

channels = tf.split(2, num_channels, image)
channels = tf.split(image, num_channels, 2)
for i in range(num_channels):
channels[i] += means[i]
return tf.concat(2, channels)
return tf.concat(channels, 2)


def _smallest_size_at_least(height, width, target_height, target_width):
Expand Down
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def main(FLAGS):
'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)
generated = model.net(processed_images, training=True)
processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
for image in tf.unpack(generated, axis=0, num=FLAGS.batch_size)
for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
]
processed_generated = tf.pack(processed_generated)
processed_generated = tf.stack(processed_generated)
_, endpoints_dict = network_fn(tf.concat(0, [processed_generated, processed_images]), spatial_squeeze=False)
tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
for key in endpoints_dict:
Expand All @@ -70,8 +70,8 @@ def main(FLAGS):
tf.scalar_summary('style_losses/' + layer, style_loss_summary[layer])
tf.image_summary('generated', generated)
# tf.image_summary('processed_generated', processed_generated) # May be better?
tf.image_summary('origin', tf.pack([
image_unprocessing_fn(image) for image in tf.unpack(processed_images, axis=0, num=FLAGS.batch_size)
tf.image_summary('origin', tf.stack([
image_unprocessing_fn(image) for image in tf.unstack(processed_images, axis=0, num=FLAGS.batch_size)
]))
summary = tf.merge_all_summaries()
writer = tf.train.SummaryWriter(training_path)
Expand Down