Skip to content

Commit

Permalink
Incorporate Review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
shankgan committed Apr 1, 2021
1 parent b7fe9f9 commit 2fb34e4
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

# MultiWorkerMirrored Training Strategy with examples

The steps below are meant to train models using [MultiWorkerMirrored Strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy) using the tensorflow 2.0 API on the Kubernetes platform.
The steps below are meant to train models using [MultiWorkerMirrored Strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy) using the tensorflow 2.x API on the Kubernetes platform.

Reference programs such as [keras_mnist.py](examples/keras_mnist.py) and
[custom_training_mnist.py](examples/custom_training_mnist.py) are available in the examples directory.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ==============================================================================
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================

# This code serves as an example of using Tensorflow 2.0 to build and train a CNN model on the
# This code serves as an example of using Tensorflow 2.x to build and train a CNN model on the
# Fashion MNIST dataset using the tf.distribute.MultiWorkerMirroredStrategy described here
# https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy
# using a custom training loop. This code is very similar to the example provided here
Expand All @@ -33,8 +33,7 @@
MAIN_MODEL_PATH = '/pvcmnt'

EPOCHS = 10
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA
GLOBAL_BATCH_SIZE = 128

def _is_chief(task_type, task_id):
# If `task_type` is None, this may be operating as single worker, which works
Expand Down Expand Up @@ -92,7 +91,6 @@ def get_dist_data_set(strategy, batch_size):
def main():
global GLOBAL_BATCH_SIZE
strategy = tf.distribute.MultiWorkerMirroredStrategy()
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
train_dist_dataset, test_dist_dataset = get_dist_data_set(strategy, GLOBAL_BATCH_SIZE)
checkpoint_pfx = write_filepath(strategy)
with strategy.scope():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ==============================================================================
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================

# This code serves as an example of using Tensorflow 2.0 Keras API to build and train a CNN model on the
# This code serves as an example of using Tensorflow 2.x Keras API to build and train a CNN model on the
# MNIST dataset using the tf.distribute.MultiWorkerMirroredStrategy described here
# https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy.
# This code is very similar to the example provided here
Expand All @@ -39,6 +39,8 @@
# Model save directory
MAIN_MODEL_PATH = '/pvcmnt'

GLOBAL_BATCH_SIZE = 128

def _is_chief(task_type, task_id):
# If `task_type` is None, this may be operating as single worker, which works
# effectively as chief.
Expand Down Expand Up @@ -88,13 +90,11 @@ def build_and_compile_cnn_model():
return model

def main():
per_worker_batch_size = 64
tf_config = json.loads(os.environ['TF_CONFIG'])
num_workers = len(tf_config['cluster']['worker'])
strategy = tf.distribute.MultiWorkerMirroredStrategy()

global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_dataset(global_batch_size)
multi_worker_dataset = mnist_dataset(GLOBAL_BATCH_SIZE)

# missing needs to be fixed
# multi_worker_dataset = strategy.distribute_datasets_from_function(mnist_dataset(global_batch_size))
Expand Down

0 comments on commit 2fb34e4

Please sign in to comment.