-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcifar10_main.py
521 lines (457 loc) · 17.9 KB
/
cifar10_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ResNet model for classifying images from CIFAR-10 dataset.
Support single-host training with one or multiple devices.
ResNet as proposed in:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
CIFAR-10 as in:
http://www.cs.toronto.edu/~kriz/cifar.html
"""
from __future__ import division
from __future__ import print_function
import argparse
import functools
import itertools
import os
import cifar10
import cifar10_model
import cifar10_utils
import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
def get_model_fn(num_gpus, variable_strategy, num_workers):
"""Returns a function that will build the resnet model."""
def _resnet_model_fn(features, labels, mode, params):
"""Resnet model body.
Support single host, one or more GPU training. Parameter distribution can
be either one of the following scheme.
1. CPU is the parameter server and manages gradient updates.
2. Parameters are distributed evenly across all GPUs, and the first GPU
manages gradient updates.
Args:
features: a list of tensors, one for each tower
labels: a list of tensors, one for each tower
mode: ModeKeys.TRAIN or EVAL
params: Hyperparameters suitable for tuning
Returns:
A EstimatorSpec object.
"""
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
weight_decay = params.weight_decay
momentum = params.momentum
tower_features = features
tower_labels = labels
tower_losses = []
tower_gradvars = []
tower_preds = []
# channels first (NCHW) is normally optimal on GPU and channels last (NHWC)
# on CPU. The exception is Intel MKL on CPU which is optimal with
# channels_last.
data_format = params.data_format
if not data_format:
if num_gpus == 0:
data_format = 'channels_last'
else:
data_format = 'channels_first'
if num_gpus == 0:
num_devices = 1
device_type = 'cpu'
else:
num_devices = num_gpus
device_type = 'gpu'
for i in range(num_devices):
worker_device = '/{}:{}'.format(device_type, i)
if variable_strategy == 'CPU':
device_setter = cifar10_utils.local_device_setter(
worker_device=worker_device)
elif variable_strategy == 'GPU':
device_setter = cifar10_utils.local_device_setter(
ps_device_type='gpu',
worker_device=worker_device,
ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy(
num_gpus, tf.contrib.training.byte_size_load_fn))
with tf.variable_scope('resnet', reuse=bool(i != 0)):
with tf.name_scope('tower_%d' % i) as name_scope:
with tf.device(device_setter):
loss, gradvars, preds = _tower_fn(
is_training, weight_decay, tower_features[i], tower_labels[i],
data_format, params.num_layers, params.batch_norm_decay,
params.batch_norm_epsilon)
tower_losses.append(loss)
tower_gradvars.append(gradvars)
tower_preds.append(preds)
if i == 0:
# Only trigger batch_norm moving mean and variance update from
# the 1st tower. Ideally, we should grab the updates from all
# towers but these stats accumulate extremely fast so we can
# ignore the other stats from the other towers without
# significant detriment.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
name_scope)
# Now compute global loss and gradients.
gradvars = []
with tf.name_scope('gradient_averaging'):
all_grads = {}
for grad, var in itertools.chain(*tower_gradvars):
if grad is not None:
all_grads.setdefault(var, []).append(grad)
for var, grads in six.iteritems(all_grads):
# Average gradients on the same device as the variables
# to which they apply.
with tf.device(var.device):
if len(grads) == 1:
avg_grad = grads[0]
else:
avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads))
gradvars.append((avg_grad, var))
# Device that runs the ops to apply global gradient updates.
consolidation_device = '/gpu:0' if variable_strategy == 'GPU' else '/cpu:0'
with tf.device(consolidation_device):
# Suggested learning rate scheduling from
# https://github.com/ppwwyyxx/tensorpack/blob/master/examples/ResNet/cifar10-resnet.py#L155
num_batches_per_epoch = cifar10.Cifar10DataSet.num_examples_per_epoch(
'train') // (params.train_batch_size * num_workers)
boundaries = [
num_batches_per_epoch * x
for x in np.array([82, 123, 300], dtype=np.int64)
]
staged_lr = [params.learning_rate * x for x in [1, 0.1, 0.01, 0.002]]
learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(),
boundaries, staged_lr)
loss = tf.reduce_mean(tower_losses, name='loss')
examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
params.train_batch_size, every_n_steps=10)
tensors_to_log = {'learning_rate': learning_rate, 'loss': loss}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
train_hooks = [logging_hook, examples_sec_hook]
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, momentum=momentum)
if params.sync:
optimizer = tf.train.SyncReplicasOptimizer(
optimizer, replicas_to_aggregate=num_workers)
sync_replicas_hook = optimizer.make_session_run_hook(params.is_chief)
train_hooks.append(sync_replicas_hook)
# Create single grouped train op
train_op = [
optimizer.apply_gradients(
gradvars, global_step=tf.train.get_global_step())
]
train_op.extend(update_ops)
train_op = tf.group(*train_op)
predictions = {
'classes':
tf.concat([p['classes'] for p in tower_preds], axis=0),
'probabilities':
tf.concat([p['probabilities'] for p in tower_preds], axis=0)
}
stacked_labels = tf.concat(labels, axis=0)
metrics = {
'accuracy':
tf.metrics.accuracy(stacked_labels, predictions['classes'])
}
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
training_hooks=train_hooks,
eval_metric_ops=metrics)
return _resnet_model_fn
def _tower_fn(is_training, weight_decay, feature, label, data_format,
num_layers, batch_norm_decay, batch_norm_epsilon):
"""Build computation tower (Resnet).
Args:
is_training: true if is training graph.
weight_decay: weight regularization strength, a float.
feature: a Tensor.
label: a Tensor.
data_format: channels_last (NHWC) or channels_first (NCHW).
num_layers: number of layers, an int.
batch_norm_decay: decay for batch normalization, a float.
batch_norm_epsilon: epsilon for batch normalization, a float.
Returns:
A tuple with the loss for the tower, the gradients and parameters, and
predictions.
"""
model = cifar10_model.ResNetCifar10(
num_layers,
batch_norm_decay=batch_norm_decay,
batch_norm_epsilon=batch_norm_epsilon,
is_training=is_training,
data_format=data_format)
logits = model.forward_pass(feature, input_data_format='channels_last')
tower_pred = {
'classes': tf.argmax(input=logits, axis=1),
'probabilities': tf.nn.softmax(logits)
}
tower_loss = tf.losses.sparse_softmax_cross_entropy(
logits=logits, labels=label)
tower_loss = tf.reduce_mean(tower_loss)
model_params = tf.trainable_variables()
tower_loss += weight_decay * tf.add_n(
[tf.nn.l2_loss(v) for v in model_params])
tower_grad = tf.gradients(tower_loss, model_params)
return tower_loss, zip(tower_grad, model_params), tower_pred
def input_fn(data_dir,
subset,
num_shards,
batch_size,
use_distortion_for_training=True):
"""Create input graph for model.
Args:
data_dir: Directory where TFRecords representing the dataset are located.
subset: one of 'train', 'validate' and 'eval'.
num_shards: num of towers participating in data-parallel training.
batch_size: total batch size for training to be divided by the number of
shards.
use_distortion_for_training: True to use distortions.
Returns:
two lists of tensors for features and labels, each of num_shards length.
"""
with tf.device('/cpu:0'):
use_distortion = subset == 'train' and use_distortion_for_training
dataset = cifar10.Cifar10DataSet(data_dir, subset, use_distortion)
image_batch, label_batch = dataset.make_batch(batch_size)
if num_shards <= 1:
# No GPU available or only 1 GPU.
return [image_batch], [label_batch]
# Note that passing num=batch_size is safe here, even though
# dataset.batch(batch_size) can, in some cases, return fewer than batch_size
# examples. This is because it does so only when repeating for a limited
# number of epochs, but our dataset repeats forever.
image_batch = tf.unstack(image_batch, num=batch_size, axis=0)
label_batch = tf.unstack(label_batch, num=batch_size, axis=0)
feature_shards = [[] for i in range(num_shards)]
label_shards = [[] for i in range(num_shards)]
for i in xrange(batch_size):
idx = i % num_shards
feature_shards[idx].append(image_batch[i])
label_shards[idx].append(label_batch[i])
feature_shards = [tf.parallel_stack(x) for x in feature_shards]
label_shards = [tf.parallel_stack(x) for x in label_shards]
return feature_shards, label_shards
def get_experiment_fn(data_dir,
num_gpus,
variable_strategy,
use_distortion_for_training=True):
"""Returns an Experiment function.
Experiments perform training on several workers in parallel,
in other words experiments know how to invoke train and eval in a sensible
fashion for distributed training. Arguments passed directly to this
function are not tunable, all other arguments should be passed within
tf.HParams, passed to the enclosed function.
Args:
data_dir: str. Location of the data for input_fns.
num_gpus: int. Number of GPUs on each worker.
variable_strategy: String. CPU to use CPU as the parameter server
and GPU to use the GPUs as the parameter server.
use_distortion_for_training: bool. See cifar10.Cifar10DataSet.
Returns:
A function (tf.estimator.RunConfig, tf.contrib.training.HParams) ->
tf.contrib.learn.Experiment.
Suitable for use by tf.contrib.learn.learn_runner, which will run various
methods on Experiment (train, evaluate) based on information
about the current runner in `run_config`.
"""
def _experiment_fn(run_config, hparams):
"""Returns an Experiment."""
# Create estimator.
train_input_fn = functools.partial(
input_fn,
data_dir,
subset='train',
num_shards=num_gpus,
batch_size=hparams.train_batch_size,
use_distortion_for_training=use_distortion_for_training)
eval_input_fn = functools.partial(
input_fn,
data_dir,
subset='eval',
batch_size=hparams.eval_batch_size,
num_shards=num_gpus)
num_eval_examples = cifar10.Cifar10DataSet.num_examples_per_epoch('eval')
if num_eval_examples % hparams.eval_batch_size != 0:
raise ValueError(
'validation set size must be multiple of eval_batch_size')
train_steps = hparams.train_steps
eval_steps = num_eval_examples // hparams.eval_batch_size
classifier = tf.estimator.Estimator(
model_fn=get_model_fn(num_gpus, variable_strategy,
run_config.num_worker_replicas or 1),
config=run_config,
params=hparams)
# Create experiment.
return tf.contrib.learn.Experiment(
classifier,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=train_steps,
eval_steps=eval_steps)
return _experiment_fn
def main(job_dir, data_dir, num_gpus, variable_strategy,
use_distortion_for_training, log_device_placement, num_intra_threads,
**hparams):
# The env variable is on deprecation path, default is set to off.
os.environ['TF_SYNC_ON_FINISH'] = '0'
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
# Session configuration.
sess_config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=log_device_placement,
intra_op_parallelism_threads=num_intra_threads,
gpu_options=tf.GPUOptions(force_gpu_compatible=True))
config = cifar10_utils.RunConfig(
session_config=sess_config, model_dir=job_dir)
tf.contrib.learn.learn_runner.run(
get_experiment_fn(data_dir, num_gpus, variable_strategy,
use_distortion_for_training),
run_config=config,
hparams=tf.contrib.training.HParams(
is_chief=config.is_chief,
**hparams))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data-dir',
type=str,
required=True,
help='The directory where the CIFAR-10 input data is stored.')
parser.add_argument(
'--job-dir',
type=str,
required=True,
help='The directory where the model will be stored.')
parser.add_argument(
'--variable-strategy',
choices=['CPU', 'GPU'],
type=str,
default='CPU',
help='Where to locate variable operations')
parser.add_argument(
'--num-gpus',
type=int,
default=1,
help='The number of gpus used. Uses only CPU if set to 0.')
parser.add_argument(
'--num-layers',
type=int,
default=44,
help='The number of layers of the model.')
parser.add_argument(
'--train-steps',
type=int,
default=80000,
help='The number of steps to use for training.')
parser.add_argument(
'--train-batch-size',
type=int,
default=128,
help='Batch size for training.')
parser.add_argument(
'--eval-batch-size',
type=int,
default=100,
help='Batch size for validation.')
parser.add_argument(
'--momentum',
type=float,
default=0.9,
help='Momentum for MomentumOptimizer.')
parser.add_argument(
'--weight-decay',
type=float,
default=2e-4,
help='Weight decay for convolutions.')
parser.add_argument(
'--learning-rate',
type=float,
default=0.1,
help="""\
This is the inital learning rate value. The learning rate will decrease
during training. For more details check the model_fn implementation in
this file.\
""")
parser.add_argument(
'--use-distortion-for-training',
type=bool,
default=True,
help='If doing image distortion for training.')
parser.add_argument(
'--sync',
action='store_true',
default=False,
help="""\
If present when running in a distributed environment will run on sync mode.\
""")
parser.add_argument(
'--num-intra-threads',
type=int,
default=0,
help="""\
Number of threads to use for intra-op parallelism. When training on CPU
set to 0 to have the system pick the appropriate number or alternatively
set it to the number of physical CPU cores.\
""")
parser.add_argument(
'--num-inter-threads',
type=int,
default=0,
help="""\
Number of threads to use for inter-op parallelism. If set to 0, the
system will pick an appropriate number.\
""")
parser.add_argument(
'--data-format',
type=str,
default=None,
help="""\
If not set, the data format best for the training device is used.
Allowed values: channels_first (NCHW) channels_last (NHWC).\
""")
parser.add_argument(
'--log-device-placement',
action='store_true',
default=False,
help='Whether to log device placement.')
parser.add_argument(
'--batch-norm-decay',
type=float,
default=0.997,
help='Decay for batch norm.')
parser.add_argument(
'--batch-norm-epsilon',
type=float,
default=1e-5,
help='Epsilon for batch norm.')
args = parser.parse_args()
if args.num_gpus > 0:
assert tf.test.is_gpu_available(), "Requested GPUs but none found."
if args.num_gpus < 0:
raise ValueError(
'Invalid GPU count: \"--num-gpus\" must be 0 or a positive integer.')
if args.num_gpus == 0 and args.variable_strategy == 'GPU':
raise ValueError('num-gpus=0, CPU must be used as parameter server. Set'
'--variable-strategy=CPU.')
if (args.num_layers - 2) % 6 != 0:
raise ValueError('Invalid --num-layers parameter.')
if args.num_gpus != 0 and args.train_batch_size % args.num_gpus != 0:
raise ValueError('--train-batch-size must be multiple of --num-gpus.')
if args.num_gpus != 0 and args.eval_batch_size % args.num_gpus != 0:
raise ValueError('--eval-batch-size must be multiple of --num-gpus.')
main(**vars(args))