Skip to content

Commit

Permalink
Merge pull request #4 from HaoyuHu/fp16
Browse files Browse the repository at this point in the history
PR: FP16 Support
  • Loading branch information
haoyuhu authored May 24, 2019
2 parents 24434dc + 323e167 commit 98cde20
Show file tree
Hide file tree
Showing 6 changed files with 1,104 additions and 1,003 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# bert-multi-gpu

Feel free to fine tune large BERT models with large batch size easily. Multi-GPU are supported.
Feel free to fine tune large BERT models with large batch size easily. Multi-GPU and FP16 are supported.

## Dependencies

Expand All @@ -11,6 +11,15 @@ Feel free to fine tune large BERT models with large batch size easily. Multi-GPU



## Features

- CPU/GPU/TPU Support
- **Multi-GPU Support**: [`tf.distribute.MirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy) is used to achieve Multi-GPU support for this project, which mirrors vars to distribute across multiple devices and machines. The maximum batch_size for each GPU is almost the same as [bert](https://github.com/google-research/bert/blob/master/README.md#out-of-memory-issues). So **global batch_size** depends on how many GPUs there are.
- **FP16 Support**: [FP16](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) allows you to use a larger batch_size. And training speed will increase by 70~100% on Volta GPUs, but may be slower on Pascal GPUs([REF1](https://github.com/tensorflow/tensorflow/issues/15585#issuecomment-361769151), [REF2](https://github.com/HaoyuHu/bert-multi-gpu/issues/1#issuecomment-493363383)).
- **SavedModel Export**



## Usage

List some optional parameters below:
Expand All @@ -29,6 +38,7 @@ List some optional parameters below:
- `num_train_epochs`: Train epoch number.
- `use_gpu`: Use GPU or not.
- `num_gpu_cores`: Total number of GPU cores to use, only used if `use_gpu` is True.
- `use_fp16`: Use [`FP16`](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) or not.
- `output_dir`: **Checkpoints** and **SavedModel(.pb) files** will be saved in this directory.

```shell
Expand All @@ -49,6 +59,7 @@ python run_custom_classifier.py \
--num_train_epochs=3.0 \
--use_gpu=true \
--num_gpu_cores=3 \
--use_fp16=true \
--output_dir=/cfs/outputs/bert-large-uncased-qqp
```

Expand Down
30 changes: 25 additions & 5 deletions custom_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from tensorflow.python.ops import resource_variable_ops


def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps):
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, fp16=False):
"""Creates an optimizer training op."""
global_step = tf.train.get_or_create_global_step()

Expand Down Expand Up @@ -70,19 +70,39 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps):
epsilon=1e-6,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])

# REF: https://github.com/tensorflow/tensorflow/issues/25080
# if fp16:
# loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(
# init_loss_scale=2 ** 32,
# incr_every_n_steps=1000,
# decr_every_n_nan_or_inf=2,
# decr_ratio=0.5)
# optimizer = tf.contrib.mixed_precision.LossScaleOptimizer(optimizer, loss_scale_manager)

tvars = tf.trainable_variables()
grads = tf.gradients(loss, tvars)
gvs = optimizer.compute_gradients(loss, tvars)
gvs = [(g, v) for g, v in gvs if g is not None]
grads, tvars = list(zip(*gvs))
if fp16:
all_finite = tf.reduce_all([tf.reduce_all(tf.is_finite(g)) for g in grads])
else:
all_finite = tf.constant(True, dtype=tf.bool)

# This is how the model was pre-trained.
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0,
use_norm=tf.cond(
all_finite,
lambda: tf.global_norm(grads),
lambda: tf.constant(1.0)))

train_op = optimizer.apply_gradients(
zip(grads, tvars), global_step=global_step)

# Normally the global step update is done inside of `apply_gradients`.
# However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
# a different optimizer, you should probably take this line out.
new_global_step = global_step + 1
new_global_step = tf.cond(all_finite, lambda: global_step + 1, lambda: global_step)
new_global_step = tf.identity(new_global_step, name='update_step')
train_op = tf.group(train_op, [global_step.assign(new_global_step)])
return train_op

Expand All @@ -101,7 +121,7 @@ def __init__(self,
"""Constructs a AdamWeightDecayOptimizer."""
super(AdamWeightDecayOptimizer, self).__init__(False, name)

self.learning_rate = learning_rate
self.learning_rate = tf.identity(learning_rate, name='learning_rate')
self.weight_decay_rate = weight_decay_rate
self.beta_1 = beta_1
self.beta_2 = beta_2
Expand Down
37 changes: 37 additions & 0 deletions gpu_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# coding=utf-8
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf


def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
initializer=None, regularizer=None,
trainable=True,
*args, **kwargs):
"""Custom variable getter that forces trainable variables to be stored in
float32 precision and then casts them to the training precision.
"""
storage_dtype = tf.float32 if trainable else dtype
variable = getter(name, shape, dtype=storage_dtype,
initializer=initializer, regularizer=regularizer,
trainable=trainable,
*args, **kwargs)
if trainable and dtype != tf.float32:
variable = tf.cast(variable, dtype)
return variable


def get_custom_getter(compute_type):
return float32_variable_storage_getter if compute_type == tf.float16 else None
Loading

0 comments on commit 98cde20

Please sign in to comment.