Skip to content

Commit

Permalink
Merge pull request #206 from gpengzhi/pretrained
Browse files Browse the repository at this point in the history
Initial code refactor on Pre-trained modules
  • Loading branch information
gpengzhi authored Sep 18, 2019
2 parents 033951e + 6d6323c commit bd2dbe4
Show file tree
Hide file tree
Showing 35 changed files with 1,825 additions and 1,907 deletions.
52 changes: 24 additions & 28 deletions examples/bert/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

This is a Texar implementation of Google's BERT model, which allows to load pre-trained model parameters downloaded from the [official release](https://github.com/google-research/bert) and build/fine-tune arbitrary downstream applications with **distributed training** (This example showcases BERT for sentence classification).

This example shows two ways of building a BERT classifier, at different abstraction levels:

* Use `texar.tf.modules.BERTClassifier` ([doc](https://texar.readthedocs.io/en/latest/code/modules.html#texar.modules.BertClassifier)) directly. The module supports both sequence classification (one label per sequence) and sequence labeling (one label per token). --- See `bert_classifier_main_v2.py` for implementation.
* Use lower-level modules by creating a `TransformerEncoder` ([doc](https://texar.readthedocs.io/en/latest/code/modules.html#transformerencoder)) instance and adding additional layers. Initialization with a pre-trained BERT checkpoint is done by calling `init_bert_checkpoint(path_to_bert_checkpoint)`. --- See `bert_classifier_main.py` for implementation.
Texar provides ready-to-use modules including
[`BERTEncoder`](https://texar.readthedocs.io/en/latest/code/modules.html#bertencoder),
and [`BERTClassifier`](https://texar.readthedocs.io/en/latest/code/modules.html#bertclassifier).
This example shows the use of `BERTClassifier` for sentence classification tasks.

In sum, this example showcases:

Expand All @@ -16,44 +16,41 @@ In sum, this example showcases:

## Quick Start

### Download BERT Pre-train Model

```
sh bert_pretrained_models/download_model.sh
```
By default, it will download a pretrained model (BERT-Base Uncased: 12-layer, 768-hidden, 12-heads, 110M parameters) named `uncased_L-12_H-768_A-12` to `bert_pretrained_models/`.

Under `bert_pretrained_models/uncased_L-12_H-768_A-12`, you can find 5 files, where
- `bert-config.json` is the model configuration of the BERT model. For the particular model we just downloaded, it is an uncased-vocabulary, 12-layer, 768-hidden, 12-heads Transformer model.

### Download Dataset

We explain the use of the example code based on the Microsoft Research Paraphrase Corpus (MRPC) corpus for sentence classification.

Download the data with the following cmd
Download the data with the following command:

```
python data/download_glue_data.py --tasks=MRPC
```
By default, it will download the MRPC dataset into the `data` directory. FYI, the MRPC dataset part of the [GLUE](https://gluebenchmark.com/tasks) dataset collection.

By default, it will download the MRPC dataset into the `data` directory. FYI, the MRPC dataset is part of the [GLUE](https://gluebenchmark.com/tasks) dataset collection.

### Prepare data

We first preprocess the downloaded raw data into [TFRecord](https://www.tensorflow.org/tutorials/load_data/tf_records) files. The preprocessing tokenizes raw text with BPE encoding, truncates sequences, adds special tokens, etc.
Run the following cmd to this end:
Run the following command to this end:

```
python prepare_data.py --task=MRPC
[--max_seq_length=128]
[--vocab_file=bert_pretrained_models/uncased_L-12_H-768_A-12/vocab.txt]
[--tfrecord_output_dir=data/MRPC]
```
- `task`: Specifies the dataset name to preprocess. BERT provides default support for `{'CoLA', 'MNLI', 'MRPC', 'XNLI', 'SST'}` data.
- `max_seq_length`: The maxium length of sequence. This includes BERT special tokens that will be automatically added. Longer sequence will be trimmed.
- `vocab_file`: Path to a vocabary file used for tokenization.
- `tfrecord_output_dir`: The output path where the resulting TFRecord files will be put in. Be default, it is set to `data/{task}` where `{task}` is the (upper-cased) dataset name specified in `--task` above. So in the above cmd, the TFRecord files are output to `data/MRPC`.

- `--task`: Specifies the dataset name to preprocess. BERT provides default support for `{'CoLA', 'MNLI', 'MRPC', 'XNLI', 'SST'}` data.
- `--max_seq_length`: The maxium length of sequence. This includes BERT special tokens that will be automatically added. Longer sequence will be trimmed.
- `--vocab_file`: Path to a vocabary file used for tokenization.
- `--tfrecord_output_dir`: The output path where the resulting TFRecord files will be put in. Be default, it is set to `data/{task}` where `{task}` is the (upper-cased) dataset name specified in `--task` above. So in the above cmd, the TFRecord files are output to `data/MRPC`.

**Outcome of the Preprocessing**:

- The preprocessing will output 3 TFRecord data files `{train.tf_record, eval.tf_record, test.tf_record}` in the specified output directory.
- The cmd also prints logs as follows:

- The command also prints logs as follows:

```
INFO:tensorflow:Loading data
INFO:tensorflow:num_classes:2; num_train_data:3668
Expand All @@ -66,25 +63,25 @@ Run the following cmd to this end:
### Train and Evaluate

For **single-GPU** training (and evaluation), run the following cmd. The training updates the classification layer and fine-tunes the pre-trained BERT parameters.

```
python bert_classifier_main.py --do_train --do_eval
[--config_bert_pretrain=uncased_L-12_H-768_A-12]
[--config_downstream=config_classifier]
[--config_data=config_data]
[--output_dir=output]
```
Here:

- `config_bert_pretrain`: Specifies the architecture of pre-trained BERT model. Used to find architecture configs under `bert_pretrained_models/{config_bert_pretrain}`.
- `config_downstream`: Configuration of the downstream part. In this example, [`config_classifier`](./config_classifier.py) configures the classification layer and the optimization method.
- `config_data`: The data configuration. See the default [`config_data.py`](./config_data.py) for example. Make sure to specify `num_classes`, `num_train_data`, `max_seq_length`, and `tfrecord_data_dir` as used or output in the above [data preparation](#prepare-data) step.
- `output_dir`: The output path where checkpoints and TensorBoard summaries are saved.
- `pretrained_model_name`: The name of a pre-trained model to load selected in the list of: `bert-base-uncased`, `bert-large-uncased`, `bert-base-cased`, `bert-large-cased`, `bert-base-multilingual-uncased`, `bert-base-multilingual-cased`, and `bert-base-chinese`.

*[NOTE: you can also use `bert_classifier_main_v2.py` in the above]*

For **Multi-GPU training** on one or multiple machines, you may first install the prerequisite OpenMPI and Hovorod packages, as detailed in the [distributed_gpu](https://github.com/asyml/texar/tree/master/examples/distributed_gpu) example.

Then run the following cmd for training and evaluation. The cmd trains the model on local with 2 GPUs. Evaluation is performed with the single rank-0 GPU.

```
mpirun -np 2 \
-H localhost:2\
Expand All @@ -93,7 +90,6 @@ mpirun -np 2 \
-mca pml ob1 -mca btl tcp,self \
-mca btl_tcp_if_include ens3 \
python bert_classifier_main.py --do_train --do_eval --distributed
[--config_bert_pretrain=uncased_L-12_H-768_A-12]
[--config_downstream=config_classifier]
[--config_data=config_data]
[--output_dir=output]
Expand All @@ -107,9 +103,8 @@ Please refer to [distributed_gpu](https://github.com/asyml/texar/tree/master/exa

Make sure to specifiy the `--distributed` flag as above for multi-gpu training.

 

After convergence, the evaluation performance is around the following. Due to certain randomness (e.g., random initialization of the classification layer), the evaluation accuracy is reasonable as long as it's `>0.84`.

```
INFO:tensorflow:dev accu: 0.8676470588235294
```
Expand All @@ -128,6 +123,7 @@ The output is by default saved in `output/test_results.tsv`, where each line con
`bert_classifier_main.py` also support other datasets/tasks. To do this, specify a different value to the `--task` flag when running [data preparation](#prepare-data).

For example, use the following commands to download the SST (Stanford Sentiment Treebank) dataset and run for sentence classification. Make sure to specify the correct data path and other info in the data configuration file.

```
python data/download_glue_data.py --tasks=SST
python prepare_data.py --task=SST
Expand Down
98 changes: 18 additions & 80 deletions examples/bert/bert_classifier_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example of building a sentence classifier based on pre-trained BERT
model.
"""Example of building a sentence classifier based on pre-trained BERT model.
"""

from __future__ import absolute_import
Expand All @@ -32,17 +31,12 @@

FLAGS = flags.FLAGS

flags.DEFINE_string(
"config_bert_pretrain", 'uncased_L-12_H-768_A-12',
"The architecture of pre-trained BERT model to use.")
flags.DEFINE_string(
"config_format_bert", "json",
"The configuration format. Set to 'json' if the BERT config file is in "
"the same format of the official BERT config file. Set to 'texar' if the "
"BERT config file is in Texar format.")
flags.DEFINE_string(
"config_downstream", "config_classifier",
"Configuration of the downstream part of the model and optmization.")
"Configuration of the downstream part of the model.")
flags.DEFINE_string(
"pretrained_model_name", "bert-base-uncased",
"Name of the pre-trained checkpoint to load.")
flags.DEFINE_string(
"config_data", "config_data",
"The dataset config.")
Expand All @@ -60,11 +54,11 @@
config_data = importlib.import_module(FLAGS.config_data)
config_downstream = importlib.import_module(FLAGS.config_downstream)


def main(_):
"""
Builds the model and runs.
"""

if FLAGS.distributed:
import horovod.tensorflow as hvd
hvd.init()
Expand All @@ -73,22 +67,7 @@ def main(_):

tx.utils.maybe_create_dir(FLAGS.output_dir)

bert_pretrain_dir = ('bert_pretrained_models'
'/%s') % FLAGS.config_bert_pretrain
# Loads BERT model configuration
if FLAGS.config_format_bert == "json":
bert_config = model_utils.transform_bert_to_texar_config(
os.path.join(bert_pretrain_dir, 'bert_config.json'))
elif FLAGS.config_format_bert == 'texar':
bert_config = importlib.import_module(
('bert_config_lib.'
'config_model_%s') % FLAGS.config_bert_pretrain)
else:
raise ValueError('Unknown config_format_bert.')

# Loads data

num_classes = config_data.num_classes
num_train_data = config_data.num_train_data

# Configures distribued mode
Expand All @@ -110,54 +89,17 @@ def main(_):
input_length = tf.reduce_sum(1 - tf.cast(tf.equal(input_ids, 0), tf.int32),
axis=1)
# Builds BERT
with tf.variable_scope('bert'):
# Word embedding
embedder = tx.modules.WordEmbedder(
vocab_size=bert_config.vocab_size,
hparams=bert_config.embed)
word_embeds = embedder(input_ids)

# Segment embedding for each type of tokens
segment_embedder = tx.modules.WordEmbedder(
vocab_size=bert_config.type_vocab_size,
hparams=bert_config.segment_embed)
segment_embeds = segment_embedder(segment_ids)

# Position embedding
position_embedder = tx.modules.PositionEmbedder(
position_size=bert_config.position_size,
hparams=bert_config.position_embed)
seq_length = tf.ones([batch_size], tf.int32) * tf.shape(input_ids)[1]
pos_embeds = position_embedder(sequence_length=seq_length)

# Aggregates embeddings
input_embeds = word_embeds + segment_embeds + pos_embeds

# The BERT model (a TransformerEncoder)
encoder = tx.modules.TransformerEncoder(hparams=bert_config.encoder)
output = encoder(input_embeds, input_length)

# Builds layers for downstream classification, which is also
# initialized with BERT pre-trained checkpoint.
with tf.variable_scope("pooler"):
# Uses the projection of the 1st-step hidden vector of BERT output
# as the representation of the sentence
bert_sent_hidden = tf.squeeze(output[:, 0:1, :], axis=1)
bert_sent_output = tf.layers.dense(
bert_sent_hidden, config_downstream.hidden_dim,
activation=tf.tanh)
output = tf.layers.dropout(
bert_sent_output, rate=0.1, training=tx.global_mode_train())

# Adds the final classification layer
logits = tf.layers.dense(
output, num_classes,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
preds = tf.argmax(logits, axis=-1, output_type=tf.int32)
hparams = {
'clas_strategy': 'cls_time'
}
model = tx.modules.BERTClassifier(
pretrained_model_name=FLAGS.pretrained_model_name,
hparams=hparams)
logits, preds = model(input_ids, input_length, segment_ids)

accu = tx.evals.accuracy(batch['label_ids'], preds)

# Optimization

loss = tf.losses.sparse_softmax_cross_entropy(
labels=batch["label_ids"], logits=logits)
global_step = tf.Variable(0, trainable=False)
Expand All @@ -167,7 +109,7 @@ def main(_):
num_train_steps = int(num_train_data / config_data.train_batch_size
* config_data.max_train_epoch)
num_warmup_steps = int(num_train_steps * config_data.warmup_proportion)
lr = model_utils.get_lr(global_step, num_train_steps, # lr is a Tensor
lr = model_utils.get_lr(global_step, num_train_steps, # lr is a Tensor
num_warmup_steps, static_lr)

opt = tx.core.get_optimizer(
Expand All @@ -190,8 +132,7 @@ def main(_):
def _is_head():
if not FLAGS.distributed:
return True
else:
return hvd.rank() == 0
return hvd.rank() == 0

def _train_epoch(sess):
"""Trains on the training set, and evaluates on the dev set
Expand All @@ -217,7 +158,7 @@ def _train_epoch(sess):

dis_steps = config_data.display_steps
if _is_head() and dis_steps > 0 and step % dis_steps == 0:
tf.logging.info('step:%d; loss:%f' % (step, rets['loss']))
tf.logging.info('step:%d; loss:%f;' % (step, rets['loss']))

eval_steps = config_data.eval_steps
if _is_head() and eval_steps > 0 and step % eval_steps == 0:
Expand Down Expand Up @@ -277,10 +218,6 @@ def _test_epoch(sess):
with tf.gfile.GFile(output_file, "w") as writer:
writer.write('\n'.join(str(p) for p in _all_preds))

# Loads pretrained BERT model parameters
init_checkpoint = os.path.join(bert_pretrain_dir, 'bert_model.ckpt')
model_utils.init_bert_checkpoint(init_checkpoint)

# Broadcasts global variables from rank-0 process
if FLAGS.distributed:
bcast = hvd.broadcast_global_variables(0)
Expand Down Expand Up @@ -315,5 +252,6 @@ def _test_epoch(sess):
if FLAGS.do_test:
_test_epoch(sess)


if __name__ == "__main__":
tf.app.run()
Loading

0 comments on commit bd2dbe4

Please sign in to comment.