Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Examples of MultiWorkerMirroredStrategy using Keras and Custom Training Loop #181

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ request.
* Jupyter images for different versions of TensorFlow
* [TFServing](https://github.com/kubeflow/kubeflow/blob/master/user_guide.md#serve-a-model-using-tensorflow-serving) Docker images and K8s templates
- [kubernetes](kubernetes) - Templates for running distributed TensorFlow on
Kubernetes.
Kubernetes. For the most upto-date examples, please also refer to the [distribution strategy](distribution_strategy) folder.
- [marathon](marathon) - Templates for running distributed TensorFlow using
Marathon, deployed on top of Mesos.
- [hadoop](hadoop) - TFRecord file InputFormat/OutputFormat for Hadoop MapReduce
Expand Down
117 changes: 117 additions & 0 deletions distribution_strategy/multi_worker_mirrored_strategy/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@

# MultiWorkerMirrored Training Strategy with examples

The steps below are meant to train models using [MultiWorkerMirrored Strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy) using the tensorflow 2.0 API on the Kubernetes platform.
shankgan marked this conversation as resolved.
Show resolved Hide resolved

Reference programs such as [keras_mnist.py](examples/keras_mnist.py) and
[custom_training_mnist.py](examples/custom_training_mnist.py) are available in the examples directory.

The Kubernetes manifest templates and other cluster specific configuration is available in the [kubernetes](kubernetes) directory

## Prerequisites

1. (Optional) It is recommended that you have a Google Cloud project. Either create a new project or use an existing one. Install
[gcloud commandline tools](https://cloud.google.com/functions/docs/quickstart)
on your system, login, set project and zone, etc.

2. [Jinja templates](http://jinja.pocoo.org/) must be installed.

3. A Kubernetes cluster running Kubernetes 1.15 or above must be available. To create a test
cluster on the local machine, [follow steps here](https://kubernetes.io/docs/tutorials/kubernetes-basics/create-cluster/). Kubernetes clusters can also be created on all major cloud providers. For instance,
here are instructions to [create GKE clusters](https://cloud.google.com/kubernetes-engine/docs/how-to/creating-a-regional-cluster). Make sure that you have atleast 12 G of RAM between all nodes in the clusters. This should also install the `kubectl` tool on your system

4. Set context for `kubectl` so that `kubectl` knows which cluster to use:

```bash
kubectl config use-context <cluster_name>
```

5. Install [Docker](https://docs.docker.com/get-docker/) for your system, while also creating an account that you can associate with your container images.

6. For model storage and checkpointing, a [persistent-volume-claim](https://kubernetes.io/docs/concepts/storage/persistent-volumes/) needs to be available to mount onto the chief worker pod. The steps below include the yaml to create a persistent-volume-claim for GKE backed by GCEPersistentDisk.

### Steps to Run the job

1. Follow the instructions for building and pushing the Docker image to a docker registry
in the [Docker README](examples/README.md).

2. Copy the template file:

```sh
cp kubernetes/MultiWorkerMirroredTemplate.yaml.jinja myjob.template.jinja
```

3. Edit the `myjob.template.jinja` file to edit job parameters.
1. `script` - which training program needs to be run. This should be either
`keras_mnist.py` or `custom_training_mnist.py` or `your_own_training_example.py`

2. `name` - the prefix attached to all the Kubernetes jobs created

3. `worker_replicas` - number of parallel worker processes that train the example

4. `port` - the port used by tensorflow worker processes to communicate with each other

5. `checkpoint_pvc_name` - name of the persistent-volume-claim that will contain the checkpointed model.

6. `model_checkpoint_dir` - mount location for inspecting the trained model in the volume inspector pod. Meant to be set if Volume inspector pod is mounted.

7. `image` - name of the docker image created in step 2 that needs to be loaded onto the cluster

8. `deploy` - set to True when the manifest is actually expected to be deployed

9. `create_pvc_checkpoint` - Creates a ReadWriteOnce persistent volume claim to checkpoint the model if needed. The name of the claim `checkpoint_pvc_name` should also be specified.

10. `create_volume_inspector` - Create a pod to inspect the contents of the volume after the training job is complete. If this is `True`, `deploy` cannot be `True` since the checkpoint volume can be mounted as read-write by a single node. Inspection cannot happen when training is happenning.

4. Run the job:
1. Create a namespace to run your training jobs

```sh
kubectl create namespace <namespace>
```

2. [Optional: If Persistent volume does not already exist on cluster] First set `deploy` to `False`, `create_pvc_checkpoint` to `True` and set the name of `checkpoint_pvc_name` appropriately in the .jinja file. Then run

```sh
python ../../render_template.py myjob.template.jinja | kubectl apply -n <namespace> -f -
```

This will create a persistent volume claim where you can checkpoint your image. In GKE, this claim will auto-create a GCE persistent disk resource to back up the claim.

3. Set `deploy` to `True`, `create_pvc_checkpoint` to `False`, with all parameters specified in step 4 and then run

```sh
python ../../render_template.py myjob.template.jinja | kubectl apply -n <namespace> -f -
```

This will create the Kubernetes jobs on the clusters. Each Job has a single service-endpoint and a single pod that runs the training image. You can track the running jobs in the cluster by running

```sh
kubectl get jobs -n <namespace>
kubectl describe jobs -n <namespace>
```

In order to inspect the trainining logs that are running in the jobs, run

```sh
# Shows all the running pods
kubectl get pods -n <namespace>
kubectl logs -n <namespace> -p <pod-name>
```

4. Once the jobs are finished (based on the logs/output of kubectl get jobs),
the trained model can be inspected by a volume inspector pod. Set `deploy` to `False`
and `create_volume_inspector` to True. Also set `model_checkpoint_dir` to indicate location where trained model will be mounted. Then run

```sh
python ../../render_template.py myjob.template.jinja | kubectl apply -n <namespace> -f -
```

This will create the volume inspector pod. Then, access the pod through ssh

```sh
kubectl get pods -n <namespace>
kubectl -n <namspace> exec --stdin --tty <volume-inspector-pod> -- /bin/sh
```

The contents of the trained model are available for inspection at `model_checkpoint_dir`.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
FROM tensorflow/tensorflow:nightly

# Keeps Python from generating .pyc files in the container
ENV PYTHONDONTWRITEBYTECODE=1

# Turns off buffering for easier container logging
ENV PYTHONUNBUFFERED=1

WORKDIR /app

COPY . /app/

ENTRYPOINT ["python", "keras_mnist.py"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# TensorFlow Docker Images

This directory contains examples of MultiWorkerMirrored Training along with the docker file to build them

- [Dockerfile](Dockerfile) contains all dependenices required to build a container image using docker with the training examples
- [keras_mnist.py](mnist.py) demonstrates how to train an MNIST classifier using
[tf.distribute.MultiWorkerMirroredStrategy and Keras Tensorflow 2.0 API](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras).
- [custom_training_mnist.py](mnist.py) demonstrates how to train a fashion MNIST classifier using
[tf.distribute.MultiWorkerMirroredStrategy and Tensorflow 2.0 Custom Training Loop APIs](https://www.tensorflow.org/tutorials/distribute/custom_training).

## Best Practices

- Always pin the TensorFlow version with the Docker image tag. This ensures that
TensorFlow updates don't adversely impact your training program for future
runs.
- When creating an image, specify version tags (see below). If you make code
changes, increment the version. Cluster managers will not pull an updated
Docker image if they have them cached. Also, versions ensure that you have
a single copy of the code running for each job.

## Building the Docker Files

Ensure that docker is installed on your system.

First, pick an image name for the job. When running on a cluster manager, you
will want to push your images to a container registry. Note that both the
[Google Container Registry](https://cloud.google.com/container-registry/)
and the [Amazon EC2 Container Registry](https://aws.amazon.com/ecr/) require
special paths. We append `:v1` to version our images. Versioning images is
strongly recommended for reasons described in the best practices section.

```sh
docker build -t <image_name>:v1 -f Dockerfile .
# Use gcloud docker push instead if on Google Container Registry.
docker push <image_name>:v1
```

If you make any updates to the code, increment the version and rerun the above
commands with the new version.

## Running the keras_mnist.py example

The [keras_mnist.py](keras_mnist.py) example demonstrates how to train an MNIST classifier using
[tf.distribute.MultiWorkerMirroredStrategy and Keras Tensorflow 2.0 API](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras).
The final model is saved to disk by the chief worker process. The disk is assumed to be mounted onto the running container by the cluster manager.
It assumes that the cluster configuration is passed in through the `TF_CONFIG` environment variable when deployed in the cluster

## Running the custom_training_mnist.py example

The [custom_training_mnist.py](mnist.py) example demonstrates how to train a fashion MNIST classifier using
[tf.distribute.MultiWorkerMirroredStrategy and Tensorflow 2.0 Custom Training Loop APIs](https://www.tensorflow.org/tutorials/distribute/custom_training).
The final model is saved to disk by the chief worker process. The disk is assumed to be mounted onto the running container by the cluster manager.
It assumes that the cluster configuration is passed in through the `TF_CONFIG` environment variable when deployed in the cluster.
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# ==============================================================================
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
shankgan marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# This code serves as an example of using Tensorflow 2.0 to build and train a CNN model on the
# Fashion MNIST dataset using the tf.distribute.MultiWorkerMirroredStrategy described here
# https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy
# using a custom training loop. This code is very similar to the example provided here
# https://www.tensorflow.org/tutorials/distribute/custom_training
# Assumptions:
# 1) The code assumes that the cluster configuration needed for the TF distribute strategy is available through the
# TF_CONFIG environment variable. See the link provided above for details
# 2) The model is checkpointed and saved in /pvcmnt by the chief worker process.

import tensorflow as tf
import numpy as np
import os

# Used to run example using CPU only. Untested on GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
MAIN_MODEL_PATH = '/pvcmnt'

EPOCHS = 10
BATCH_SIZE_PER_REPLICA = 64
shankgan marked this conversation as resolved.
Show resolved Hide resolved
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA

def _is_chief(task_type, task_id):
# If `task_type` is None, this may be operating as single worker, which works
# effectively as chief.
return task_type is None or task_type == 'chief' or (
task_type == 'worker' and task_id == 0)

def _get_temp_dir(task_id):
base_dirpath = 'workertemp_' + str(task_id)
temp_dir = os.path.join("/tmp", base_dirpath)
os.makedirs(temp_dir)
return temp_dir

def write_filepath(strategy):
task_type, task_id = strategy.cluster_resolver.task_type, strategy.cluster_resolver.task_id
if not _is_chief(task_type, task_id):
checkpoint_dir = _get_temp_dir(task_id)
else:
base_dirpath = 'workertemp_' + str(task_id)
checkpoint_dir = os.path.join(MAIN_MODEL_PATH, base_dirpath)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
return checkpoint_dir

def create_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
return model

def get_dist_data_set(strategy, batch_size):
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]
# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(60000).batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(batch_size)
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
return train_dist_dataset, test_dist_dataset

def main():
global GLOBAL_BATCH_SIZE
strategy = tf.distribute.MultiWorkerMirroredStrategy()
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
train_dist_dataset, test_dist_dataset = get_dist_data_set(strategy, GLOBAL_BATCH_SIZE)
checkpoint_pfx = write_filepath(strategy)
with strategy.scope():
model = create_model()
optimizer = tf.keras.optimizers.Adam()
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE)
test_loss = tf.keras.metrics.Mean(name='test_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='train_accuracy')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='test_accuracy')

def compute_loss(labels, predictions):
per_example_loss = loss_object(labels, predictions)
return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

def test_step(inputs):
images, labels = inputs
predictions = model(images, training=False)
t_loss = loss_object(labels, predictions)
test_loss.update_state(t_loss)
test_accuracy.update_state(labels, predictions)

def train_step(inputs):
images, labels = inputs
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = compute_loss(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_accuracy.update_state(labels, predictions)
return loss

# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
# TRAIN LOOP
total_loss = 0.0
num_batches = 0
for x in train_dist_dataset:
total_loss += distributed_train_step(x)
num_batches += 1
train_loss = total_loss / num_batches

# TEST LOOP
for x in test_dist_dataset:
distributed_test_step(x)
if epoch % 2 == 0:
checkpoint.save(checkpoint_pfx)

template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
"Test Accuracy: {}")
print (template.format(epoch+1, train_loss,
train_accuracy.result()*100, test_loss.result(),
test_accuracy.result()*100))

test_loss.reset_states()
train_accuracy.reset_states()
test_accuracy.reset_states()

if __name__=="__main__":
main()
Loading