diff --git a/distribution_strategy/multi_worker_mirrored_strategy/README.md b/distribution_strategy/multi_worker_mirrored_strategy/README.md index 161eaca..e010258 100644 --- a/distribution_strategy/multi_worker_mirrored_strategy/README.md +++ b/distribution_strategy/multi_worker_mirrored_strategy/README.md @@ -1,7 +1,7 @@ # MultiWorkerMirrored Training Strategy with examples -The steps below are meant to train models using [MultiWorkerMirrored Strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy) using the tensorflow 2.0 API on the Kubernetes platform. +The steps below are meant to train models using [MultiWorkerMirrored Strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy) using the tensorflow 2.x API on the Kubernetes platform. Reference programs such as [keras_mnist.py](examples/keras_mnist.py) and [custom_training_mnist.py](examples/custom_training_mnist.py) are available in the examples directory. diff --git a/distribution_strategy/multi_worker_mirrored_strategy/examples/custom_training_mnist.py b/distribution_strategy/multi_worker_mirrored_strategy/examples/custom_training_mnist.py index fc4f429..7a0c2e0 100644 --- a/distribution_strategy/multi_worker_mirrored_strategy/examples/custom_training_mnist.py +++ b/distribution_strategy/multi_worker_mirrored_strategy/examples/custom_training_mnist.py @@ -1,5 +1,5 @@ # ============================================================================== -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -# This code serves as an example of using Tensorflow 2.0 to build and train a CNN model on the +# This code serves as an example of using Tensorflow 2.x to build and train a CNN model on the # Fashion MNIST dataset using the tf.distribute.MultiWorkerMirroredStrategy described here # https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy # using a custom training loop. This code is very similar to the example provided here @@ -33,8 +33,7 @@ MAIN_MODEL_PATH = '/pvcmnt' EPOCHS = 10 -BATCH_SIZE_PER_REPLICA = 64 -GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA +GLOBAL_BATCH_SIZE = 128 def _is_chief(task_type, task_id): # If `task_type` is None, this may be operating as single worker, which works @@ -92,7 +91,6 @@ def get_dist_data_set(strategy, batch_size): def main(): global GLOBAL_BATCH_SIZE strategy = tf.distribute.MultiWorkerMirroredStrategy() - GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync train_dist_dataset, test_dist_dataset = get_dist_data_set(strategy, GLOBAL_BATCH_SIZE) checkpoint_pfx = write_filepath(strategy) with strategy.scope(): diff --git a/distribution_strategy/multi_worker_mirrored_strategy/examples/keras_mnist.py b/distribution_strategy/multi_worker_mirrored_strategy/examples/keras_mnist.py index 288fed8..41882c7 100644 --- a/distribution_strategy/multi_worker_mirrored_strategy/examples/keras_mnist.py +++ b/distribution_strategy/multi_worker_mirrored_strategy/examples/keras_mnist.py @@ -1,5 +1,5 @@ # ============================================================================== -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -# This code serves as an example of using Tensorflow 2.0 Keras API to build and train a CNN model on the +# This code serves as an example of using Tensorflow 2.x Keras API to build and train a CNN model on the # MNIST dataset using the tf.distribute.MultiWorkerMirroredStrategy described here # https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy. # This code is very similar to the example provided here @@ -39,6 +39,8 @@ # Model save directory MAIN_MODEL_PATH = '/pvcmnt' +GLOBAL_BATCH_SIZE = 128 + def _is_chief(task_type, task_id): # If `task_type` is None, this may be operating as single worker, which works # effectively as chief. @@ -88,13 +90,11 @@ def build_and_compile_cnn_model(): return model def main(): - per_worker_batch_size = 64 tf_config = json.loads(os.environ['TF_CONFIG']) num_workers = len(tf_config['cluster']['worker']) strategy = tf.distribute.MultiWorkerMirroredStrategy() - global_batch_size = per_worker_batch_size * num_workers - multi_worker_dataset = mnist_dataset(global_batch_size) + multi_worker_dataset = mnist_dataset(GLOBAL_BATCH_SIZE) # missing needs to be fixed # multi_worker_dataset = strategy.distribute_datasets_from_function(mnist_dataset(global_batch_size))