diff --git a/examples/bert/README.md b/examples/bert/README.md index 28cdf454..7fbbc724 100644 --- a/examples/bert/README.md +++ b/examples/bert/README.md @@ -2,10 +2,10 @@ This is a Texar implementation of Google's BERT model, which allows to load pre-trained model parameters downloaded from the [official release](https://github.com/google-research/bert) and build/fine-tune arbitrary downstream applications with **distributed training** (This example showcases BERT for sentence classification). -This example shows two ways of building a BERT classifier, at different abstraction levels: - - * Use `texar.tf.modules.BERTClassifier` ([doc](https://texar.readthedocs.io/en/latest/code/modules.html#texar.modules.BertClassifier)) directly. The module supports both sequence classification (one label per sequence) and sequence labeling (one label per token). --- See `bert_classifier_main_v2.py` for implementation. - * Use lower-level modules by creating a `TransformerEncoder` ([doc](https://texar.readthedocs.io/en/latest/code/modules.html#transformerencoder)) instance and adding additional layers. Initialization with a pre-trained BERT checkpoint is done by calling `init_bert_checkpoint(path_to_bert_checkpoint)`. --- See `bert_classifier_main.py` for implementation. +Texar provides ready-to-use modules including +[`BERTEncoder`](https://texar.readthedocs.io/en/latest/code/modules.html#bertencoder), +and [`BERTClassifier`](https://texar.readthedocs.io/en/latest/code/modules.html#bertclassifier). +This example shows the use of `BERTClassifier` for sentence classification tasks. In sum, this example showcases: @@ -16,44 +16,41 @@ In sum, this example showcases: ## Quick Start -### Download BERT Pre-train Model - -``` -sh bert_pretrained_models/download_model.sh -``` -By default, it will download a pretrained model (BERT-Base Uncased: 12-layer, 768-hidden, 12-heads, 110M parameters) named `uncased_L-12_H-768_A-12` to `bert_pretrained_models/`. - -Under `bert_pretrained_models/uncased_L-12_H-768_A-12`, you can find 5 files, where -- `bert-config.json` is the model configuration of the BERT model. For the particular model we just downloaded, it is an uncased-vocabulary, 12-layer, 768-hidden, 12-heads Transformer model. - ### Download Dataset We explain the use of the example code based on the Microsoft Research Paraphrase Corpus (MRPC) corpus for sentence classification. -Download the data with the following cmd +Download the data with the following command: + ``` python data/download_glue_data.py --tasks=MRPC ``` -By default, it will download the MRPC dataset into the `data` directory. FYI, the MRPC dataset part of the [GLUE](https://gluebenchmark.com/tasks) dataset collection. + +By default, it will download the MRPC dataset into the `data` directory. FYI, the MRPC dataset is part of the [GLUE](https://gluebenchmark.com/tasks) dataset collection. ### Prepare data We first preprocess the downloaded raw data into [TFRecord](https://www.tensorflow.org/tutorials/load_data/tf_records) files. The preprocessing tokenizes raw text with BPE encoding, truncates sequences, adds special tokens, etc. -Run the following cmd to this end: +Run the following command to this end: + ``` python prepare_data.py --task=MRPC [--max_seq_length=128] [--vocab_file=bert_pretrained_models/uncased_L-12_H-768_A-12/vocab.txt] [--tfrecord_output_dir=data/MRPC] ``` -- `task`: Specifies the dataset name to preprocess. BERT provides default support for `{'CoLA', 'MNLI', 'MRPC', 'XNLI', 'SST'}` data. -- `max_seq_length`: The maxium length of sequence. This includes BERT special tokens that will be automatically added. Longer sequence will be trimmed. -- `vocab_file`: Path to a vocabary file used for tokenization. -- `tfrecord_output_dir`: The output path where the resulting TFRecord files will be put in. Be default, it is set to `data/{task}` where `{task}` is the (upper-cased) dataset name specified in `--task` above. So in the above cmd, the TFRecord files are output to `data/MRPC`. + +- `--task`: Specifies the dataset name to preprocess. BERT provides default support for `{'CoLA', 'MNLI', 'MRPC', 'XNLI', 'SST'}` data. +- `--max_seq_length`: The maxium length of sequence. This includes BERT special tokens that will be automatically added. Longer sequence will be trimmed. +- `--vocab_file`: Path to a vocabary file used for tokenization. +- `--tfrecord_output_dir`: The output path where the resulting TFRecord files will be put in. Be default, it is set to `data/{task}` where `{task}` is the (upper-cased) dataset name specified in `--task` above. So in the above cmd, the TFRecord files are output to `data/MRPC`. **Outcome of the Preprocessing**: + - The preprocessing will output 3 TFRecord data files `{train.tf_record, eval.tf_record, test.tf_record}` in the specified output directory. -- The cmd also prints logs as follows: + +- The command also prints logs as follows: + ``` INFO:tensorflow:Loading data INFO:tensorflow:num_classes:2; num_train_data:3668 @@ -66,25 +63,25 @@ Run the following cmd to this end: ### Train and Evaluate For **single-GPU** training (and evaluation), run the following cmd. The training updates the classification layer and fine-tunes the pre-trained BERT parameters. + ``` python bert_classifier_main.py --do_train --do_eval - [--config_bert_pretrain=uncased_L-12_H-768_A-12] [--config_downstream=config_classifier] [--config_data=config_data] [--output_dir=output] ``` Here: -- `config_bert_pretrain`: Specifies the architecture of pre-trained BERT model. Used to find architecture configs under `bert_pretrained_models/{config_bert_pretrain}`. - `config_downstream`: Configuration of the downstream part. In this example, [`config_classifier`](./config_classifier.py) configures the classification layer and the optimization method. - `config_data`: The data configuration. See the default [`config_data.py`](./config_data.py) for example. Make sure to specify `num_classes`, `num_train_data`, `max_seq_length`, and `tfrecord_data_dir` as used or output in the above [data preparation](#prepare-data) step. - `output_dir`: The output path where checkpoints and TensorBoard summaries are saved. +- `pretrained_model_name`: The name of a pre-trained model to load selected in the list of: `bert-base-uncased`, `bert-large-uncased`, `bert-base-cased`, `bert-large-cased`, `bert-base-multilingual-uncased`, `bert-base-multilingual-cased`, and `bert-base-chinese`. -*[NOTE: you can also use `bert_classifier_main_v2.py` in the above]* For **Multi-GPU training** on one or multiple machines, you may first install the prerequisite OpenMPI and Hovorod packages, as detailed in the [distributed_gpu](https://github.com/asyml/texar/tree/master/examples/distributed_gpu) example. Then run the following cmd for training and evaluation. The cmd trains the model on local with 2 GPUs. Evaluation is performed with the single rank-0 GPU. + ``` mpirun -np 2 \ -H localhost:2\ @@ -93,7 +90,6 @@ mpirun -np 2 \ -mca pml ob1 -mca btl tcp,self \ -mca btl_tcp_if_include ens3 \ python bert_classifier_main.py --do_train --do_eval --distributed - [--config_bert_pretrain=uncased_L-12_H-768_A-12] [--config_downstream=config_classifier] [--config_data=config_data] [--output_dir=output] @@ -107,9 +103,8 @@ Please refer to [distributed_gpu](https://github.com/asyml/texar/tree/master/exa Make sure to specifiy the `--distributed` flag as above for multi-gpu training. -  - After convergence, the evaluation performance is around the following. Due to certain randomness (e.g., random initialization of the classification layer), the evaluation accuracy is reasonable as long as it's `>0.84`. + ``` INFO:tensorflow:dev accu: 0.8676470588235294 ``` @@ -128,6 +123,7 @@ The output is by default saved in `output/test_results.tsv`, where each line con `bert_classifier_main.py` also support other datasets/tasks. To do this, specify a different value to the `--task` flag when running [data preparation](#prepare-data). For example, use the following commands to download the SST (Stanford Sentiment Treebank) dataset and run for sentence classification. Make sure to specify the correct data path and other info in the data configuration file. + ``` python data/download_glue_data.py --tasks=SST python prepare_data.py --task=SST diff --git a/examples/bert/bert_classifier_main.py b/examples/bert/bert_classifier_main.py index c2acb2a9..3e693954 100644 --- a/examples/bert/bert_classifier_main.py +++ b/examples/bert/bert_classifier_main.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Example of building a sentence classifier based on pre-trained BERT -model. +"""Example of building a sentence classifier based on pre-trained BERT model. """ from __future__ import absolute_import @@ -32,17 +31,12 @@ FLAGS = flags.FLAGS -flags.DEFINE_string( - "config_bert_pretrain", 'uncased_L-12_H-768_A-12', - "The architecture of pre-trained BERT model to use.") -flags.DEFINE_string( - "config_format_bert", "json", - "The configuration format. Set to 'json' if the BERT config file is in " - "the same format of the official BERT config file. Set to 'texar' if the " - "BERT config file is in Texar format.") flags.DEFINE_string( "config_downstream", "config_classifier", - "Configuration of the downstream part of the model and optmization.") + "Configuration of the downstream part of the model.") +flags.DEFINE_string( + "pretrained_model_name", "bert-base-uncased", + "Name of the pre-trained checkpoint to load.") flags.DEFINE_string( "config_data", "config_data", "The dataset config.") @@ -60,11 +54,11 @@ config_data = importlib.import_module(FLAGS.config_data) config_downstream = importlib.import_module(FLAGS.config_downstream) + def main(_): """ Builds the model and runs. """ - if FLAGS.distributed: import horovod.tensorflow as hvd hvd.init() @@ -73,22 +67,7 @@ def main(_): tx.utils.maybe_create_dir(FLAGS.output_dir) - bert_pretrain_dir = ('bert_pretrained_models' - '/%s') % FLAGS.config_bert_pretrain - # Loads BERT model configuration - if FLAGS.config_format_bert == "json": - bert_config = model_utils.transform_bert_to_texar_config( - os.path.join(bert_pretrain_dir, 'bert_config.json')) - elif FLAGS.config_format_bert == 'texar': - bert_config = importlib.import_module( - ('bert_config_lib.' - 'config_model_%s') % FLAGS.config_bert_pretrain) - else: - raise ValueError('Unknown config_format_bert.') - # Loads data - - num_classes = config_data.num_classes num_train_data = config_data.num_train_data # Configures distribued mode @@ -110,54 +89,17 @@ def main(_): input_length = tf.reduce_sum(1 - tf.cast(tf.equal(input_ids, 0), tf.int32), axis=1) # Builds BERT - with tf.variable_scope('bert'): - # Word embedding - embedder = tx.modules.WordEmbedder( - vocab_size=bert_config.vocab_size, - hparams=bert_config.embed) - word_embeds = embedder(input_ids) - - # Segment embedding for each type of tokens - segment_embedder = tx.modules.WordEmbedder( - vocab_size=bert_config.type_vocab_size, - hparams=bert_config.segment_embed) - segment_embeds = segment_embedder(segment_ids) - - # Position embedding - position_embedder = tx.modules.PositionEmbedder( - position_size=bert_config.position_size, - hparams=bert_config.position_embed) - seq_length = tf.ones([batch_size], tf.int32) * tf.shape(input_ids)[1] - pos_embeds = position_embedder(sequence_length=seq_length) - - # Aggregates embeddings - input_embeds = word_embeds + segment_embeds + pos_embeds - - # The BERT model (a TransformerEncoder) - encoder = tx.modules.TransformerEncoder(hparams=bert_config.encoder) - output = encoder(input_embeds, input_length) - - # Builds layers for downstream classification, which is also - # initialized with BERT pre-trained checkpoint. - with tf.variable_scope("pooler"): - # Uses the projection of the 1st-step hidden vector of BERT output - # as the representation of the sentence - bert_sent_hidden = tf.squeeze(output[:, 0:1, :], axis=1) - bert_sent_output = tf.layers.dense( - bert_sent_hidden, config_downstream.hidden_dim, - activation=tf.tanh) - output = tf.layers.dropout( - bert_sent_output, rate=0.1, training=tx.global_mode_train()) - - # Adds the final classification layer - logits = tf.layers.dense( - output, num_classes, - kernel_initializer=tf.truncated_normal_initializer(stddev=0.02)) - preds = tf.argmax(logits, axis=-1, output_type=tf.int32) + hparams = { + 'clas_strategy': 'cls_time' + } + model = tx.modules.BERTClassifier( + pretrained_model_name=FLAGS.pretrained_model_name, + hparams=hparams) + logits, preds = model(input_ids, input_length, segment_ids) + accu = tx.evals.accuracy(batch['label_ids'], preds) # Optimization - loss = tf.losses.sparse_softmax_cross_entropy( labels=batch["label_ids"], logits=logits) global_step = tf.Variable(0, trainable=False) @@ -167,7 +109,7 @@ def main(_): num_train_steps = int(num_train_data / config_data.train_batch_size * config_data.max_train_epoch) num_warmup_steps = int(num_train_steps * config_data.warmup_proportion) - lr = model_utils.get_lr(global_step, num_train_steps, # lr is a Tensor + lr = model_utils.get_lr(global_step, num_train_steps, # lr is a Tensor num_warmup_steps, static_lr) opt = tx.core.get_optimizer( @@ -190,8 +132,7 @@ def main(_): def _is_head(): if not FLAGS.distributed: return True - else: - return hvd.rank() == 0 + return hvd.rank() == 0 def _train_epoch(sess): """Trains on the training set, and evaluates on the dev set @@ -217,7 +158,7 @@ def _train_epoch(sess): dis_steps = config_data.display_steps if _is_head() and dis_steps > 0 and step % dis_steps == 0: - tf.logging.info('step:%d; loss:%f' % (step, rets['loss'])) + tf.logging.info('step:%d; loss:%f;' % (step, rets['loss'])) eval_steps = config_data.eval_steps if _is_head() and eval_steps > 0 and step % eval_steps == 0: @@ -277,10 +218,6 @@ def _test_epoch(sess): with tf.gfile.GFile(output_file, "w") as writer: writer.write('\n'.join(str(p) for p in _all_preds)) - # Loads pretrained BERT model parameters - init_checkpoint = os.path.join(bert_pretrain_dir, 'bert_model.ckpt') - model_utils.init_bert_checkpoint(init_checkpoint) - # Broadcasts global variables from rank-0 process if FLAGS.distributed: bcast = hvd.broadcast_global_variables(0) @@ -315,5 +252,6 @@ def _test_epoch(sess): if FLAGS.do_test: _test_epoch(sess) + if __name__ == "__main__": tf.app.run() diff --git a/examples/bert/bert_classifier_main_v2.py b/examples/bert/bert_classifier_main_v2.py deleted file mode 100644 index 0e4440a3..00000000 --- a/examples/bert/bert_classifier_main_v2.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright 2019 The Texar Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Example of building a sentence classifier based on pre-trained BERT -model. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import importlib -import tensorflow as tf -import texar.tf as tx - -from utils import model_utils - -# pylint: disable=invalid-name, too-many-locals, too-many-statements - -flags = tf.flags - -FLAGS = flags.FLAGS - -flags.DEFINE_string( - "config_downstream", "config_classifier", - "Configuration of the downstream part of the model and optmization.") -flags.DEFINE_string( - "config_data", "config_data", - "The dataset config.") -flags.DEFINE_string( - "output_dir", "output/", - "The output directory where the model checkpoints will be written.") -flags.DEFINE_string( - "checkpoint", None, - "Path to a model chceckpoint (including bert modules) to restore from.") -flags.DEFINE_bool("do_train", False, "Whether to run training.") -flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") -flags.DEFINE_bool("do_test", False, "Whether to run test on the test set.") -flags.DEFINE_bool("distributed", False, "Whether to run in distributed mode.") - -config_data = importlib.import_module(FLAGS.config_data) -config_downstream = importlib.import_module(FLAGS.config_downstream) - - -def main(_): - """ - Builds the model and runs. - """ - - if FLAGS.distributed: - import horovod.tensorflow as hvd - hvd.init() - - tf.logging.set_verbosity(tf.logging.INFO) - - tx.utils.maybe_create_dir(FLAGS.output_dir) - - # Loads data - num_train_data = config_data.num_train_data - - # Configures distribued mode - if FLAGS.distributed: - config_data.train_hparam["dataset"]["num_shards"] = hvd.size() - config_data.train_hparam["dataset"]["shard_id"] = hvd.rank() - config_data.train_hparam["batch_size"] //= hvd.size() - - train_dataset = tx.data.TFRecordData(hparams=config_data.train_hparam) - eval_dataset = tx.data.TFRecordData(hparams=config_data.eval_hparam) - test_dataset = tx.data.TFRecordData(hparams=config_data.test_hparam) - - iterator = tx.data.FeedableDataIterator({ - 'train': train_dataset, 'eval': eval_dataset, 'test': test_dataset}) - batch = iterator.get_next() - input_ids = batch["input_ids"] - segment_ids = batch["segment_ids"] - batch_size = tf.shape(input_ids)[0] - input_length = tf.reduce_sum(1 - tf.cast(tf.equal(input_ids, 0), tf.int32), - axis=1) - # Builds BERT - hparams = { - 'clas_strategy': 'cls_time' - } - model = tx.modules.BertClassifier(hparams=hparams) - logits, preds = model(input_ids, input_length, segment_ids) - - accu = tx.evals.accuracy(batch['label_ids'], preds) - - # Optimization - loss = tf.losses.sparse_softmax_cross_entropy( - labels=batch["label_ids"], logits=logits) - global_step = tf.Variable(0, trainable=False) - - # Builds learning rate decay scheduler - static_lr = config_downstream.lr['static_lr'] - num_train_steps = int(num_train_data / config_data.train_batch_size - * config_data.max_train_epoch) - num_warmup_steps = int(num_train_steps * config_data.warmup_proportion) - lr = model_utils.get_lr(global_step, num_train_steps, # lr is a Tensor - num_warmup_steps, static_lr) - - opt = tx.core.get_optimizer( - global_step=global_step, - learning_rate=lr, - hparams=config_downstream.opt - ) - - if FLAGS.distributed: - opt = hvd.DistributedOptimizer(opt) - - train_op = tf.contrib.layers.optimize_loss( - loss=loss, - global_step=global_step, - learning_rate=None, - optimizer=opt) - - # Train/eval/test routine - - def _is_head(): - if not FLAGS.distributed: - return True - return hvd.rank() == 0 - - def _train_epoch(sess): - """Trains on the training set, and evaluates on the dev set - periodically. - """ - iterator.restart_dataset(sess, 'train') - - fetches = { - 'train_op': train_op, - 'loss': loss, - 'batch_size': batch_size, - 'step': global_step - } - - while True: - try: - feed_dict = { - iterator.handle: iterator.get_handle(sess, 'train'), - tx.global_mode(): tf.estimator.ModeKeys.TRAIN, - } - rets = sess.run(fetches, feed_dict) - step = rets['step'] - - dis_steps = config_data.display_steps - if _is_head() and dis_steps > 0 and step % dis_steps == 0: - tf.logging.info('step:%d; loss:%f;' % (step, rets['loss'])) - - eval_steps = config_data.eval_steps - if _is_head() and eval_steps > 0 and step % eval_steps == 0: - _eval_epoch(sess) - - except tf.errors.OutOfRangeError: - break - - def _eval_epoch(sess): - """Evaluates on the dev set. - """ - iterator.restart_dataset(sess, 'eval') - - cum_acc = 0.0 - cum_loss = 0.0 - nsamples = 0 - fetches = { - 'accu': accu, - 'loss': loss, - 'batch_size': batch_size, - } - while True: - try: - feed_dict = { - iterator.handle: iterator.get_handle(sess, 'eval'), - tx.context.global_mode(): tf.estimator.ModeKeys.EVAL, - } - rets = sess.run(fetches, feed_dict) - - cum_acc += rets['accu'] * rets['batch_size'] - cum_loss += rets['loss'] * rets['batch_size'] - nsamples += rets['batch_size'] - except tf.errors.OutOfRangeError: - break - - tf.logging.info('eval accu: {}; loss: {}; nsamples: {}'.format( - cum_acc / nsamples, cum_loss / nsamples, nsamples)) - - def _test_epoch(sess): - """Does predictions on the test set. - """ - iterator.restart_dataset(sess, 'test') - - _all_preds = [] - while True: - try: - feed_dict = { - iterator.handle: iterator.get_handle(sess, 'test'), - tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT, - } - _preds = sess.run(preds, feed_dict=feed_dict) - _all_preds.extend(_preds.tolist()) - except tf.errors.OutOfRangeError: - break - - output_file = os.path.join(FLAGS.output_dir, "test_results.tsv") - with tf.gfile.GFile(output_file, "w") as writer: - writer.write('\n'.join(str(p) for p in _all_preds)) - - # Broadcasts global variables from rank-0 process - if FLAGS.distributed: - bcast = hvd.broadcast_global_variables(0) - - session_config = tf.ConfigProto() - if FLAGS.distributed: - session_config.gpu_options.visible_device_list = str(hvd.local_rank()) - - with tf.Session(config=session_config) as sess: - sess.run(tf.global_variables_initializer()) - sess.run(tf.local_variables_initializer()) - sess.run(tf.tables_initializer()) - - if FLAGS.distributed: - bcast.run() - - # Restores trained model if specified - saver = tf.train.Saver() - if FLAGS.checkpoint: - saver.restore(sess, FLAGS.checkpoint) - - iterator.initialize_dataset(sess) - - if FLAGS.do_train: - for i in range(config_data.max_train_epoch): - _train_epoch(sess) - saver.save(sess, FLAGS.output_dir + '/model.ckpt') - - if FLAGS.do_eval: - _eval_epoch(sess) - - if FLAGS.do_test: - _test_epoch(sess) - - -if __name__ == "__main__": - tf.app.run() diff --git a/examples/bert/bert_config_lib/README.md b/examples/bert/bert_config_lib/README.md deleted file mode 100644 index b7c79e4b..00000000 --- a/examples/bert/bert_config_lib/README.md +++ /dev/null @@ -1,3 +0,0 @@ -### Configuration files of BERT models in Texar style. - -For example, `config_model_uncased_L-12_H-768_A-12.py` is the Texar configuration file equivalent to `uncased_L-12_H-768_A-12` downloaded from [BERT official release](https://github.com/haoransh/texar_private/tree/master/examples/bert). diff --git a/examples/bert/bert_config_lib/__init__.py b/examples/bert/bert_config_lib/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/bert/bert_config_lib/config_model_uncased_L-12_H-768_A-12.py b/examples/bert/bert_config_lib/config_model_uncased_L-12_H-768_A-12.py deleted file mode 100644 index 1fae6645..00000000 --- a/examples/bert/bert_config_lib/config_model_uncased_L-12_H-768_A-12.py +++ /dev/null @@ -1,56 +0,0 @@ -embed = { - 'dim': 768, - 'name': 'word_embeddings' -} -vocab_size = 30522 - -segment_embed = { - 'dim': 768, - 'name': 'token_type_embeddings' -} -type_vocab_size = 2 - -position_embed = { - 'dim': 768, - 'name': 'position_embeddings' -} -position_size = 512 - - -encoder = { - 'dim': 768, - 'embedding_dropout': 0.1, - 'multihead_attention': { - 'dropout_rate': 0.1, - 'name': 'self', - 'num_heads': 12, - 'num_units': 768, - 'output_dim': 768, - 'use_bias': True - }, - 'name': 'encoder', - 'num_blocks': 12, - 'poswise_feedforward': { - 'layers': [ - { 'kwargs': { - 'activation': 'gelu', - 'name': 'intermediate', - 'units': 3072, - 'use_bias': True - }, - 'type': 'Dense' - }, - { 'kwargs': {'activation': None, - 'name': 'output', - 'units': 768, - 'use_bias': True - }, - 'type': 'Dense' - } - ] - }, - 'residual_dropout': 0.1, - 'use_bert_config': True -} - -output_size = 768 # The output dimension of BERT diff --git a/examples/bert/bert_pretrained_models/download_model.sh b/examples/bert/bert_pretrained_models/download_model.sh deleted file mode 100644 index fec01a5c..00000000 --- a/examples/bert/bert_pretrained_models/download_model.sh +++ /dev/null @@ -1,2 +0,0 @@ -wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip -P bert_pretrained_models/; -unzip bert_pretrained_models/uncased_L-12_H-768_A-12.zip -d bert_pretrained_models/ diff --git a/setup.py b/setup.py index edd016d1..c236930c 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ install_requires=[ 'numpy<1.17.0', + 'pathlib>=1.0', 'pyyaml', 'requests', 'funcsigs>=1.0.2', diff --git a/texar/tf/modules/__init__.py b/texar/tf/modules/__init__.py index f7b2a8d9..711485fd 100644 --- a/texar/tf/modules/__init__.py +++ b/texar/tf/modules/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Modules of texar library module. +Modules of Texar library module. """ from __future__ import absolute_import @@ -21,14 +21,14 @@ # pylint: disable=wildcard-import -from texar.tf.modules.networks import * +from texar.tf.modules.classifiers import * +from texar.tf.modules.connectors import * +from texar.tf.modules.decoders import * from texar.tf.modules.embedders import * from texar.tf.modules.encoders import * -from texar.tf.modules.decoders import * -from texar.tf.modules.connectors import * -from texar.tf.modules.classifiers import * +from texar.tf.modules.memory import * +from texar.tf.modules.networks import * from texar.tf.modules.policies import * +from texar.tf.modules.pretrained import * from texar.tf.modules.qnets import * -from texar.tf.modules.memory import * from texar.tf.modules.regressors import * -from texar.tf.modules.pretrained import * diff --git a/texar/tf/modules/classifiers/__init__.py b/texar/tf/modules/classifiers/__init__.py index ec5edaf5..79e3bc99 100644 --- a/texar/tf/modules/classifiers/__init__.py +++ b/texar/tf/modules/classifiers/__init__.py @@ -23,6 +23,6 @@ from texar.tf.modules.classifiers.conv_classifiers import * from texar.tf.modules.classifiers.rnn_classifiers import * -from texar.tf.modules.classifiers.bert_classifiers import * +from texar.tf.modules.classifiers.bert_classifier import * from texar.tf.modules.classifiers.xlnet_classifier import * diff --git a/texar/tf/modules/classifiers/bert_classifiers.py b/texar/tf/modules/classifiers/bert_classifier.py similarity index 75% rename from texar/tf/modules/classifiers/bert_classifiers.py rename to texar/tf/modules/classifiers/bert_classifier.py index 6d0b887a..7f7cf4bd 100644 --- a/texar/tf/modules/classifiers/bert_classifiers.py +++ b/texar/tf/modules/classifiers/bert_classifier.py @@ -20,22 +20,26 @@ from __future__ import print_function import tensorflow as tf -from texar.tf.core import layers + +from texar.tf.core.layers import get_layer from texar.tf.modules.classifiers.classifier_base import ClassifierBase -from texar.tf.modules import BertEncoder -from texar.tf.utils import utils +from texar.tf.modules.encoders.bert_encoder import BERTEncoder from texar.tf.hyperparams import HParams +from texar.tf.modules.pretrained.bert import PretrainedBERTMixin +from texar.tf.utils.utils import dict_fetch # pylint: disable=too-many-arguments, invalid-name, no-member, # pylint: disable=too-many-branches, too-many-locals, too-many-statements __all__ = [ - "BertClassifier" + "BERTClassifier" ] -class BertClassifier(ClassifierBase): - """Classifier based on bert modules. +class BERTClassifier(ClassifierBase, PretrainedBERTMixin): + r"""Classifier based on BERT modules. Please see + :class:`~texar.tf.modules.PretrainedBERTMixin` for a brief description + of BERT. This is a combination of the :class:`~texar.tf.modules.BertEncoder` with a classification @@ -43,41 +47,42 @@ class BertClassifier(ClassifierBase): are supported, specified in :attr:`hparams`. Arguments are the same as in - :class:`~texar.tf.modules.BertEncoder`. + :class:`~texar.tf.modules.BERTEncoder`. Args: - pretrained_model_name (optional): a str with the name - of a pre-trained model to load selected in the list of: - `bert-base-uncased`, `bert-large-uncased`, `bert-base-cased`, - `bert-large-cased`, `bert-base-multilingual-uncased`, - `bert-base-multilingual-cased`, `bert-base-chinese`. - If `None`, will use the model name in :attr:`hparams`. + pretrained_model_name (optional): a `str`, the name + of pre-trained model (e.g., ``bert-base-uncased``). Please refer to + :class:`~texar.tf.modules.PretrainedBERTMixin` for + all supported models. + If `None`, the model name in :attr:`hparams` is used. cache_dir (optional): the path to a folder in which the pre-trained models will be cached. If `None` (default), - a default directory will be used. + a default directory (``texar_data`` folder under user's home + directory) will be used. hparams (dict or HParams, optional): Hyperparameters. Missing - hyperparameter will be set to default values. See - :meth:`default_hparams` for the hyperparameter sturcture + hyperparameters will be set to default values. See + :meth:`default_hparams` for the hyperparameter structure and default values. .. document private functions .. automethod:: _build """ + _ENCODER_CLASS = BERTEncoder def __init__(self, pretrained_model_name=None, cache_dir=None, hparams=None): - ClassifierBase.__init__(self, hparams) + super(BERTClassifier, self).__init__(hparams=hparams) with tf.variable_scope(self.variable_scope): # Creates the underlying encoder - encoder_hparams = utils.dict_fetch( - hparams, BertEncoder.default_hparams()) + encoder_hparams = dict_fetch( + hparams, BERTEncoder.default_hparams()) if encoder_hparams is not None: encoder_hparams['name'] = None - self._encoder = BertEncoder( + self._encoder = BERTEncoder( pretrained_model_name=pretrained_model_name, cache_dir=cache_dir, hparams=encoder_hparams) @@ -85,7 +90,7 @@ def __init__(self, # Creates an dropout layer drop_kwargs = {"rate": self._hparams.dropout} layer_hparams = {"type": "Dropout", "kwargs": drop_kwargs} - self._dropout_layer = layers.get_layer(hparams=layer_hparams) + self._dropout_layer = get_layer(hparams=layer_hparams) # Creates an additional classification layer if needed self._num_classes = self._hparams.num_classes @@ -105,11 +110,11 @@ def __init__(self, logit_kwargs['name'] = "logit_layer" layer_hparams = {"type": "Dense", "kwargs": logit_kwargs} - self._logit_layer = layers.get_layer(hparams=layer_hparams) + self._logit_layer = get_layer(hparams=layer_hparams) @staticmethod def default_hparams(): - """Returns a dictionary of hyperparameters with default values. + r"""Returns a dictionary of hyperparameters with default values. .. code-block:: python @@ -134,43 +139,44 @@ def default_hparams(): 2. Additional hyperparameters: - "num_classes": int + `"num_classes"`: int Number of classes: - - If **`> 0`**, an additional :tf_main:`Dense ` \ - layer is appended to the encoder to compute the logits over \ - classes. - - If **`<= 0`**, no dense layer is appended. The number of \ - classes is assumed to be the final dense layer size of the \ - encoder. + - If **> 0**, an additional :tf_main:`Dense ` + layer is appended to the encoder to compute the logits over + classes. + - If **<= 0**, no dense layer is appended. The number of + classes is assumed to be the final dense layer size of the + encoder. - "logit_layer_kwargs": dict + `"logit_layer_kwargs"`: dict Keyword arguments for the logit Dense layer constructor, - except for argument "units" which is set to "num_classes". + except for argument "units" which is set to `num_classes`. Ignored if no extra logit layer is appended. - "clas_strategy": str + `"clas_strategy"`: str The classification strategy, one of: - - **"cls_time"**: Sequence-level classification based on the \ - output of the first time step (which is the "CLS" token). \ - Each sequence has a class. - - **"all_time"**: Sequence-level classification based on \ - the output of all time steps. Each sequence has a class. - - **"time_wise"**: Step-wise classfication, i.e., make \ - classification for each time step based on its output. - - "max_seq_length": int, optional + + - **cls_time**: Sequence-level classification based on the + output of the first time step (which is the `CLS` token). + Each sequence has a class. + - **all_time**: Sequence-level classification based on + the output of all time steps. Each sequence has a class. + - **time_wise**: Step-wise classification, i.e., make + classification for each time step based on its output. + + `"max_seq_length"`: int, optional Maximum possible length of input sequences. Required if - "clas_strategy" is "all_time". + `clas_strategy` is `all_time`. - "dropout": float - The dropout rate of the bert encoder output. + `"dropout"`: float + The dropout rate of the BERT encoder output. - "name": str + `"name"`: str Name of the classifier. """ - hparams = BertEncoder.default_hparams() + hparams = BERTEncoder.default_hparams() hparams.update({ "num_classes": 2, "logit_layer_kwargs": None, @@ -187,7 +193,7 @@ def _build(self, segment_ids=None, mode=None, **kwargs): - """Feeds the inputs through the network and makes classification. + r"""Feeds the inputs through the network and makes classification. The arguments are the same as in :class:`~texar.tf.modules.BertEncoder`. diff --git a/texar/tf/modules/classifiers/bert_classifier_test.py b/texar/tf/modules/classifiers/bert_classifier_test.py index 6b908fa4..b682c04e 100644 --- a/texar/tf/modules/classifiers/bert_classifier_test.py +++ b/texar/tf/modules/classifiers/bert_classifier_test.py @@ -1,4 +1,3 @@ -# """ Unit tests for BERT classifiers. """ @@ -9,17 +8,29 @@ from __future__ import unicode_literals import numpy as np - import tensorflow as tf -from texar.tf.modules.classifiers.bert_classifiers import BertClassifier +from texar.tf.modules.classifiers.bert_classifier import BERTClassifier +from texar.tf.utils.test import pretrained_test # pylint: disable=too-many-locals, no-member -class BertClassifierTest(tf.test.TestCase): - """Tests :class:`~texar.tf.modules.BertClassifierTest` class. + +class BERTClassifierTest(tf.test.TestCase): + """Tests :class:`~texar.tf.modules.BERTClassifier` class. """ + @pretrained_test + def test_model_loading(self): + r"""Tests model loading functionality.""" + + inputs = tf.placeholder(dtype=tf.int32, shape=[None, None]) + + for pretrained_model_name in BERTClassifier.available_checkpoints(): + classifier = BERTClassifier( + pretrained_model_name=pretrained_model_name) + _, _ = classifier(inputs) + def test_trainable_variables(self): """Tests the functionality of automatically collecting trainable variables. @@ -27,24 +38,29 @@ def test_trainable_variables(self): inputs = tf.placeholder(dtype=tf.int32, shape=[None, None]) # case 1 - clas = BertClassifier() + hparams = { + "pretrained_model_name": None, + } + clas = BERTClassifier(hparams=hparams) _, _ = clas(inputs) self.assertEqual(len(clas.trainable_variables), 199+2) # case 2 hparams = { + "pretrained_model_name": None, "clas_strategy": "all_time", "max_seq_length": 8, } - clas = BertClassifier(hparams=hparams) + clas = BERTClassifier(hparams=hparams) _, _ = clas(inputs) self.assertEqual(len(clas.trainable_variables), 199+2) # case 2 hparams = { + "pretrained_model_name": None, "clas_strategy": "time_wise", } - clas = BertClassifier(hparams=hparams) + clas = BERTClassifier(hparams=hparams) _, _ = clas(inputs) self.assertEqual(len(clas.trainable_variables), 199+2) @@ -55,9 +71,11 @@ def test_encode(self): batch_size = 16 inputs = tf.random_uniform([batch_size, max_time], maxval=30521, dtype=tf.int32) - # case 1 - clas = BertClassifier() + hparams = { + "pretrained_model_name": None, + } + clas = BERTClassifier(hparams=hparams) logits, pred = clas(inputs) with self.test_session() as sess: @@ -69,10 +87,11 @@ def test_encode(self): # case 2 hparams = { + "pretrained_model_name": None, "num_classes": 10, "clas_strategy": "time_wise" } - clas = BertClassifier(hparams=hparams) + clas = BERTClassifier(hparams=hparams) logits, pred = clas(inputs) with self.test_session() as sess: @@ -84,10 +103,11 @@ def test_encode(self): # case 3 hparams = { + "pretrained_model_name": None, "num_classes": 0, "clas_strategy": "time_wise" } - clas = BertClassifier(hparams=hparams) + clas = BERTClassifier(hparams=hparams) logits, pred = clas(inputs) with self.test_session() as sess: @@ -97,15 +117,15 @@ def test_encode(self): (batch_size, max_time, clas.hparams.encoder.dim)) self.assertEqual(pred_.shape, (batch_size, max_time)) - # case 4 hparams = { + "pretrained_model_name": None, "num_classes": 10, "clas_strategy": "all_time", "max_seq_length": max_time } inputs = tf.placeholder(tf.int32, shape=[batch_size, 6]) - clas = BertClassifier(hparams=hparams) + clas = BERTClassifier(hparams=hparams) logits, pred = clas(inputs) with self.test_session() as sess: @@ -128,10 +148,11 @@ def test_binary(self): # case 2 hparams = { + "pretrained_model_name": None, "num_classes": 1, "clas_strategy": "time_wise" } - clas = BertClassifier(hparams=hparams) + clas = BERTClassifier(hparams=hparams) logits, pred = clas(inputs) with self.test_session() as sess: @@ -142,12 +163,13 @@ def test_binary(self): # case 3 hparams = { + "pretrained_model_name": None, "num_classes": 1, "clas_strategy": "cls_time", "max_seq_length": max_time } inputs = tf.placeholder(tf.int32, shape=[batch_size, 6]) - clas = BertClassifier(hparams=hparams) + clas = BERTClassifier(hparams=hparams) logits, pred = clas(inputs) with self.test_session() as sess: @@ -161,12 +183,13 @@ def test_binary(self): # case 4 hparams = { + "pretrained_model_name": None, "num_classes": 1, "clas_strategy": "all_time", "max_seq_length": max_time } inputs = tf.placeholder(tf.int32, shape=[batch_size, 6]) - clas = BertClassifier(hparams=hparams) + clas = BERTClassifier(hparams=hparams) logits, pred = clas(inputs) with self.test_session() as sess: diff --git a/texar/tf/modules/classifiers/xlnet_classifier.py b/texar/tf/modules/classifiers/xlnet_classifier.py index 9c066071..c6a45311 100644 --- a/texar/tf/modules/classifiers/xlnet_classifier.py +++ b/texar/tf/modules/classifiers/xlnet_classifier.py @@ -20,12 +20,14 @@ from __future__ import print_function import tensorflow as tf + from texar.tf.utils.mode import is_train_mode -from texar.tf.core import layers +from texar.tf.core.layers import get_layer, get_initializer from texar.tf.modules.classifiers.classifier_base import ClassifierBase -from texar.tf.modules import XLNetEncoder -from texar.tf.utils import utils +from texar.tf.modules.encoders.xlnet_encoder import XLNetEncoder from texar.tf.hyperparams import HParams +from texar.tf.modules.pretrained.xlnet import PretrainedXLNetMixin +from texar.tf.utils.utils import dict_fetch # pylint: disable=too-many-arguments, invalid-name, no-member, # pylint: disable=too-many-branches, too-many-locals, too-many-statements @@ -35,8 +37,10 @@ ] -class XLNetClassifier(ClassifierBase): - """Classifier based on XLNet modules. +class XLNetClassifier(ClassifierBase, PretrainedXLNetMixin): + """Classifier based on XLNet modules. Please see + :class:`~texar.tf.modules.PretrainedXLNetMixin` for a brief description + of XLNet. This is a combination of the :class:`~texar.tf.modules.XLNetEncoder` with a classification layer. Both step-wise classification and sequence-level @@ -45,14 +49,17 @@ class XLNetClassifier(ClassifierBase): Arguments are the same as in :class:`~texar.tf.modules.XLNetEncoder`. Args: - pretrained_model_name (optional): a str with the name - of a pre-trained model to load. Currently only 'xlnet-large-cased' - is supported. If `None`, will use the model name in :attr:`hparams`. + pretrained_model_name (optional): a `str`, the name + of pre-trained model (e.g., ``xlnet-based-cased``). Please refer to + :class:`~texar.tf.modules.PretrainedXLNetMixin` for + all supported models. + If `None`, the model name in :attr:`hparams` is used. cache_dir (optional): the path to a folder in which the pre-trained models will be cached. If `None` (default), - a default directory will be used. + a default directory (``texar_data`` folder under user's home + directory) will be used. hparams (dict or HParams, optional): Hyperparameters. Missing - hyperparameter will be set to default values. See + hyperparameters will be set to default values. See :meth:`default_hparams` for the hyperparameter structure and default values. @@ -64,13 +71,13 @@ def __init__(self, pretrained_model_name=None, cache_dir=None, hparams=None): - ClassifierBase.__init__(self, hparams) + super(XLNetClassifier, self).__init__(hparams=hparams) with tf.variable_scope(self.variable_scope): tf.get_variable_scope().set_initializer( - layers.get_initializer(self._hparams.initializer)) + get_initializer(self._hparams.initializer)) # Creates the underlying encoder - encoder_hparams = utils.dict_fetch( + encoder_hparams = dict_fetch( hparams, XLNetEncoder.default_hparams()) if encoder_hparams is not None: encoder_hparams['name'] = "encoder" @@ -79,7 +86,7 @@ def __init__(self, cache_dir=cache_dir, hparams=encoder_hparams) if self._hparams.use_projection: - self.projection = layers.get_layer(hparams={ + self.projection = get_layer(hparams={ "type": "Dense", "kwargs": { "units": self._encoder.output_size @@ -89,7 +96,7 @@ def __init__(self, # Creates an dropout layer drop_kwargs = {"rate": self._hparams.dropout} layer_hparams = {"type": "Dropout", "kwargs": drop_kwargs} - self._dropout_layer = layers.get_layer(hparams=layer_hparams) + self._dropout_layer = get_layer(hparams=layer_hparams) # Creates an additional classification layer if needed self._num_classes = self._hparams.num_classes @@ -109,7 +116,7 @@ def __init__(self, logit_kwargs['name'] = "logit_layer" layer_hparams = {"type": "Dense", "kwargs": logit_kwargs} - self._logit_layer = layers.get_layer(hparams=layer_hparams) + self._logit_layer = get_layer(hparams=layer_hparams) @staticmethod def default_hparams(): diff --git a/texar/tf/modules/classifiers/xlnet_classifier_test.py b/texar/tf/modules/classifiers/xlnet_classifier_test.py index 61ab6871..b85dd02a 100644 --- a/texar/tf/modules/classifiers/xlnet_classifier_test.py +++ b/texar/tf/modules/classifiers/xlnet_classifier_test.py @@ -9,10 +9,10 @@ from __future__ import unicode_literals import numpy as np - import tensorflow as tf from texar.tf.modules.classifiers.xlnet_classifier import XLNetClassifier +from texar.tf.utils.test import pretrained_test # pylint: disable=too-many-locals, no-member @@ -21,6 +21,17 @@ class XLNetClassifierTest(tf.test.TestCase): """Tests :class:`~texar.tf.modules.XLNetClassifier` class. """ + @pretrained_test + def test_model_loading(self): + r"""Tests model loading functionality.""" + + inputs = tf.placeholder(dtype=tf.int32, shape=[None, None]) + + for pretrained_model_name in XLNetClassifier.available_checkpoints(): + classifier = XLNetClassifier( + pretrained_model_name=pretrained_model_name) + _, _ = classifier(inputs) + def test_trainable_variables(self): """Tests the functionality of automatically collecting trainable variables. @@ -28,7 +39,10 @@ def test_trainable_variables(self): inputs = tf.placeholder(dtype=tf.int32, shape=[None, None]) # case 1 - clas = XLNetClassifier() + hparams = { + "pretrained_model_name": None, + } + clas = XLNetClassifier(hparams=hparams) clas(inputs) n_xlnet_vars = 162 n_projection_vars = 2 @@ -38,6 +52,7 @@ def test_trainable_variables(self): # case 2 hparams = { + "pretrained_model_name": None, "clas_strategy": "time_wise" } clas = XLNetClassifier(hparams=hparams) @@ -47,6 +62,7 @@ def test_trainable_variables(self): # case 3 hparams = { + "pretrained_model_name": None, "clas_strategy": "all_time" } clas = XLNetClassifier(hparams=hparams) diff --git a/texar/tf/modules/encoders/__init__.py b/texar/tf/modules/encoders/__init__.py index c15df5d9..557dfa1d 100644 --- a/texar/tf/modules/encoders/__init__.py +++ b/texar/tf/modules/encoders/__init__.py @@ -22,10 +22,10 @@ # pylint: disable=wildcard-import from texar.tf.modules.encoders.encoder_base import * -from texar.tf.modules.encoders.rnn_encoders import * +from texar.tf.modules.encoders.bert_encoder import * +from texar.tf.modules.encoders.conv_encoders import * from texar.tf.modules.encoders.hierarchical_encoders import * -from texar.tf.modules.encoders.transformer_encoders import * from texar.tf.modules.encoders.multihead_attention import * -from texar.tf.modules.encoders.conv_encoders import * -from texar.tf.modules.encoders.bert_encoders import * -from texar.tf.modules.encoders.xlnet_encoders import * +from texar.tf.modules.encoders.rnn_encoders import * +from texar.tf.modules.encoders.transformer_encoders import * +from texar.tf.modules.encoders.xlnet_encoder import * diff --git a/texar/tf/modules/encoders/bert_encoders.py b/texar/tf/modules/encoders/bert_encoder.py similarity index 65% rename from texar/tf/modules/encoders/bert_encoders.py rename to texar/tf/modules/encoders/bert_encoder.py index 76d0d0f2..c8735c0a 100644 --- a/texar/tf/modules/encoders/bert_encoders.py +++ b/texar/tf/modules/encoders/bert_encoder.py @@ -20,64 +20,57 @@ from __future__ import print_function import tensorflow as tf -from texar.tf.core import layers + +from texar.tf.core.layers import get_initializer, get_layer from texar.tf.modules.encoders.transformer_encoders import TransformerEncoder -from texar.tf.modules.embedders import WordEmbedder, PositionEmbedder -from texar.tf.hyperparams import HParams -from texar.tf.modules.pretrained.pretrained_base import PretrainedBase -from texar.tf.modules.pretrained import bert_utils -from texar.tf.modules.encoders import EncoderBase +from texar.tf.modules.embedders.embedders import WordEmbedder +from texar.tf.modules.embedders.position_embedders import PositionEmbedder +from texar.tf.modules.encoders.encoder_base import EncoderBase +from texar.tf.modules.pretrained.bert import PretrainedBERTMixin __all__ = [ - "BertEncoder", + "BERTEncoder", ] -class BertEncoder(PretrainedBase, EncoderBase): - """Raw BERT Transformer for encoding sequences. +class BERTEncoder(EncoderBase, PretrainedBERTMixin): + r"""Raw BERT Transformer for encoding sequences. Please see + :class:`~texar.tf.modules.PretrainedBERTMixin` for a brief description + of BERT. This module basically stacks - :class:`~texar.tf.modules.embedders.WordEmbedder`, - :class:`~texar.tf.modules.embedders.PositionEmbedder`, - :class:`~texar.tf.modules.encoders.TransformerEncoder` and a dense pooler. - - This module supports the architecture first proposed - in `(Devlin et al.)` BERT. + :class:`~texar.tf.modules.WordEmbedder`, + :class:`~texar.tf.modules.PositionEmbedder`, + :class:`~texar.tf.modules.TransformerEncoder` and a dense pooler. Args: - pretrained_model_name (optional): a str with the name - of a pre-trained model to load selected in the list of: - `bert-base-uncased`, `bert-large-uncased`, `bert-base-cased`, - `bert-large-cased`, `bert-base-multilingual-uncased`, - `bert-base-multilingual-cased`, `bert-base-chinese`. - If `None`, will use the model name in :attr:`hparams`. + pretrained_model_name (optional): a `str`, the name + of pre-trained model (e.g., ``bert-base-uncased``). Please refer to + :class:`~texar.tf.modules.PretrainedBERTMixin` for + all supported models. + If `None`, the model name in :attr:`hparams` is used. cache_dir (optional): the path to a folder in which the pre-trained models will be cached. If `None` (default), - a default directory will be used. + a default directory (``texar_data`` folder under user's home + directory) will be used. hparams (dict or HParams, optional): Hyperparameters. Missing hyperparameter will be set to default values. See - :meth:`default_hparams` for the hyperparameter sturcture + :meth:`default_hparams` for the hyperparameter structure and default values. .. document private functions .. automethod:: _build """ - model_name = "BERT" - def __init__(self, pretrained_model_name=None, cache_dir=None, hparams=None): - PretrainedBase.__init__(self, pretrained_model_name, cache_dir, hparams) - if self.pretrained_model_dir: - self._hparams = HParams(self.pretrained_model_hparams, - self._hparams.todict()) + super(BERTEncoder, self).__init__(hparams=hparams) + + self.load_pretrained_config(pretrained_model_name, cache_dir) with tf.variable_scope(self.variable_scope): - if self._hparams.initializer: - tf.get_variable_scope().set_initializer( - layers.get_initializer(self._hparams.initializer)) # Word embedding self.word_embedder = WordEmbedder( @@ -101,124 +94,127 @@ def __init__(self, kwargs_i = {"units": self._hparams.hidden_size, "activation": tf.tanh} layer_hparams = {"type": "Dense", "kwargs": kwargs_i} - self.pooler = layers.get_layer(hparams=layer_hparams) + self.pooler = get_layer(hparams=layer_hparams) + + def reset_parameters(self): + with tf.variable_scope(self.variable_scope): + if self._hparams.initializer: + tf.get_variable_scope().set_initializer( + get_initializer(self._hparams.initializer)) @staticmethod def default_hparams(): - """Returns a dictionary of hyperparameters with default values. + r"""Returns a dictionary of hyperparameters with default values. - * The encoder arch is determined by the constructor argument \ - :attr:`pretrained_model_name` if it's specified. In this case, \ - hparams are ignored. - * Otherwise, the encoder arch is determined by \ - `hparams['pretrained_model_name']` if it's specified. All other \ - configs in hparams are ignored. - * If the above two are `None`, the encoder arch is defined by \ - the configs in hparams and weights are randomly initialized. + * The encoder arch is determined by the constructor argument + :attr:`pretrained_model_name` if it's specified. In this case, + `hparams` are ignored. + * Otherwise, the encoder arch is determined by + `hparams['pretrained_model_name']` if it's specified. All other + configurations in `hparams` are ignored. + * If the above two are `None`, the encoder arch is defined by the + configurations in `hparams` and weights are randomly initialized. .. code-block:: python { - 'pretrained_model_name': 'bert-base-uncased', - 'embed': { - 'dim': 768, - 'name': 'word_embeddings' + "pretrained_model_name": "bert-base-uncased", + "embed": { + "dim": 768, + "name": "word_embeddings" }, - 'vocab_size': 30522, - 'segment_embed': { - 'dim': 768, - 'name': 'token_type_embeddings' + "vocab_size": 30522, + "segment_embed": { + "dim": 768, + "name": "token_type_embeddings" }, - 'type_vocab_size': 2, - 'position_embed': { - 'dim': 768, - 'name': 'position_embeddings' + "type_vocab_size": 2, + "position_embed": { + "dim": 768, + "name": "position_embeddings" }, - 'position_size': 512, - - 'encoder': { - 'dim': 768, - 'embedding_dropout': 0.1, - 'multihead_attention': { - 'dropout_rate': 0.1, - 'name': 'self', - 'num_heads': 12, - 'num_units': 768, - 'output_dim': 768, - 'use_bias': True + "position_size": 512, + + "encoder": { + "dim": 768, + "embedding_dropout": 0.1, + "multihead_attention": { + "dropout_rate": 0.1, + "name": "self", + "num_heads": 12, + "num_units": 768, + "output_dim": 768, + "use_bias": True }, - 'name': 'encoder', - 'num_blocks': 12, - 'poswise_feedforward': { - 'layers': [ - { 'kwargs': { - 'activation': 'gelu', - 'name': 'intermediate', - 'units': 3072, - 'use_bias': True + "name": "encoder", + "num_blocks": 12, + "poswise_feedforward": { + "layers": [ + { "kwargs": { + "activation": "gelu", + "name": "intermediate", + "units": 3072, + "use_bias": True }, - 'type': 'Dense' + "type": "Dense" }, - { 'kwargs': {'activation': None, - 'name': 'output', - 'units': 768, - 'use_bias': True + { "kwargs": {"activation": None, + "name": "output", + "units": 768, + "use_bias": True }, - 'type': 'Dense' + "type": "Dense" } ] }, - 'residual_dropout': 0.1, - 'use_bert_config': True + "residual_dropout": 0.1, + "use_bert_config": True }, - 'hidden_size': 768, - 'initializer': None, - 'name': 'bert_encoder' + "hidden_size": 768, + "initializer": None, + "name": "bert_encoder" } - - Here: The default parameters are values for uncased BERT-Base model. + `"pretrained_model_name"`: str or None + The name of the pre-trained BERT model. If None, the model + will be randomly initialized. - "pretrained_model_name": str or None - The name of the pretrained bert model. If None, the model - will be randomly initialized. - - "embed": dict + `"embed"`: dict Hyperparameters for word embedding layer. - "vocab_size": int - The vocabulary size of `inputs` in `BertModel`. + `"vocab_size"`: int + The vocabulary size of `inputs` in BERT model. - "segment_embed": dict + `"segment_embed"`: dict Hyperparameters for segment embedding layer. - "type_vocab_size": int + `"type_vocab_size"`: int The vocabulary size of the `segment_ids` passed into `BertModel`. - "position_embed": dict + `"position_embed"`: dict Hyperparameters for position embedding layer. - "position_size": int + `"position_size"`: int The maximum sequence length that this model might ever be used with. - "encoder": dict + `"encoder"`: dict Hyperparameters for the TransformerEncoder. See :func:`~texar.tf.modules.TransformerEncoder.default_harams` for details. - "hidden_size": int + `"hidden_size"`: int Size of the pooler dense layer. - "initializer": dict, optional + `"initializer"`: dict, optional Hyperparameters of the default initializer that initializes variables created in this module. See :func:`~texar.tf.core.get_initializer` for details. - "name": str + `"name"`: str Name of the module. """ @@ -318,7 +314,7 @@ def _build(self, - :attr:`pooled_output`: A Tensor of size \ `[batch_size, hidden_size]` which is the output of a \ - pooler pretrained on top of the hidden state associated \ + pooler berts on top of the hidden state associated \ to the first character of the input (`CLS`), see BERT's \ paper. """ @@ -351,8 +347,6 @@ def _build(self, self._add_internal_trainable_variables() self._built = True - if self.pretrained_model_dir: - bert_utils.init_bert_checkpoint(self.pretrained_model_dir, - self.variable_scope.name) + self.init_pretrained_weights(self.variable_scope.name) return output, pooled_output diff --git a/texar/tf/modules/encoders/bert_encoders_test.py b/texar/tf/modules/encoders/bert_encoder_test.py similarity index 78% rename from texar/tf/modules/encoders/bert_encoders_test.py rename to texar/tf/modules/encoders/bert_encoder_test.py index 515d39df..09516f15 100644 --- a/texar/tf/modules/encoders/bert_encoders_test.py +++ b/texar/tf/modules/encoders/bert_encoder_test.py @@ -1,6 +1,5 @@ -# """ -Unit tests for Bert encoders. +Unit tests for BERT encoders. """ from __future__ import absolute_import @@ -10,13 +9,25 @@ import tensorflow as tf -from texar.tf.modules.encoders.bert_encoders import BertEncoder +from texar.tf.modules.encoders.bert_encoder import BERTEncoder +from texar.tf.utils.test import pretrained_test -class BertEncoderTest(tf.test.TestCase): - """Tests :class:`~texar.tf.modules.BertEncoder` class. +class BERTEncoderTest(tf.test.TestCase): + """Tests :class:`~texar.tf.modules.BERTEncoder` class. """ + @pretrained_test + def test_model_loading(self): + r"""Tests model loading functionality.""" + + inputs = tf.placeholder(dtype=tf.int32, shape=[None, None]) + + for pretrained_model_name in BERTEncoder.available_checkpoints(): + encoder = BERTEncoder(pretrained_model_name=pretrained_model_name) + _, _ = encoder(inputs) + + @pretrained_test def test_hparams(self): """Tests the priority of the encoder arch parameter. """ @@ -27,7 +38,7 @@ def test_hparams(self): hparams = { "pretrained_model_name": "bert-large-uncased", } - encoder = BertEncoder(pretrained_model_name="bert-base-uncased", + encoder = BERTEncoder(pretrained_model_name="bert-base-uncased", hparams=hparams) _, _ = encoder(inputs) self.assertEqual(encoder.hparams.encoder.num_blocks, 12) @@ -39,7 +50,7 @@ def test_hparams(self): "num_blocks": 6 } } - encoder = BertEncoder(hparams=hparams) + encoder = BERTEncoder(hparams=hparams) _, _ = encoder(inputs) self.assertEqual(encoder.hparams.encoder.num_blocks, 24) @@ -50,23 +61,25 @@ def test_hparams(self): "num_blocks": 6 }, } - encoder = BertEncoder(hparams=hparams) + encoder = BERTEncoder(hparams=hparams) _, _ = encoder(inputs) self.assertEqual(encoder.hparams.encoder.num_blocks, 6) # case 4: using default hparams - encoder = BertEncoder() + encoder = BERTEncoder() _, _ = encoder(inputs) self.assertEqual(encoder.hparams.encoder.num_blocks, 12) + @pretrained_test def test_trainable_variables(self): """Tests the functionality of automatically collecting trainable variables. """ + inputs = tf.placeholder(dtype=tf.int32, shape=[None, None]) # case 1: bert base - encoder = BertEncoder() + encoder = BERTEncoder() _, _ = encoder(inputs) self.assertEqual(len(encoder.trainable_variables), 3+2+12*16+2) @@ -74,7 +87,7 @@ def test_trainable_variables(self): hparams = { "pretrained_model_name": "bert-large-uncased" } - encoder = BertEncoder(hparams=hparams) + encoder = BERTEncoder(hparams=hparams) _, _ = encoder(inputs) self.assertEqual(len(encoder.trainable_variables), 3+2+24*16+2) @@ -85,7 +98,7 @@ def test_trainable_variables(self): }, "pretrained_model_name": None } - encoder = BertEncoder(hparams=hparams) + encoder = BERTEncoder(hparams=hparams) _, _ = encoder(inputs) self.assertEqual(len(encoder.trainable_variables), 3+2+6*16+2) @@ -93,7 +106,10 @@ def test_encode(self): """Tests encoding. """ # case 1: bert base - encoder = BertEncoder() + hparams = { + "pretrained_model_name": None + } + encoder = BERTEncoder(hparams=hparams) max_time = 8 batch_size = 16 @@ -116,7 +132,7 @@ def test_encode(self): "hidden_size": 100, "pretrained_model_name": None } - encoder = BertEncoder(hparams=hparams) + encoder = BERTEncoder(hparams=hparams) max_time = 8 batch_size = 16 @@ -135,7 +151,5 @@ def test_encode(self): (batch_size, pooled_output_dim)) - - if __name__ == "__main__": tf.test.main() diff --git a/texar/tf/modules/encoders/xlnet_encoders.py b/texar/tf/modules/encoders/xlnet_encoder.py similarity index 93% rename from texar/tf/modules/encoders/xlnet_encoders.py rename to texar/tf/modules/encoders/xlnet_encoder.py index 34d00311..5c3e263c 100644 --- a/texar/tf/modules/encoders/xlnet_encoders.py +++ b/texar/tf/modules/encoders/xlnet_encoder.py @@ -23,56 +23,50 @@ from texar.tf.utils.mode import is_train_mode -from texar.tf.hyperparams import HParams -from texar.tf.core import layers -from texar.tf.modules.pretrained.pretrained_base import PretrainedBase -from texar.tf.modules.pretrained import xlnet_utils -from texar.tf.modules.pretrained.xlnet_model_utils import \ +from texar.tf.core.layers import get_initializer, get_layer +from texar.tf.modules.embedders.embedders import WordEmbedder +from texar.tf.modules.encoders.encoder_base import EncoderBase +from texar.tf.modules.pretrained.xlnet import PretrainedXLNetMixin +from texar.tf.modules.pretrained.xlnet_utils import \ (PositionWiseFF, RelativePositionalEncoding, RelativeMutiheadAttention) -from texar.tf.modules.embedders import WordEmbedder -from texar.tf.modules.encoders import EncoderBase - -from texar.tf.utils import dict_fetch +from texar.tf.utils.utils import dict_fetch __all__ = [ "XLNetEncoder" ] -class XLNetEncoder(PretrainedBase, EncoderBase): - r"""XLNet Transformer for encoding sequences. - - This module supports the architecture proposed - in `(Zhiling et al.)` XLNet. +class XLNetEncoder(EncoderBase, PretrainedXLNetMixin): + r"""Raw XLNet module for encoding sequences. Please see + :class:`~texar.tf.modules.PretrainedXLNetMixin` for a brief description + of XLNet. Args: - pretrained_model_name (optional): a str with the name - of a pre-trained model to load. Currently 'xlnet-large-cased' - and 'xlnet-base-cased' are supported. - If `None`, will use the model name in :attr:`hparams`. + pretrained_model_name (optional): a `str`, the name + of pre-trained model (e.g., ``xlnet-based-cased``). Please refer to + :class:`~texar.tf.modules.PretrainedXLNetMixin` for + all supported models. + If `None`, the model name in :attr:`hparams` is used. cache_dir (optional): the path to a folder in which the pre-trained models will be cached. If `None` (default), - a default directory will be used. + a default directory (``texar_data`` folder under user's home + directory) will be used. hparams (dict or HParams, optional): Hyperparameters. Missing hyperparameter will be set to default values. See - :meth:`default_hparams` for the hyperparameter sturcture + :meth:`default_hparams` for the hyperparameter structure and default values. .. document private functions .. automethod:: _build """ - model_name = "XLNet" - def __init__(self, pretrained_model_name=None, cache_dir=None, hparams=None): - PretrainedBase.__init__(self, pretrained_model_name, cache_dir, hparams) + super(XLNetEncoder, self).__init__(hparams=hparams) - if self.pretrained_model_dir: - self._hparams = HParams(self.pretrained_model_hparams, - self._hparams.todict()) + self.load_pretrained_config(pretrained_model_name, cache_dir) num_layers = self._hparams.num_layers use_segments = self._hparams.use_segments @@ -80,10 +74,6 @@ def __init__(self, with tf.variable_scope(self.variable_scope): - if self._hparams.initializer: - tf.get_variable_scope().set_initializer( - layers.get_initializer(self._hparams.initializer)) - if untie_r: self.r_w_bias = tf.get_variable('r_w_bias', [num_layers, @@ -177,11 +167,17 @@ def __init__(self, "rate": self._hparams.dropout } } - self.dropout = layers.get_layer(hparams=dropout_hparams) + self.dropout = get_layer(hparams=dropout_hparams) self.mask_embed = tf.get_variable( 'mask_emb', [1, 1, self.hparams.hidden_dim], dtype=tf.float32) + def reset_parameters(self): + with tf.variable_scope(self.variable_scope): + if self._hparams.initializer: + tf.get_variable_scope().set_initializer( + get_initializer(self._hparams.initializer)) + @staticmethod def default_hparams(): r"""Returns a dictionary of hyperparameters with default values. @@ -225,7 +221,7 @@ def default_hparams(): "pretrained_model_name": str or None - The name of the pretrained bert model. If None, the model + The name of the pre-trained bert model. If None, the model will be randomly initialized. "untie_r": bool @@ -619,8 +615,7 @@ def _execute(self, word_embed, segment_ids=None, # noqa: C901 self._built = True if self.pretrained_model_dir: - xlnet_utils.init_xlnet_checkpoint(self.pretrained_model_dir, - self.variable_scope.name) + self.init_pretrained_weights(self.variable_scope.name) if cache_len == 0: return output, None diff --git a/texar/tf/modules/encoders/xlnet_encoders_test.py b/texar/tf/modules/encoders/xlnet_encoder_test.py similarity index 89% rename from texar/tf/modules/encoders/xlnet_encoders_test.py rename to texar/tf/modules/encoders/xlnet_encoder_test.py index 4390c5e4..fbe03336 100644 --- a/texar/tf/modules/encoders/xlnet_encoders_test.py +++ b/texar/tf/modules/encoders/xlnet_encoder_test.py @@ -10,13 +10,25 @@ import tensorflow as tf -from texar.tf.modules.encoders.xlnet_encoders import XLNetEncoder +from texar.tf.modules.encoders.xlnet_encoder import XLNetEncoder +from texar.tf.utils.test import pretrained_test class XLNetEncoderTest(tf.test.TestCase): """Tests :class:`~texar.tf.modules.XLNetEncoder` class. """ + @pretrained_test + def test_model_loading(self): + r"""Tests model loading functionality.""" + + inputs = tf.placeholder(dtype=tf.int32, shape=[None, None]) + + for pretrained_model_name in XLNetEncoder.available_checkpoints(): + encoder = XLNetEncoder(pretrained_model_name=pretrained_model_name) + _ = encoder(inputs) + + @pretrained_test def test_hparams(self): """Tests the priority of the encoder architecture parameter. """ @@ -56,10 +68,12 @@ def test_hparams(self): self.assertEqual(len(encoder.attn_layers), 12) self.assertEqual(len(encoder.ff_layers), 12) + @pretrained_test def test_trainable_variables(self): """Tests the functionality of automatically collecting trainable variables. """ + inputs = tf.placeholder(dtype=tf.int32, shape=[None, None]) # case 1: XLNet with no pre-trained model @@ -102,7 +116,7 @@ def test_encode(self): """ # case 1: XLNet pre-trained hparams = { - "pretrained_model_name": "xlnet-base-cased", + "pretrained_model_name": None, "untie_r": False } encoder = XLNetEncoder(hparams=hparams) @@ -122,7 +136,7 @@ def test_encode(self): # case 2: XLNet pre-trained, untie_r=True hparams = { - "pretrained_model_name": "xlnet-base-cased", + "pretrained_model_name": None, "untie_r": True } diff --git a/texar/tf/modules/pretrained/__init__.py b/texar/tf/modules/pretrained/__init__.py index 2d0f87c7..f78eb8f6 100644 --- a/texar/tf/modules/pretrained/__init__.py +++ b/texar/tf/modules/pretrained/__init__.py @@ -15,6 +15,6 @@ Pre-trained modules of Texar library. """ -from texar.tf.modules.pretrained.bert_utils import * from texar.tf.modules.pretrained.pretrained_base import * -from texar.tf.modules.pretrained.xlnet_utils import * +from texar.tf.modules.pretrained.bert import * +from texar.tf.modules.pretrained.xlnet import * diff --git a/texar/tf/modules/pretrained/bert.py b/texar/tf/modules/pretrained/bert.py new file mode 100644 index 00000000..b1598604 --- /dev/null +++ b/texar/tf/modules/pretrained/bert.py @@ -0,0 +1,253 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utils of BERT Modules. +""" + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division +from __future__ import unicode_literals + +import collections +import json +import os +import re + +from abc import ABCMeta + +import tensorflow as tf + +from texar.tf.modules.pretrained.pretrained_base import PretrainedMixin + +__all__ = [ + "PretrainedBERTMixin", +] + +_BERT_PATH = "https://storage.googleapis.com/bert_models/" + + +class PretrainedBERTMixin(PretrainedMixin): + r"""A mixin class to support loading pre-trained checkpoints for modules + that implement the BERT model. + + The BERT model was proposed in (`Devlin et al`. 2018) + `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ + . A bidirectional Transformer language model pre-trained on large text + corpora. Available model names include: + + * ``bert-base-uncased``: 12-layer, 768-hidden, 12-heads, + 110M parameters. + * ``bert-large-uncased``: 24-layer, 1024-hidden, 16-heads, + 340M parameters. + * ``bert-base-cased``: 12-layer, 768-hidden, 12-heads , 110M parameters. + * ``bert-large-cased``: 24-layer, 1024-hidden, 16-heads, + 340M parameters. + * ``bert-base-multilingual-uncased``: 102 languages, 12-layer, + 768-hidden, 12-heads, 110M parameters. + * ``bert-base-multilingual-cased``: 104 languages, 12-layer, 768-hidden, + 12-heads, 110M parameters. + * ``bert-base-chinese``: Chinese Simplified and Traditional, 12-layer, + 768-hidden, 12-heads, 110M parameters. + + We provide the following BERT classes: + + * :class:`~texar.tf.modules.BERTEncoder` for text encoding. + * :class:`~texar.tf.modules.BERTClassifier` for text classification and + sequence tagging. + + .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`: + https://arxiv.org/abs/1810.04805 + """ + + __metaclass__ = ABCMeta + + _MODEL_NAME = "BERT" + _MODEL2URL = { + 'bert-base-uncased': + _BERT_PATH + "2018_10_18/uncased_L-12_H-768_A-12.zip", + 'bert-large-uncased': + _BERT_PATH + "2018_10_18/uncased_L-24_H-1024_A-16.zip", + 'bert-base-cased': + _BERT_PATH + "2018_10_18/cased_L-12_H-768_A-12.zip", + 'bert-large-cased': + _BERT_PATH + "2018_10_18/cased_L-24_H-1024_A-16.zip", + 'bert-base-multilingual-uncased': + _BERT_PATH + "2018_11_23/multi_cased_L-12_H-768_A-12.zip", + 'bert-base-multilingual-cased': + _BERT_PATH + "2018_11_03/multilingual_L-12_H-768_A-12.zip", + 'bert-base-chinese': + _BERT_PATH + "2018_11_03/chinese_L-12_H-768_A-12.zip", + } + + @classmethod + def _transform_config(cls, pretrained_model_name, cache_dir): + info = list(os.walk(cache_dir)) + root, _, files = info[0] + config_path = None + + for file in files: + if file.endswith('config.json'): + config_path = os.path.join(root, file) + with open(config_path) as f: + config_ckpt = json.loads(f.read()) + + if config_path is None: + raise ValueError("Cannot find the config file in {}".format( + cache_dir)) + + configs = {} + hidden_dim = config_ckpt['hidden_size'] + configs['hidden_size'] = config_ckpt['hidden_size'] + configs['embed'] = { + 'name': 'word_embeddings', + 'dim': hidden_dim} + configs['vocab_size'] = config_ckpt['vocab_size'] + + configs['segment_embed'] = { + 'name': 'token_type_embeddings', + 'dim': hidden_dim} + configs['type_vocab_size'] = config_ckpt['type_vocab_size'] + + configs['position_embed'] = { + 'name': 'position_embeddings', + 'dim': hidden_dim} + configs['position_size'] = config_ckpt['max_position_embeddings'] + + configs['encoder'] = { + 'name': 'encoder', + 'embedding_dropout': config_ckpt['hidden_dropout_prob'], + 'num_blocks': config_ckpt['num_hidden_layers'], + 'multihead_attention': { + 'use_bias': True, + 'num_units': hidden_dim, + 'num_heads': config_ckpt['num_attention_heads'], + 'output_dim': hidden_dim, + 'dropout_rate': config_ckpt['attention_probs_dropout_prob'], + 'name': 'self' + }, + 'residual_dropout': config_ckpt['hidden_dropout_prob'], + 'dim': hidden_dim, + 'use_bert_config': True, + 'poswise_feedforward': { + "layers": [ + { + 'type': 'Dense', + 'kwargs': { + 'name': 'intermediate', + 'units': config_ckpt['intermediate_size'], + 'activation': config_ckpt['hidden_act'], + 'use_bias': True, + } + }, + { + 'type': 'Dense', + 'kwargs': { + 'name': 'output', + 'units': hidden_dim, + 'activation': None, + 'use_bias': True, + } + }, + ], + }, + } + return configs + + def _init_from_checkpoint(self, pretrained_model_name, + cache_dir, scope_name, **kwargs): + tvars = tf.trainable_variables() + init_checkpoint = os.path.abspath(os.path.join(cache_dir, + 'bert_model.ckpt')) + if init_checkpoint: + assignment_map, initialized_variable_names = \ + self._get_assignment_map_from_checkpoint( + tvars, init_checkpoint, scope_name) + tf.train.init_from_checkpoint(init_checkpoint, assignment_map) + + def _get_assignment_map_from_checkpoint(self, tvars, init_checkpoint, + scope_name): + r"""`https://github.com/google-research/bert/blob/master/modeling.py` + + Compute the union of the current variables and checkpoint variables. + Because the variable scope of the original BERT and Texar + implementation, we need to build a assignment map to match the + variables. + """ + initialized_variable_names = {} + + name_to_variable = collections.OrderedDict() + for var in tvars: + name = var.name + m = re.match("^(.*):\\d+$", name) + if m is not None: + name = m.group(1) + name_to_variable[name] = var + + init_vars = tf.train.list_variables(init_checkpoint) + + assignment_map = { + 'bert/embeddings/word_embeddings': + scope_name + '/word_embeddings/w', + 'bert/embeddings/token_type_embeddings': + scope_name + '/token_type_embeddings/w', + 'bert/embeddings/position_embeddings': + scope_name + '/position_embeddings/w', + 'bert/embeddings/LayerNorm/beta': + scope_name + '/encoder/LayerNorm/beta', + 'bert/embeddings/LayerNorm/gamma': + scope_name + '/encoder/LayerNorm/gamma', + } + for check_name, model_name in assignment_map.items(): + initialized_variable_names[model_name] = 1 + initialized_variable_names[model_name + ":0"] = 1 + + for check_name, _ in init_vars: + if check_name.startswith('bert'): + if check_name.startswith('bert/embeddings'): + continue + check_name_scope = check_name.replace("bert/", scope_name + '/') + model_name = re.sub( + 'layer_\\d+/output/dense', + lambda x: x.group(0).replace('output/dense', 'ffn/output'), + check_name_scope) + if model_name == check_name_scope: + model_name = re.sub( + 'layer_\\d+/output/LayerNorm', + lambda x: x.group(0).replace('output/LayerNorm', + 'ffn/LayerNorm'), + check_name_scope) + if model_name == check_name_scope: + model_name = re.sub( + 'layer_\\d+/intermediate/dense', + lambda x: x.group(0).replace('intermediate/dense', + 'ffn/intermediate'), + check_name_scope) + if model_name == check_name_scope: + model_name = re.sub('attention/output/dense', + 'attention/self/output', + check_name_scope) + if model_name == check_name_scope: + model_name = check_name_scope.replace( + 'attention/output/LayerNorm', 'output/LayerNorm') + + if model_name in name_to_variable.keys(): + assignment_map[check_name] = model_name + initialized_variable_names[model_name] = 1 + initialized_variable_names[model_name + ":0"] = 1 + else: + tf.logging.info( + 'model name:{} not exist'.format(model_name)) + + return assignment_map, initialized_variable_names diff --git a/texar/tf/modules/pretrained/bert_test.py b/texar/tf/modules/pretrained/bert_test.py new file mode 100644 index 00000000..8a221386 --- /dev/null +++ b/texar/tf/modules/pretrained/bert_test.py @@ -0,0 +1,99 @@ +""" +Unit tests for BERT utils. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import os +import tensorflow as tf + +from texar.tf.modules.pretrained.bert import * +from texar.tf.utils.test import pretrained_test + + +class BERTUtilsTest(tf.test.TestCase): + r"""Tests BERT utils. + """ + + @pretrained_test + def test_load_pretrained_bert_AND_transform_bert_to_texar_config(self): + + pretrained_model_dir = PretrainedBERTMixin.download_checkpoint( + pretrained_model_name="bert-base-uncased") + + info = list(os.walk(pretrained_model_dir)) + _, _, files = info[0] + self.assertIn('bert_model.ckpt.meta', files) + self.assertIn('bert_model.ckpt.data-00000-of-00001', files) + self.assertIn('bert_model.ckpt.index', files) + self.assertIn('bert_config.json', files) + + model_config = PretrainedBERTMixin._transform_config( + pretrained_model_name="bert-base-uncased", + cache_dir=pretrained_model_dir) + + exp_config = { + 'hidden_size': 768, + 'embed': { + 'name': 'word_embeddings', + 'dim': 768 + }, + 'vocab_size': 30522, + 'segment_embed': { + 'name': 'token_type_embeddings', + 'dim': 768 + }, + 'type_vocab_size': 2, + 'position_embed': { + 'name': 'position_embeddings', + 'dim': 768 + }, + 'position_size': 512, + 'encoder': { + 'name': 'encoder', + 'embedding_dropout': 0.1, + 'num_blocks': 12, + 'multihead_attention': { + 'use_bias': True, + 'num_units': 768, + 'num_heads': 12, + 'output_dim': 768, + 'dropout_rate': 0.1, + 'name': 'self' + }, + 'residual_dropout': 0.1, + 'dim': 768, + 'use_bert_config': True, + 'poswise_feedforward': { + 'layers': [ + { + 'type': 'Dense', + 'kwargs': { + 'name': 'intermediate', + 'units': 3072, + 'activation': 'gelu', + 'use_bias': True + } + }, + { + 'type': 'Dense', + 'kwargs': { + 'name': 'output', + 'units': 768, + 'activation': None, + 'use_bias': True + } + } + ] + } + } + } + + self.assertDictEqual(model_config, exp_config) + + +if __name__ == "__main__": + tf.test.main() diff --git a/texar/tf/modules/pretrained/bert_utils.py b/texar/tf/modules/pretrained/bert_utils.py deleted file mode 100644 index 7cc8819c..00000000 --- a/texar/tf/modules/pretrained/bert_utils.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright 2019 The Texar Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Utility functions related to BERT encoders. -""" - -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division -from __future__ import unicode_literals - -import json -import collections -import re -import os -import tensorflow as tf -from texar.tf.modules.pretrained.pretrained_utils import default_download_dir -from texar.tf.data.data_utils import maybe_download - -__all__ = [ - "transform_bert_to_texar_config", - "init_bert_checkpoint", - "load_pretrained_bert" -] - -_BERT_PATH = "https://storage.googleapis.com/bert_models/" -_MODEL2URL = { - 'bert-base-uncased': - _BERT_PATH + "2018_10_18/uncased_L-12_H-768_A-12.zip", - 'bert-large-uncased': - _BERT_PATH + "2018_10_18/uncased_L-24_H-1024_A-16.zip", - 'bert-base-cased': - _BERT_PATH + "2018_10_18/cased_L-12_H-768_A-12.zip", - 'bert-large-cased': - _BERT_PATH + "2018_10_18/cased_L-24_H-1024_A-16.zip", - 'bert-base-multilingual-uncased': - _BERT_PATH + "2018_11_23/multi_cased_L-12_H-768_A-12.zip", - 'bert-base-multilingual-cased': - _BERT_PATH + "2018_11_03/multilingual_L-12_H-768_A-12.zip", - 'bert-base-chinese': - _BERT_PATH + "2018_11_03/chinese_L-12_H-768_A-12.zip" -} - - -def _get_assignment_map_from_checkpoint(tvars, init_checkpoint, scope_name): - """ - Provided by Google AI Language Team. - Compute the union of the current variables and checkpoint variables. - Because the variable scope of the original BERT and Texar implementation, - we need to build a assignment map to match the variables. - """ - initialized_variable_names = {} - - name_to_variable = collections.OrderedDict() - for var in tvars: - name = var.name - m = re.match("^(.*):\\d+$", name) - if m is not None: - name = m.group(1) - name_to_variable[name] = var - - init_vars = tf.train.list_variables(init_checkpoint) - - assignment_map = { - 'bert/embeddings/word_embeddings': - scope_name + '/word_embeddings/w', - 'bert/embeddings/token_type_embeddings': - scope_name + '/token_type_embeddings/w', - 'bert/embeddings/position_embeddings': - scope_name + '/position_embeddings/w', - 'bert/embeddings/LayerNorm/beta': - scope_name + '/encoder/LayerNorm/beta', - 'bert/embeddings/LayerNorm/gamma': - scope_name + '/encoder/LayerNorm/gamma', - } - for check_name, model_name in assignment_map.items(): - initialized_variable_names[model_name] = 1 - initialized_variable_names[model_name + ":0"] = 1 - - for check_name, _ in init_vars: - if check_name.startswith('bert'): - if check_name.startswith('bert/embeddings'): - continue - check_name_scope = check_name.replace("bert/", scope_name+'/') - model_name = re.sub( - 'layer_\\d+/output/dense', - lambda x: x.group(0).replace('output/dense', 'ffn/output'), - check_name_scope) - if model_name == check_name_scope: - model_name = re.sub( - 'layer_\\d+/output/LayerNorm', - lambda x: x.group(0).replace('output/LayerNorm', - 'ffn/LayerNorm'), - check_name_scope) - if model_name == check_name_scope: - model_name = re.sub( - 'layer_\\d+/intermediate/dense', - lambda x: x.group(0).replace('intermediate/dense', - 'ffn/intermediate'), - check_name_scope) - if model_name == check_name_scope: - model_name = re.sub('attention/output/dense', - 'attention/self/output', check_name_scope) - if model_name == check_name_scope: - model_name = check_name_scope.replace( - 'attention/output/LayerNorm', 'output/LayerNorm') - - if model_name in name_to_variable.keys(): - assignment_map[check_name] = model_name - initialized_variable_names[model_name] = 1 - initialized_variable_names[model_name + ":0"] = 1 - else: - tf.logging.info('model name:{} not exist'.format(model_name)) - - return assignment_map, initialized_variable_names - - -def init_bert_checkpoint(init_checkpoint_dir, scope_name): - """ - Initializes BERT model parameters from a checkpoint. - Provided by Google AI Language Team. - - Args: - init_checkpoint_dir (str): path to the checkpoint. - scope_name: variable scope of bert encoder. - """ - tvars = tf.trainable_variables() - init_checkpoint = os.path.join(init_checkpoint_dir, 'bert_model.ckpt') - if init_checkpoint: - assignment_map, initialized_variable_names = \ - _get_assignment_map_from_checkpoint( - tvars, init_checkpoint, scope_name) - tf.train.init_from_checkpoint(init_checkpoint, assignment_map) - - -def load_pretrained_bert(pretrained_model_name, cache_dir=None): - """ - Return the directory in which the pretrained model is cached. - """ - if pretrained_model_name in _MODEL2URL: - download_path = _MODEL2URL[pretrained_model_name] - else: - raise ValueError( - "Pre-trained model not found: {}".format(pretrained_model_name)) - - if cache_dir is None: - cache_dir = default_download_dir("bert") - - file_name = download_path.split('/')[-1] - - cache_path = os.path.join(cache_dir, file_name.split('.')[0]) - if not os.path.exists(cache_path): - maybe_download(download_path, cache_dir, extract=True) - else: - print("Using cached pre-trained model {} from: {}".format( - pretrained_model_name, cache_dir)) - - return cache_path - - -def transform_bert_to_texar_config(config_dir): - """ - Load the Json config file and transform it into Texar style configuration. - """ - config_ckpt = json.loads( - open(os.path.join(config_dir, 'bert_config.json')).read()) - configs = {} - hidden_dim = config_ckpt['hidden_size'] - configs['hidden_size'] = config_ckpt['hidden_size'] - configs['embed'] = { - 'name': 'word_embeddings', - 'dim': hidden_dim} - configs['vocab_size'] = config_ckpt['vocab_size'] - - configs['segment_embed'] = { - 'name': 'token_type_embeddings', - 'dim': hidden_dim} - configs['type_vocab_size'] = config_ckpt['type_vocab_size'] - - configs['position_embed'] = { - 'name': 'position_embeddings', - 'dim': hidden_dim} - configs['position_size'] = config_ckpt['max_position_embeddings'] - - configs['encoder'] = { - 'name': 'encoder', - 'embedding_dropout': config_ckpt['hidden_dropout_prob'], - 'num_blocks': config_ckpt['num_hidden_layers'], - 'multihead_attention': { - 'use_bias': True, - 'num_units': hidden_dim, - 'num_heads': config_ckpt['num_attention_heads'], - 'output_dim': hidden_dim, - 'dropout_rate': config_ckpt['attention_probs_dropout_prob'], - 'name': 'self' - }, - 'residual_dropout': config_ckpt['hidden_dropout_prob'], - 'dim': hidden_dim, - 'use_bert_config': True, - 'poswise_feedforward': { - "layers": [ - { - 'type': 'Dense', - 'kwargs': { - 'name': 'intermediate', - 'units': config_ckpt['intermediate_size'], - 'activation': config_ckpt['hidden_act'], - 'use_bias': True, - } - }, - { - 'type': 'Dense', - 'kwargs': { - 'name': 'output', - 'units': hidden_dim, - 'activation': None, - 'use_bias': True, - } - }, - ], - }, - } - return configs diff --git a/texar/tf/modules/pretrained/bert_utils_test.py b/texar/tf/modules/pretrained/bert_utils_test.py deleted file mode 100644 index 6ba5d102..00000000 --- a/texar/tf/modules/pretrained/bert_utils_test.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -Unit tests for bert utils. -""" - -import os -import tensorflow as tf - -from texar.tf.modules.pretrained.bert_utils import \ - load_pretrained_bert, transform_bert_to_texar_config - - -class BertUtilsTest(tf.test.TestCase): - r"""Tests bert utils. - """ - - def test_load_pretrained_model_AND_transform_bert_to_texar_config(self): - - pretrained_model_dir = load_pretrained_bert( - pretrained_model_name="bert-base-uncased") - - info = list(os.walk(pretrained_model_dir)) - _, _, files = info[0] - self.assertIn('bert_model.ckpt.meta', files) - self.assertIn('bert_model.ckpt.data-00000-of-00001', files) - self.assertIn('bert_model.ckpt.index', files) - self.assertIn('bert_config.json', files) - - model_config = transform_bert_to_texar_config(pretrained_model_dir) - - expected_config = { - 'hidden_size': 768, - 'embed': {'name': 'word_embeddings', 'dim': 768}, - 'vocab_size': 30522, - 'segment_embed': {'name': 'token_type_embeddings', 'dim': 768}, - 'type_vocab_size': 2, - 'position_embed': {'name': 'position_embeddings', 'dim': 768}, - 'position_size': 512, - 'encoder': { - 'name': 'encoder', - 'embedding_dropout': 0.1, - 'num_blocks': 12, - 'multihead_attention': { - 'use_bias': True, - 'num_units': 768, - 'num_heads': 12, - 'output_dim': 768, - 'dropout_rate': 0.1, - 'name': 'self'}, - 'residual_dropout': 0.1, - 'dim': 768, - 'use_bert_config': True, - 'poswise_feedforward': { - 'layers': [{ - 'type': 'Dense', - 'kwargs': { - 'name': 'intermediate', - 'units': 3072, - 'activation': 'gelu', - 'use_bias': True - } - }, - { - 'type': 'Dense', - 'kwargs': { - 'name': 'output', - 'units': 768, - 'activation': None, - 'use_bias': True} - }] - } - } - } - - self.assertDictEqual(model_config, expected_config) - - -if __name__ == "__main__": - tf.test.main() diff --git a/texar/tf/modules/pretrained/pretrained_base.py b/texar/tf/modules/pretrained/pretrained_base.py index eadfc1d0..bbbbeba9 100644 --- a/texar/tf/modules/pretrained/pretrained_base.py +++ b/texar/tf/modules/pretrained/pretrained_base.py @@ -15,62 +15,149 @@ Base class for Pre-trained Modules. """ -from texar.tf.module_base import ModuleBase -from texar.tf.modules.pretrained.bert_utils import ( - load_pretrained_bert, transform_bert_to_texar_config) -from texar.tf.modules.pretrained.xlnet_utils import ( - load_pretrained_xlnet, transform_xlnet_to_texar_config) +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division +from __future__ import unicode_literals + +import os +import sys + +from abc import ABCMeta, abstractmethod +from pathlib import Path +from texar.tf.data.data_utils import maybe_download +from texar.tf.hyperparams import HParams +from texar.tf.module_base import ModuleBase __all__ = [ - "PretrainedBase", + "default_download_dir", + "set_default_download_dir", + "PretrainedMixin", ] +_default_texar_download_dir = None + -class PretrainedBase(ModuleBase): - r"""Base class for all pre-trained classes to inherit. - - Args: - pretrained_model_name (optional): A str with the name - of a pre-trained model to load. If `None`, will use the model - name in :attr:`hparams`. - cache_dir (optional): The path to a folder in which the - pre-trained models will be cached. If `None` (default), - a default directory will be used. - hparams (dict or HParams, optional): Hyperparameters. Missing - hyperparameter will be set to default values. See - :meth:`default_hparams` for the hyperparameter structure - and default values. +def default_download_dir(name): + r"""Return the directory to which packages will be downloaded by default. """ + global _default_texar_download_dir # pylint: disable=global-statement + if _default_texar_download_dir is None: + if sys.platform == 'win32' and 'APPDATA' in os.environ: + # On Windows, use %APPDATA% + home_dir = Path(os.environ['APPDATA']) + else: + # Otherwise, install in the user's home directory. + home_dir = Path(os.environ["HOME"]) - def __init__(self, - pretrained_model_name=None, - cache_dir=None, - hparams=None): + if os.access(str(home_dir), os.W_OK): + _default_texar_download_dir = home_dir / 'texar_data' + else: + raise ValueError("The path {} is not writable. Please manually " + "specify the download directory".format(home_dir)) - ModuleBase.__init__(self, hparams=hparams) + if not _default_texar_download_dir.exists(): + _default_texar_download_dir.mkdir(parents=True) - self.pretrained_model_dir = None + return _default_texar_download_dir / name - if self.model_name == "BERT": - load_func = load_pretrained_bert - transform_func = transform_bert_to_texar_config - elif self.model_name == "XLNet": - load_func = load_pretrained_xlnet - transform_func = transform_xlnet_to_texar_config - else: - raise ValueError("Could not find this pre-trained model.") - if pretrained_model_name: - self.pretrained_model_dir = load_func( - pretrained_model_name, cache_dir) - elif self._hparams.pretrained_model_name is not None: - self.pretrained_model_dir = load_func( - self._hparams.pretrained_model_name, cache_dir) +def set_default_download_dir(path): + if isinstance(path, str): + path = Path(path) + elif not isinstance(path, Path): + raise ValueError("`path` must be a string or a pathlib.Path object") + + if not os.access(str(path), os.W_OK): + raise ValueError( + "The specified download directory {} is not writable".format(path)) + + global _default_texar_download_dir # pylint: disable=global-statement + _default_texar_download_dir = path + +class PretrainedMixin(ModuleBase): + r"""A mixin class for all pre-trained classes to inherit. + """ + __metaclass__ = ABCMeta + + _MODEL_NAME = None + _MODEL2URL = None + + pretrained_model_dir = None + + @classmethod + def available_checkpoints(cls): + return list(cls._MODEL2URL.keys()) + + def _name_to_variable(self, name): + r"""Find the corresponding variable given the specified name. + """ + pointer = self + for m_name in name.split("."): + if m_name.isdigit(): + num = int(m_name) + pointer = pointer[num] # type: ignore + else: + pointer = getattr(pointer, m_name) + return pointer # type: ignore + + def load_pretrained_config(self, + pretrained_model_name=None, + cache_dir=None, + hparams=None): + r"""Load paths and configurations of the pre-trained model. + + Args: + pretrained_model_name (optional): A str with the name + of a pre-trained model to load. If `None`, will use the model + name in :attr:`hparams`. + cache_dir (optional): The path to a folder in which the + pre-trained models will be cached. If `None` (default), + a default directory will be used. + hparams (dict or HParams, optional): Hyperparameters. Missing + hyperparameter will be set to default values. See + :meth:`default_hparams` for the hyperparameter structure + and default values. + """ + if not hasattr(self, "_hparams"): + self._hparams = HParams(hparams, self.default_hparams()) + else: + # Probably already parsed by subclasses. We rely on subclass + # implementations to get this right. + # As a sanity check, we require `hparams` to be `None` in this case. + if hparams is not None: + raise ValueError( + "`self._hparams` is already assigned, but `hparams` " + "argument is not None.") + + self.pretrained_model_dir = None + self.pretrained_model_name = pretrained_model_name + + if self.pretrained_model_name is None: + self.pretrained_model_name = self._hparams.pretrained_model_name + if self.pretrained_model_name is not None: + self.pretrained_model_dir = self.download_checkpoint( + self.pretrained_model_name, cache_dir) + pretrained_model_hparams = self._transform_config( + self.pretrained_model_name, self.pretrained_model_dir) + self._hparams = HParams( + pretrained_model_hparams, self._hparams.todict()) + + def init_pretrained_weights(self, scope_name, **kwargs): if self.pretrained_model_dir: - self.pretrained_model_hparams = transform_func( - self.pretrained_model_dir) + self._init_from_checkpoint( + self.pretrained_model_name, + self.pretrained_model_dir, scope_name, **kwargs) + else: + self.reset_parameters() + + def reset_parameters(self): + r"""Initialize parameters of the pre-trained model. This method is only + called if pre-trained checkpoints are not loaded. + """ + pass @staticmethod def default_hparams(): @@ -89,15 +176,80 @@ def default_hparams(): '@no_typecheck': ['pretrained_model_name'] } - def _build(self, inputs, *args, **kwargs): - r"""Encodes the inputs and (optionally) conduct downstream prediction. + @classmethod + def download_checkpoint(cls, pretrained_model_name, cache_dir=None): + r"""Download the specified pre-trained checkpoint, and return the + directory in which the checkpoint is cached. Args: - inputs: Inputs to the pre-trained module. - *args: Other arguments. - **kwargs: Keyword arguments. + pretrained_model_name (str): Name of the model checkpoint. + cache_dir (str, optional): Path to the cache directory. If `None`, + uses the default directory (user's home directory). Returns: - Encoding results or prediction results. + Path to the cache directory. + """ + if pretrained_model_name in cls._MODEL2URL: + download_path = cls._MODEL2URL[pretrained_model_name] + else: + raise ValueError( + "Pre-trained model not found: {}".format(pretrained_model_name)) + + if cache_dir is None: + cache_path = default_download_dir(cls._MODEL_NAME) + else: + cache_path = Path(cache_dir) + cache_path = cache_path / pretrained_model_name + + if not cache_path.exists(): + if isinstance(download_path, list): + for path in download_path: + maybe_download(path, str(cache_path)) + else: + filename = download_path.split('/')[-1] + maybe_download(download_path, str(cache_path), extract=True) + folder = None + for file in cache_path.iterdir(): + if file.is_dir(): + folder = file + assert folder is not None + (cache_path / filename).unlink() + for file in folder.iterdir(): + file.rename(file.parents[1] / file.name) + folder.rmdir() + print("Pre-trained {} checkpoint {} cached to {}".format( + cls._MODEL_NAME, pretrained_model_name, cache_path)) + else: + print("Using cached pre-trained {} checkpoint from {}.".format( + cls._MODEL_NAME, cache_path)) + + return str(cache_path) + + @classmethod + @abstractmethod + def _transform_config(cls, pretrained_model_name, cache_dir): + r"""Load the official configuration file and transform it into + Texar-style hyperparameters. + + Args: + pretrained_model_name (str): Name of the pre-trained model. + cache_dir (str): Path to the cache directory. + + Returns: + dict: Texar module hyperparameters. + """ + raise NotImplementedError + + @abstractmethod + def _init_from_checkpoint(self, pretrained_model_name, cache_dir, + scope_name, **kwargs): + r"""Initialize model parameters from weights stored in the pre-trained + checkpoint. + + Args: + pretrained_model_name (str): Name of the pre-trained model. + cache_dir (str): Path to the cache directory. + scope_name: Variable scope. + **kwargs: Additional arguments for specific models. """ raise NotImplementedError diff --git a/texar/tf/modules/pretrained/pretrained_utils.py b/texar/tf/modules/pretrained/pretrained_utils.py deleted file mode 100644 index 743e588c..00000000 --- a/texar/tf/modules/pretrained/pretrained_utils.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 The Texar Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Utils of Pre-trained Modules. -""" - -import os -import sys - - -__all__ = [ - "default_download_dir", -] - - -def default_download_dir(name): - r"""Return the directory to which packages will be downloaded by default. - """ - package_dir = os.path.dirname(os.path.dirname( - os.path.dirname(os.path.dirname(__file__)))) - if os.access(package_dir, os.W_OK): - texar_download_dir = os.path.join(package_dir, 'texar_download') - else: - # On Windows, use %APPDATA% - if sys.platform == 'win32' and 'APPDATA' in os.environ: - home_dir = os.environ['APPDATA'] - - # Otherwise, install in the user's home directory. - else: - home_dir = os.path.expanduser('~/') - if home_dir == '~/': - raise ValueError("Could not find a default download directory") - - texar_download_dir = os.path.join(home_dir, 'texar_download') - - if not os.path.exists(texar_download_dir): - os.mkdir(texar_download_dir) - - return os.path.join(texar_download_dir, name) diff --git a/texar/tf/modules/pretrained/xlnet.py b/texar/tf/modules/pretrained/xlnet.py new file mode 100644 index 00000000..5fa6d29a --- /dev/null +++ b/texar/tf/modules/pretrained/xlnet.py @@ -0,0 +1,177 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utils of XLNet Modules. +""" + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division +from __future__ import unicode_literals + +import collections +import json +import os +import re + +from abc import ABCMeta + +import tensorflow as tf + +from texar.tf.modules.pretrained.pretrained_base import PretrainedMixin + +__all__ = [ + "PretrainedXLNetMixin", +] + +_XLNET_PATH = "https://storage.googleapis.com/xlnet/released_models/" + + +class PretrainedXLNetMixin(PretrainedMixin): + r"""A mixin class to support loading pre-trained checkpoints for modules + that implement the XLNet model. + + The XLNet model was proposed in + `XLNet: Generalized Autoregressive Pretraining for Language Understanding`_ + by `Yang et al.` It is based on the Transformer-XL model, pre-trained on a + large corpus using a language modeling objective that considers all + permutations of the input sentence. + + The available XLNet models are as follows: + + * ``xlnet-based-cased``: 12-layer, 768-hidden, 12-heads. This model is + trained on full data (different from the one in the paper). + * ``xlnet-large-cased``: 24-layer, 1024-hidden, 16-heads. + + We provide the following XLNet classes: + + * :class:`~texar.torch.modules.XLNetEncoder` for text encoding. + * :class:`~texar.torch.modules.XLNetDecoder` for text generation and + decoding. + * :class:`~texar.torch.modules.XLNetClassifier` for text classification + and sequence tagging. + * :class:`~texar.torch.modules.XLNetRegressor` for text regression. + + .. _`XLNet: Generalized Autoregressive Pretraining for Language Understanding`: + http://arxiv.org/abs/1906.08237 + """ + + __metaclass__ = ABCMeta + + _MODEL_NAME = "XLNet" + _MODEL2URL = { + 'xlnet-base-cased': + _XLNET_PATH + "cased_L-12_H-768_A-12.zip", + 'xlnet-large-cased': + _XLNET_PATH + "cased_L-24_H-1024_A-16.zip", + } + + @classmethod + def _transform_config(cls, pretrained_model_name, cache_dir): + info = list(os.walk(cache_dir)) + root, _, files = info[0] + config_path = None + for file in files: + if file.endswith('config.json'): + config_path = os.path.join(root, file) + if config_path is None: + raise ValueError("Cannot find the config file in {}".format( + cache_dir)) + + with open(config_path) as f: + config_ckpt = json.loads(f.read()) + + configs = { + "head_dim": config_ckpt["d_head"], + "ffn_inner_dim": config_ckpt["d_inner"], + "hidden_dim": config_ckpt["d_model"], + "activation": config_ckpt["ff_activation"], + "num_heads": config_ckpt["n_head"], + "num_layers": config_ckpt["n_layer"], + "vocab_size": config_ckpt["n_token"], + "untie_r": config_ckpt["untie_r"] + } + + return configs + + def _init_from_checkpoint(self, pretrained_model_name, + cache_dir, scope_name, **kwargs): + + tvars = tf.trainable_variables() + init_checkpoint = os.path.join(cache_dir, 'xlnet_model.ckpt') + if init_checkpoint: + assignment_map, initialized_variable_names = \ + self._get_assignment_map_from_checkpoint( + tvars, init_checkpoint, scope_name) + tf.train.init_from_checkpoint(init_checkpoint, assignment_map) + + def _get_assignment_map_from_checkpoint(self, tvars, init_checkpoint, + scope_name): + r""" + Compute the union of the current variables and checkpoint variables. + Because of the variable scope of the original XLNet and Texar + implementation, we need to build a assignment map to match the variables. + """ + assignment_map = {} + initialized_variable_names = {} + + name_to_variable = collections.OrderedDict() + for var in tvars: + name = var.name + m = re.match("^(.*):\\d+$", name) + if m is not None: + name = m.group(1) + name_to_variable[name] = var + + init_vars = tf.train.list_variables(init_checkpoint) + + for check_name, _ in init_vars: + check_name_scope = check_name.replace( + 'model/transformer/', scope_name + '/') + model_name = check_name_scope + if check_name.startswith('model/lm_loss/bias'): + model_name = scope_name + '/lm_loss/bias' + elif check_name.startswith('model/transformer/mask_emb'): + model_name = check_name_scope.replace( + 'mask_emb/mask_emb', 'mask_emb') + elif check_name.startswith('model/transformer/word_embedding'): + model_name = scope_name + '/word_embedder/w' + elif re.match('model/transformer/r_[r,s,w]_bias', check_name): + model_name = check_name_scope + elif re.match('model/transformer/seg_embed', check_name): + model_name = check_name_scope + elif re.match('model/transformer/layer_\\d+/rel_attn/[q,k,v,r,o]', + check_name): + model_name = check_name_scope + elif re.match('model/transformer/layer_\\d+/rel_attn/LayerNorm', + check_name): + model_name = check_name_scope.replace('LayerNorm/', '') + elif re.match('model/transformer/layer_\\d+/ff/layer_[1,2]', + check_name): + model_name = check_name_scope.replace('ff/layer_1', 'ff/dense') + if model_name == check_name_scope: + model_name = check_name_scope.replace( + 'ff/layer_2', 'ff/dense_1') + elif re.match('model/transformer/layer_\\d+/ff/LayerNorm', + check_name): + model_name = check_name_scope.replace('LayerNorm/', '') + + if model_name in name_to_variable.keys(): + assignment_map[check_name] = model_name + initialized_variable_names[model_name] = 1 + initialized_variable_names[model_name + ":0"] = 1 + else: + tf.logging.info('model name:{} not exist'.format(model_name)) + + return assignment_map, initialized_variable_names diff --git a/texar/tf/modules/pretrained/xlnet_model_utils.py b/texar/tf/modules/pretrained/xlnet_model_utils.py deleted file mode 100644 index 05e2640a..00000000 --- a/texar/tf/modules/pretrained/xlnet_model_utils.py +++ /dev/null @@ -1,555 +0,0 @@ -# Copyright 2019 The Texar Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Model Utils of XLNet Modules. -Adapted from -https://github.com/zihangdai/xlnet/blob/master/modeling.py -""" - -import tensorflow as tf - -from texar.tf.core import layers -from texar.tf.utils.mode import is_train_mode -from texar.tf.module_base import ModuleBase - - -__all__ = [ - 'PositionWiseFF', - 'PositionalEmbedding', - 'RelativePositionalEncoding', - 'RelativeMutiheadAttention' -] - - -class PositionWiseFF(ModuleBase): - r"""Position Wise feed forward.""" - def __init__(self, hparams=None): - ModuleBase.__init__(self, hparams) - - hidden_dim = self._hparams.hidden_dim - ffn_inner_dim = self._hparams.ffn_inner_dim - dropout = self._hparams.dropout - activation = self._hparams.activation - if activation == 'gelu': - activation = layers.gelu - - with tf.variable_scope(self.variable_scope): - tf.get_variable_scope().set_initializer( - layers.get_initializer(self._hparams.initializer)) - l1_hparams = { - "type": "Dense", - "kwargs": { - "units": ffn_inner_dim, - "activation": activation - } - } - self.linear1 = layers.get_layer(hparams=l1_hparams) - dropout_hparams = { - "type": "Dropout", - "kwargs": { - "rate": dropout - } - } - self.dropout = layers.get_layer(hparams=dropout_hparams) - l2_hparams = { - "type": "Dense", - "kwargs": { - "units": hidden_dim - } - } - self.linear2 = layers.get_layer(hparams=l2_hparams) - - @staticmethod - def default_hparams(): - r"""Returns a dictionary of hyperparameters with default values. - - .. code-block:: python - - { - "hidden_dim": 768, - "ffn_inner_dim": 3072, - "dropout": 0.1, - "activation": 'gelu' - } - - Here - - "hidden_dim": int - Dimension of the layer fed as input to feed forward network - - "ffn_inner_dim": int - Inner dimension of the feed forward layer - - "dropout": float - Dropout rate for layers - - "activation": str or callable - Activation function applied to the output of the PositionWise FF. - See :func:`~texar.tf.core.get_activation_fn` for more details. - """ - return { - "name": "ff", - "initializer": None, - "hidden_dim": 768, - "ffn_inner_dim": 3072, - "dropout": 0.1, - "activation": 'gelu', - } - - def _build(self, input, mode=None): - r"""Compute feed forward for the input. - - Args: - input: Input tensor of size `(max_time, batch_size, hidden_dim)` - mode (optional): A tensor taking value in - :tf_main:`tf.estimator.ModeKeys `, including - `TRAIN`, `EVAL`, and `PREDICT`. If `None`, dropout is - controlled by :func:`texar.tf.global_mode`. - - :returns: A tensor output of the position wise feed forward network - """ - is_training = is_train_mode(mode) - output = self.linear1(input) - output = self.dropout(output, training=is_training) - output = self.linear2(output) - output = self.dropout(output, training=is_training) - - # residual + layer norm - output = tf.contrib.layers.layer_norm( - input + output, begin_norm_axis=-1, scope=self.variable_scope, - reuse=tf.AUTO_REUSE) - - return output - - -class PositionalEmbedding(ModuleBase): - r"""Sinosoidal Positional Embedding. - """ - - # TODO(avinash) : See if this can be merged with Sinosoidal Position - # Embedder - def __init__(self, embed_dim): - ModuleBase.__init__(self) - freq_seq = tf.range(0.0, embed_dim, 2.0) - self.inv_freq = 1 / (10000 ** (freq_seq / embed_dim)) - - def _build(self, pos_seq): - r"""Compute sinosoidal positional embeddings. - - Args: - pos_seq: A 1D tensor of position sequences - - :returns: A 2D tensor of sinosoidal embeddings for the sequence. - """ - pos_seq = tf.dtypes.cast(pos_seq, dtype=self.inv_freq.dtype) - sinusoid_inp = tf.einsum('i,d->id', pos_seq, self.inv_freq) - pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1) - return pos_emb - - -class RelativePositionalEncoding(ModuleBase): - r"""Relative positional encodings.""" - def __init__(self, hparams=None): - ModuleBase.__init__(self, hparams) - self.sinusoid_embed = PositionalEmbedding(self._hparams.dim) - - @staticmethod - def default_hparams(): - r"""Returns a dictionary of hyperparameters with default values. - - .. code-block:: python - - { - "dim": 768, - "max_seq_len": 512 - } - - Here - - "dim": int - Dimension size of the positional embedding - - "max_seq_len": int - Maximum size of the sequence length - """ - return { - "name": "relative_positional_encoder", - "dim": 768, - "max_seq_len": 512 - } - - def _create_positional_embedding(self, start, end, step, batch_size, - clamp_len=None): - pos_seq = tf.range(start, end, step) - if clamp_len is not None: - pos_seq = tf.clip_by_value(pos_seq, -clamp_len, clamp_len) - pos_emb = self.sinusoid_embed(pos_seq) - pos_emb = pos_emb[:, None, :] - - if batch_size is not None: - pos_emb = tf.tile(pos_emb, [1, batch_size, 1]) - - return pos_emb - - def _build(self, batch_size, max_time, total_len, clamp_len=None, - attn_type='bi', bi_data=True): - r"""Compute relative positional encoding. - - Args - batch_size: int - Batch size of the input - - max_time: int - Sequence length of the input - - total_len: int - Sequence length + Memory length - - clamp_len (optional): int - Clamp all relative distances larger than clamp_len. - None means no clamping. - - attn_type (optional): str - Attention type. Supported values are `"uni"` and `"bi"`. - - bi_data (optional): bool - Whether to use bidirectional data input pipeline. Usually set to - True during pretraining and False during finetuning. - - :returns: A tensor of shape `[total_len + max_time, batch_size, dim]` - (if attn_type == `"bi"`) or of shape `[total_len, batch_size, dim]` - (if attn_type == `"uni"`) representing relative positional encoding - of the sequence. - """ - if attn_type == 'bi': - start, end = total_len, -max_time - elif attn_type == 'uni': - start, end = total_len, -1 - else: - raise ValueError("Unknown `attn_type` {}".format(attn_type)) - - if bi_data: - if batch_size % 2 != 0: - raise ValueError("`batch_size` must be an even number") - fwd_pos_embed = self._create_positional_embedding( - start, end, -1, batch_size // 2, clamp_len) - bwd_pos_embed = self._create_positional_embedding( - -start, -end, 1, batch_size // 2, clamp_len) - pos_embed = tf.concat([fwd_pos_embed, bwd_pos_embed], axis=1) - else: - pos_embed = self._create_positional_embedding( - start, end, -1, batch_size, clamp_len) - return pos_embed - - -class RelativeMutiheadAttention(ModuleBase): - r"""Compute relative multi-head attention for XLNet encoder. - - This module computes relative multi-head attention as explained in - `Transformer-XL, (Zihang et. al)` and in `XLNet (Zhiling et. al)`. - - Args: - r_r_bias: A tensor of shape `(num_heads, head_dim)`. - The bias value added to query head while computing position based - attention score. - - r_w_bias: A tensor of shape `(num_heads, head_dim)`. - The bias value added to query head while computing content based - attention score. - - r_s_bias (optional): A tensor of shape `(num_heads, head_dim)`. - The bias value added to query head while computing segment based - attention score. - - segment_embed (optional): A tensor of shape `(2, num_heads, head_dim)` - if use_segments is True. Otherwise, this is set to None. - - hparams (dict or HParams, optional): Hyperparameters. Missing - hyperparameter will be set to default values. See - :meth:`default_hparams` for the hyperparameter sturcture - and default values. - """ - def __init__(self, r_r_bias, r_w_bias, r_s_bias=None, segment_embed=None, - hparams=None): - ModuleBase.__init__(self, hparams=hparams) - - self.num_heads = self._hparams.num_heads - self.head_dim = self._hparams.head_dim - hidden_dim = self._hparams.hidden_dim - - with tf.variable_scope(self.variable_scope): - if self._hparams.initializer: - tf.get_variable_scope().set_initializer( - layers.get_initializer(self._hparams.initializer)) - - # Official implementation creates these head variables. - # If we create dense layers instead, there would be dimension - # mismatch while loading the tensors - # TODO(avinash) : Can we reshape tensors while loading the ckpt? - self.q_head = tf.get_variable( - 'q/kernel', [hidden_dim, self.num_heads, self.head_dim]) - - self.k_head = tf.get_variable( - 'k/kernel', [hidden_dim, self.num_heads, self.head_dim]) - - self.v_head = tf.get_variable( - 'v/kernel', [hidden_dim, self.num_heads, self.head_dim]) - - self.k_head_r = tf.get_variable( - 'r/kernel', [hidden_dim, self.num_heads, self.head_dim]) - - self.dropout = layers.get_layer(hparams={ - "type": "Dropout", - "kwargs": { - "rate": self._hparams.dropout - } - }) - - self.dropout_attn = layers.get_layer(hparams={ - "type": "Dropout", - "kwargs": { - "rate": self._hparams.attention_dropout - } - }) - - self.output_projection = tf.get_variable( - 'o/kernel', [hidden_dim, self.num_heads, self.head_dim]) - - self.r_r_bias = r_r_bias - self.r_w_bias = r_w_bias - - if self._hparams.use_segments: - self.segment_embed = segment_embed - self.r_s_bias = r_s_bias - - self.scale = 1 / (self.head_dim ** 0.5) - - @staticmethod - def default_hparams(): - r"""Returns a dictionary of hyperparameters with default values. - - .. code-block:: python - - { - "name": "rel_attn", - "initializer": None, - "num_heads": 12, - "hidden_dim": 768, - "head_dim": 64, - "dropout": 0.1, - "attention_dropout": 0.1, - "use_segments": True - } - - - - Here: - - The default parameters are values for cased XLNet-Base model. - - "initializer": dict, optional - Hyperparameters of the default initializer that initializes - variables created in this module. - See :func:`~texar.tf.core.get_initializer` for details. - - "num_heads": int - Number of heads in the attention - - "hidden_dim": int - Hidden dimension of the embeddings - - "head_dim": int - Size of the vectors after head projection. - - "dropout": float - Dropout rate for layers - - "attention_dropout": float - Dropout rate for attention layers - - "use_segments": bool - Boolean to indicate if the input has segments - - "name": str - Name of the module. - """ - return { - "name": "rel_attn", - "initializer": None, - "num_heads": 12, - "hidden_dim": 768, - "head_dim": 64, - "dropout": 0.1, - "attention_dropout": 0.1, - "use_segments": True, - } - - @staticmethod - def _rel_shift(x, klen=-1): - """Perform relative shift to form the relative attention score.""" - x_size = tf.shape(x) - - x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]]) - x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1]) - x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]]) - x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1]) - - return x - - def _compute_attention_score(self, q_head, k_head_h, v_head_h, k_head_r, - segment_mat, attn_mask=None, mode=None): - is_training = is_train_mode(mode) - - # Content based attention score. - q_head_rw = q_head + self.r_w_bias - # attn_ac: (max_time, tot_len, batch_size, n_head) - attn_ac = tf.einsum('ibnd,jbnd->ijbn', q_head_rw, k_head_h) - - # Position based attention score. - q_head_rr = q_head + self.r_r_bias - # attn_bd: (max_time, tot_len, batch_size, n_head) - attn_bd = tf.einsum('ibnd,jbnd->ijbn', q_head_rr, k_head_r) - attn_bd = self._rel_shift(attn_bd, klen=tf.shape(attn_ac)[1]) - - # Segment based attention score. - if segment_mat is None: - attn_ef = 0 - else: - q_head_rs = q_head + self.r_s_bias - attn_ef = tf.einsum( - 'ibnd,snd->ibns', q_head_rs, self.segment_embed) - attn_ef = tf.einsum('ijbs,ibns->ijbn', segment_mat, attn_ef) - - # Merge attention scores and perform masking. - # attn_score: (max_time, tot_len, batch_size, n_head) - attn_score = (attn_ac + attn_bd + attn_ef) * self.scale - if attn_mask is not None: - # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask - attn_score = attn_score - 1e30 * attn_mask - - # attention probability - attn_prob = tf.nn.softmax(attn_score, 1) - attn_prob = self.dropout_attn(attn_prob, training=is_training) - - # attention output - attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h) - - return attn_vec - - def _post_attention(self, attn_vec, mode=None): - is_training = is_train_mode(mode) - attn_out = tf.einsum('ibnd,hnd->ibh', attn_vec, self.output_projection) - attn_out = self.dropout(attn_out, training=is_training) - return attn_out - - def _build(self, states_h, pos_embed, states_g=None, segment_mat=None, - attn_mask_h=None, attn_mask_g=None, target_mapping=None, - memory=None, mode=None): - r"""Compute relative multi-head attention with relative positional - encoding. - - Args: - states_h: A content representation tensor of shape - `[max_time, batch_size, hidden_dim]` - - pos_embed: Position embedding tensor of shape - `[max_time, batch_size, hidden_dim]`. - - states_g (optional): A query representation tensor of shape - `[max_time, batch_size, hidden_dim]`. This tensor is set during - decoding. - - segment_mat (optional): A tensor of size - `[max_time, tot_len, batch_size]` indicating if tokens are in the - same seqment. A value at `(i, j, k)` of `1` indicates tokens at - `i` and `j` are not in the same sequence in batch k. - - attn_mask_h (optional): A tensor of shape - `[max_time, max_time, batch_size, 1]` Attention mask used while - computing attention score for `states_h` - - attn_mask_g (optional): A tensor of shape - `[max_time, max_time, batch_size, 1]` Attention mask used while - computing attention score for `states_g` - - target_mapping (optional): The target token mapping. Float tensor of - shape `[num_targets, max_time, batch_size]`. - A value of 1 for ``target_mapping[i, j, k]`` indicates that - the `i`-th target token (in order of permutation) in batch `k` - is the token at position `j`. - Each row ``target_mapping[i, :, k]`` can have no more than one - value of 1. - - memory (optional): Memory from previous batches. A list of length - `num_layers`, each a tensor of shape - `[mem_len, batch_size, hidden_dim]`. - - mode (optional): A tensor taking value in - :tf_main:`tf.estimator.ModeKeys `, including - `TRAIN`, `EVAL`, and `PREDICT`. If `None`, dropout is - controlled by :func:`texar.tf.global_mode`. - - :returns: Returns output states for `states_h` and `states_g` - (`states_g` is not None) - """ - batch_size = tf.shape(states_h)[1] - - if memory is not None and memory.shape.ndims > 1: - concat_input = tf.concat([memory, states_h], axis=0) - else: - concat_input = states_h - - # Content heads. - q_head_h = tf.einsum('ibh,hnd->ibnd', states_h, self.q_head) - k_head_h = tf.einsum('ibh,hnd->ibnd', concat_input, self.k_head) - v_head_h = tf.einsum('ibh,hnd->ibnd', concat_input, self.v_head) - - # Positional heads. - k_head_r = tf.einsum('ibh,hnd->ibnd', pos_embed, self.k_head_r) - - # Core attention ops. - attn_vec_h = self._compute_attention_score( - q_head_h, k_head_h, v_head_h, k_head_r, segment_mat, attn_mask_h, - mode) - - # Post attention processing. - attn_out_h = self._post_attention(attn_vec_h, mode=mode) - - output_h = tf.contrib.layers.layer_norm( - attn_out_h + states_h, begin_norm_axis=-1, - scope=self.variable_scope, reuse=tf.AUTO_REUSE) - - if states_g is not None: - q_head_g = tf.einsum('ibh,hnd->ibnd', states_g, self.q_head) - shape = tf.shape(q_head_g) - q_head_g = tf.reshape( - q_head_g, - shape=(shape[0], batch_size, self.num_heads, self.head_dim)) - if target_mapping is not None: - q_head_g = tf.einsum( - 'mbnd,mlb->lbnd', q_head_g, target_mapping) - attn_vec_g = self._compute_attention_score( - q_head_g, k_head_h, v_head_h, k_head_r, - segment_mat, attn_mask_g, mode) - if target_mapping is not None: - attn_vec_g = tf.einsum( - 'lbnd,mlb->mbnd', attn_vec_g, target_mapping) - attn_out_g = self._post_attention(attn_vec_g, mode=mode) - output_g = tf.contrib.layers.layer_norm( - attn_out_g + states_g, begin_norm_axis=-1, - scope=self.variable_scope, reuse=tf.AUTO_REUSE) - else: - output_g = None - - return output_h, output_g diff --git a/texar/tf/modules/pretrained/xlnet_model_utils_test.py b/texar/tf/modules/pretrained/xlnet_model_utils_test.py deleted file mode 100644 index 08b63a58..00000000 --- a/texar/tf/modules/pretrained/xlnet_model_utils_test.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -Unit tests for xlnet model utils. -""" -import tensorflow as tf - -from texar.tf.modules.pretrained.xlnet_model_utils import \ - PositionWiseFF, RelativePositionalEncoding, RelativeMutiheadAttention - - -class XLNetModelUtilsTest(tf.test.TestCase): - r"""Tests xlnet model utils. - """ - - def test_PositionWiseFF(self): - - # Case 1 - model = PositionWiseFF() - inputs = tf.random_uniform(shape=(32, model.hparams.hidden_dim)) - outputs = model(inputs) - self.assertEqual(outputs.shape, [32, model._hparams.hidden_dim]) - - # Case 2 - hparams = { - "hidden_dim": 16, - "ffn_inner_dim": 32, - "dropout": 0.1, - "activation": 'relu', - } - model = PositionWiseFF(hparams=hparams) - inputs = tf.random_uniform(shape=(32, 16)) - outputs = model(inputs) - self.assertEqual(outputs.shape, [32, 16]) - - # Case 3 - hparams = { - "hidden_dim": 16, - "ffn_inner_dim": 32, - "dropout": 0.1, - "activation": 'gelu', - } - model = PositionWiseFF(hparams=hparams) - inputs = tf.random_uniform(shape=(32, 16)) - outputs = model(inputs) - self.assertEqual(outputs.shape, [32, 16]) - - def test_RelativeMultiheadAttention(self): - num_heads = 12 - head_dim = 64 - - r_r_bias = tf.random_normal(shape=(num_heads, head_dim)) - r_w_bias = tf.random_normal(shape=(num_heads, head_dim)) - - model = RelativeMutiheadAttention(r_r_bias=r_r_bias, r_w_bias=r_w_bias) - - states_h = tf.random_uniform(shape=(16, 32, model._hparams.hidden_dim)) - pos_embed = tf.random_uniform(shape=(24, 32, model._hparams.hidden_dim)) - - output_h, output_g = model(states_h=states_h, pos_embed=pos_embed) - - self.assertEqual(output_h.shape, - [16, 32, model._hparams.hidden_dim]) - self.assertEqual(output_g, None) - - def test_RelativePositionalEncoding(self): - - batch_size = 16 - max_time = 8 - total_len = 32 - - # Case 1 - model = RelativePositionalEncoding() - pos_embed = model(batch_size=batch_size, - max_time=max_time, - total_len=total_len) - self.assertEqual(pos_embed.shape, - [40, 16, model._hparams.dim]) - - # Case 2 - model = RelativePositionalEncoding() - pos_embed = model(batch_size=batch_size, - max_time=max_time, - total_len=total_len, - attn_type='uni') - self.assertEqual(pos_embed.shape, - [33, 16, model._hparams.dim]) - - -if __name__ == "__main__": - tf.test.main() diff --git a/texar/tf/modules/pretrained/xlnet_test.py b/texar/tf/modules/pretrained/xlnet_test.py new file mode 100644 index 00000000..34c89386 --- /dev/null +++ b/texar/tf/modules/pretrained/xlnet_test.py @@ -0,0 +1,55 @@ +""" +Unit tests for xlnet utils. +""" + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import os +import tensorflow as tf + +from texar.tf.modules.pretrained.xlnet import * +from texar.tf.utils.test import pretrained_test + + +class XLNetUtilsTest(tf.test.TestCase): + r"""Tests XLNet utils. + """ + + @pretrained_test + def test_load_pretrained_model_AND_transform_xlnet_to_texar_config(self): + + pretrained_model_dir = PretrainedXLNetMixin.download_checkpoint( + pretrained_model_name="xlnet-base-cased") + + info = list(os.walk(pretrained_model_dir)) + _, _, files = info[0] + self.assertIn('spiece.model', files) + self.assertIn('xlnet_model.ckpt.meta', files) + self.assertIn('xlnet_model.ckpt.data-00000-of-00001', files) + self.assertIn('xlnet_model.ckpt.index', files) + self.assertIn('xlnet_config.json', files) + + model_config = PretrainedXLNetMixin._transform_config( + pretrained_model_name="xlnet-base-cased", + cache_dir=pretrained_model_dir) + + expected_config = { + 'head_dim': 64, + 'ffn_inner_dim': 3072, + 'hidden_dim': 768, + 'activation': 'gelu', + 'num_heads': 12, + 'num_layers': 12, + 'vocab_size': 32000, + 'untie_r': True + } + + self.assertDictEqual(model_config, expected_config) + + +if __name__ == "__main__": + tf.test.main() diff --git a/texar/tf/modules/pretrained/xlnet_utils.py b/texar/tf/modules/pretrained/xlnet_utils.py index 8bf79a06..3894f105 100644 --- a/texar/tf/modules/pretrained/xlnet_utils.py +++ b/texar/tf/modules/pretrained/xlnet_utils.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,153 +12,549 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Utility functions related to XLNet encoders. +Model Utils of XLNet Modules. +Adapted from +https://github.com/zihangdai/xlnet/blob/master/modeling.py """ from __future__ import absolute_import -from __future__ import division from __future__ import print_function - -import collections -import json -import os -import re +from __future__ import division +from __future__ import unicode_literals import tensorflow as tf -from texar.tf.modules.pretrained.pretrained_utils import default_download_dir -from texar.tf.data.data_utils import maybe_download + +from texar.tf.core import layers +from texar.tf.module_base import ModuleBase +from texar.tf.utils.mode import is_train_mode + __all__ = [ - 'init_xlnet_checkpoint', - 'load_pretrained_xlnet', - 'transform_xlnet_to_texar_config' + 'PositionWiseFF', + 'PositionalEmbedding', + 'RelativePositionalEncoding', + 'RelativeMutiheadAttention' ] -_XLNET_PATH = "https://storage.googleapis.com/xlnet/released_models/" -_MODEL2URL = { - 'xlnet-large-cased': _XLNET_PATH + "cased_L-24_H-1024_A-16.zip", - 'xlnet-base-cased': _XLNET_PATH + "cased_L-12_H-768_A-12.zip" -} +class PositionWiseFF(ModuleBase): + r"""Position Wise feed forward.""" + def __init__(self, hparams=None): + ModuleBase.__init__(self, hparams) -def _get_assignment_map_from_checkpoint(tvars, # noqa: C901 - init_checkpoint, scope_name): - """ - Compute the union of the current variables and checkpoint variables. - Because of the variable scope of the original XLNet and Texar - implementation, we need to build a assignment map to match the variables. + hidden_dim = self._hparams.hidden_dim + ffn_inner_dim = self._hparams.ffn_inner_dim + dropout = self._hparams.dropout + activation = self._hparams.activation + if activation == 'gelu': + activation = layers.gelu + + with tf.variable_scope(self.variable_scope): + tf.get_variable_scope().set_initializer( + layers.get_initializer(self._hparams.initializer)) + l1_hparams = { + "type": "Dense", + "kwargs": { + "units": ffn_inner_dim, + "activation": activation + } + } + self.linear1 = layers.get_layer(hparams=l1_hparams) + dropout_hparams = { + "type": "Dropout", + "kwargs": { + "rate": dropout + } + } + self.dropout = layers.get_layer(hparams=dropout_hparams) + l2_hparams = { + "type": "Dense", + "kwargs": { + "units": hidden_dim + } + } + self.linear2 = layers.get_layer(hparams=l2_hparams) + + @staticmethod + def default_hparams(): + r"""Returns a dictionary of hyperparameters with default values. + + .. code-block:: python + + { + "hidden_dim": 768, + "ffn_inner_dim": 3072, + "dropout": 0.1, + "activation": 'gelu' + } + + Here + + `"hidden_dim"`: int + Dimension of the layer fed as input to feed forward network + + `"ffn_inner_dim"`: int + Inner dimension of the feed forward layer + + `"dropout"`: float + Dropout rate for layers + + `"activation"`: str or callable + Activation function applied to the output of the PositionWise FF. + See :func:`~texar.tf.core.get_activation_fn` for more details. + """ + return { + "name": "ff", + "initializer": None, + "hidden_dim": 768, + "ffn_inner_dim": 3072, + "dropout": 0.1, + "activation": 'gelu', + } + + def _build(self, input, mode=None): + r"""Compute feed forward for the input. + + Args: + input: Input tensor of size `(max_time, batch_size, hidden_dim)` + mode (optional): A tensor taking value in + :tf_main:`tf.estimator.ModeKeys `, including + `TRAIN`, `EVAL`, and `PREDICT`. If `None`, dropout is + controlled by :func:`texar.tf.global_mode`. + + :returns: A tensor output of the position wise feed forward network + """ + is_training = is_train_mode(mode) + output = self.linear1(input) + output = self.dropout(output, training=is_training) + output = self.linear2(output) + output = self.dropout(output, training=is_training) + + # residual + layer norm + output = tf.contrib.layers.layer_norm( + input + output, begin_norm_axis=-1, scope=self.variable_scope, + reuse=tf.AUTO_REUSE) + + return output + + +class PositionalEmbedding(ModuleBase): + r"""Sinosoidal Positional Embedding. """ - assignment_map = {} - initialized_variable_names = {} - - name_to_variable = collections.OrderedDict() - for var in tvars: - name = var.name - m = re.match("^(.*):\\d+$", name) - if m is not None: - name = m.group(1) - name_to_variable[name] = var - - init_vars = tf.train.list_variables(init_checkpoint) - - for check_name, _ in init_vars: - check_name_scope = check_name.replace( - 'model/transformer/', scope_name + '/') - model_name = check_name_scope - if check_name.startswith('model/lm_loss/bias'): - model_name = scope_name + '/lm_loss/bias' - elif check_name.startswith('model/transformer/mask_emb'): - model_name = check_name_scope.replace( - 'mask_emb/mask_emb', 'mask_emb') - elif check_name.startswith('model/transformer/word_embedding'): - model_name = scope_name + '/word_embedder/w' - elif re.match('model/transformer/r_[r,s,w]_bias', check_name): - model_name = check_name_scope - elif re.match('model/transformer/seg_embed', check_name): - model_name = check_name_scope - elif re.match('model/transformer/layer_\\d+/rel_attn/[q,k,v,r,o]', - check_name): - model_name = check_name_scope - elif re.match('model/transformer/layer_\\d+/rel_attn/LayerNorm', - check_name): - model_name = check_name_scope.replace('LayerNorm/', '') - elif re.match('model/transformer/layer_\\d+/ff/layer_[1,2]', - check_name): - model_name = check_name_scope.replace('ff/layer_1', 'ff/dense') - if model_name == check_name_scope: - model_name = check_name_scope.replace( - 'ff/layer_2', 'ff/dense_1') - elif re.match('model/transformer/layer_\\d+/ff/LayerNorm', check_name): - model_name = check_name_scope.replace('LayerNorm/', '') - - if model_name in name_to_variable.keys(): - assignment_map[check_name] = model_name - initialized_variable_names[model_name] = 1 - initialized_variable_names[model_name + ":0"] = 1 + + # TODO(avinash) : See if this can be merged with Sinosoidal Position + # Embedder + def __init__(self, embed_dim): + ModuleBase.__init__(self) + freq_seq = tf.range(0.0, embed_dim, 2.0) + self.inv_freq = 1 / (10000 ** (freq_seq / embed_dim)) + + def _build(self, pos_seq): + r"""Compute sinosoidal positional embeddings. + + Args: + pos_seq: A 1D tensor of position sequences + + :returns: A 2D tensor of sinosoidal embeddings for the sequence. + """ + pos_seq = tf.dtypes.cast(pos_seq, dtype=self.inv_freq.dtype) + sinusoid_inp = tf.einsum('i,d->id', pos_seq, self.inv_freq) + pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1) + return pos_emb + + +class RelativePositionalEncoding(ModuleBase): + r"""Relative positional encodings.""" + def __init__(self, hparams=None): + ModuleBase.__init__(self, hparams) + self.sinusoid_embed = PositionalEmbedding(self._hparams.dim) + + @staticmethod + def default_hparams(): + r"""Returns a dictionary of hyperparameters with default values. + + .. code-block:: python + + { + "dim": 768, + "max_seq_len": 512 + } + + Here + + `"dim"`: int + Dimension size of the positional embedding + + `"max_seq_len"`: int + Maximum size of the sequence length + """ + return { + "name": "relative_positional_encoder", + "dim": 768, + "max_seq_len": 512 + } + + def _create_positional_embedding(self, start, end, step, batch_size, + clamp_len=None): + pos_seq = tf.range(start, end, step) + if clamp_len is not None: + pos_seq = tf.clip_by_value(pos_seq, -clamp_len, clamp_len) + pos_emb = self.sinusoid_embed(pos_seq) + pos_emb = pos_emb[:, None, :] + + if batch_size is not None: + pos_emb = tf.tile(pos_emb, [1, batch_size, 1]) + + return pos_emb + + def _build(self, batch_size, max_time, total_len, clamp_len=None, + attn_type='bi', bi_data=True): + r"""Compute relative positional encoding. + + Args + batch_size: int + Batch size of the input + + max_time: int + Sequence length of the input + + total_len: int + Sequence length + Memory length + + clamp_len (optional): int + Clamp all relative distances larger than clamp_len. + None means no clamping. + + attn_type (optional): str + Attention type. Supported values are `"uni"` and `"bi"`. + + bi_data (optional): bool + Whether to use bidirectional data input pipeline. Usually set to + True during pretraining and False during finetuning. + + :returns: A tensor of shape `[total_len + max_time, batch_size, dim]` + (if attn_type == `"bi"`) or of shape `[total_len, batch_size, dim]` + (if attn_type == `"uni"`) representing relative positional encoding + of the sequence. + """ + if attn_type == 'bi': + start, end = total_len, -max_time + elif attn_type == 'uni': + start, end = total_len, -1 else: - tf.logging.info('model name:{} not exist'.format(model_name)) + raise ValueError("Unknown `attn_type` {}".format(attn_type)) - return assignment_map, initialized_variable_names + if bi_data: + if batch_size % 2 != 0: + raise ValueError("`batch_size` must be an even number") + fwd_pos_embed = self._create_positional_embedding( + start, end, -1, batch_size // 2, clamp_len) + bwd_pos_embed = self._create_positional_embedding( + -start, -end, 1, batch_size // 2, clamp_len) + pos_embed = tf.concat([fwd_pos_embed, bwd_pos_embed], axis=1) + else: + pos_embed = self._create_positional_embedding( + start, end, -1, batch_size, clamp_len) + return pos_embed -def init_xlnet_checkpoint(init_checkpoint_dir, scope_name): - """ - Initialize XLnet model parameters from a checkpoint. +class RelativeMutiheadAttention(ModuleBase): + r"""Compute relative multi-head attention for XLNet encoder. + + This module computes relative multi-head attention as explained in + `Transformer-XL, (Zihang et. al)` and in `XLNet (Zhiling et. al)`. Args: - init_checkpoint_dir (str): path to the checkpoint. - scope_name: variable scope of XLNet encoder. - """ - tvars = tf.trainable_variables() - init_checkpoint = os.path.join(init_checkpoint_dir, 'xlnet_model.ckpt') - if init_checkpoint: - assignment_map, initialized_variable_names = \ - _get_assignment_map_from_checkpoint( - tvars, init_checkpoint, scope_name) - tf.train.init_from_checkpoint(init_checkpoint, assignment_map) + r_r_bias: A tensor of shape `(num_heads, head_dim)`. + The bias value added to query head while computing position based + attention score. + r_w_bias: A tensor of shape `(num_heads, head_dim)`. + The bias value added to query head while computing content based + attention score. -def load_pretrained_xlnet(pretrained_model_name, cache_dir=None): - """ - Return the directory in which the pretrained model is cached. + r_s_bias (optional): A tensor of shape `(num_heads, head_dim)`. + The bias value added to query head while computing segment based + attention score. + + segment_embed (optional): A tensor of shape `(2, num_heads, head_dim)` + if use_segments is True. Otherwise, this is set to None. + + hparams (dict or HParams, optional): Hyperparameters. Missing + hyperparameter will be set to default values. See + :meth:`default_hparams` for the hyperparameter sturcture + and default values. """ - if pretrained_model_name in _MODEL2URL: - download_path = _MODEL2URL[pretrained_model_name] - else: - raise ValueError( - "Pre-trained model not found: {}".format(pretrained_model_name)) + def __init__(self, r_r_bias, r_w_bias, r_s_bias=None, segment_embed=None, + hparams=None): + ModuleBase.__init__(self, hparams=hparams) - if cache_dir is None: - cache_dir = default_download_dir("xlnet") + self.num_heads = self._hparams.num_heads + self.head_dim = self._hparams.head_dim + hidden_dim = self._hparams.hidden_dim - file_name = download_path.split('/')[-1] - # this is required because of the way xlnet model is bundled - file_name = "xlnet_" + file_name + with tf.variable_scope(self.variable_scope): + if self._hparams.initializer: + tf.get_variable_scope().set_initializer( + layers.get_initializer(self._hparams.initializer)) - cache_path = os.path.join(cache_dir, file_name.split('.')[0]) - if not os.path.exists(cache_path): - maybe_download(download_path, cache_dir, extract=True) - else: - print("Using cached pre-trained model {} from: {}".format( - pretrained_model_name, cache_dir)) + # Official implementation creates these head variables. + # If we create dense layers instead, there would be dimension + # mismatch while loading the tensors + # TODO(avinash) : Can we reshape tensors while loading the ckpt? + self.q_head = tf.get_variable( + 'q/kernel', [hidden_dim, self.num_heads, self.head_dim]) - return cache_path + self.k_head = tf.get_variable( + 'k/kernel', [hidden_dim, self.num_heads, self.head_dim]) + self.v_head = tf.get_variable( + 'v/kernel', [hidden_dim, self.num_heads, self.head_dim]) -def transform_xlnet_to_texar_config(config_dir): - """ - Load the Json config file and transform it into Texar style configuration. - """ - config_ckpt = json.loads( - open(os.path.join(config_dir, 'xlnet_config.json')).read()) - config = dict(untie_r=config_ckpt["untie_r"], - num_layers=config_ckpt["n_layer"], - # layer - head_dim=config_ckpt["d_head"], - hidden_dim=config_ckpt["d_model"], - num_heads=config_ckpt["n_head"], - vocab_size=config_ckpt["n_token"], - activation="gelu", - ffn_inner_dim=config_ckpt["d_inner"]) - - return config + self.k_head_r = tf.get_variable( + 'r/kernel', [hidden_dim, self.num_heads, self.head_dim]) + + self.dropout = layers.get_layer(hparams={ + "type": "Dropout", + "kwargs": { + "rate": self._hparams.dropout + } + }) + + self.dropout_attn = layers.get_layer(hparams={ + "type": "Dropout", + "kwargs": { + "rate": self._hparams.attention_dropout + } + }) + + self.output_projection = tf.get_variable( + 'o/kernel', [hidden_dim, self.num_heads, self.head_dim]) + + self.r_r_bias = r_r_bias + self.r_w_bias = r_w_bias + + if self._hparams.use_segments: + self.segment_embed = segment_embed + self.r_s_bias = r_s_bias + + self.scale = 1 / (self.head_dim ** 0.5) + + @staticmethod + def default_hparams(): + r"""Returns a dictionary of hyperparameters with default values. + + .. code-block:: python + + { + "name": "rel_attn", + "initializer": None, + "num_heads": 12, + "hidden_dim": 768, + "head_dim": 64, + "dropout": 0.1, + "attention_dropout": 0.1, + "use_segments": True + } + + + + Here: + + The default parameters are values for cased XLNet-Base model. + + "initializer": dict, optional + Hyperparameters of the default initializer that initializes + variables created in this module. + See :func:`~texar.tf.core.get_initializer` for details. + + "num_heads": int + Number of heads in the attention + + "hidden_dim": int + Hidden dimension of the embeddings + + "head_dim": int + Size of the vectors after head projection. + + "dropout": float + Dropout rate for layers + + "attention_dropout": float + Dropout rate for attention layers + + "use_segments": bool + Boolean to indicate if the input has segments + + "name": str + Name of the module. + """ + return { + "name": "rel_attn", + "initializer": None, + "num_heads": 12, + "hidden_dim": 768, + "head_dim": 64, + "dropout": 0.1, + "attention_dropout": 0.1, + "use_segments": True, + } + + @staticmethod + def _rel_shift(x, klen=-1): + """Perform relative shift to form the relative attention score.""" + x_size = tf.shape(x) + + x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]]) + x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1]) + x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]]) + x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1]) + + return x + + def _compute_attention_score(self, q_head, k_head_h, v_head_h, k_head_r, + segment_mat, attn_mask=None, mode=None): + is_training = is_train_mode(mode) + + # Content based attention score. + q_head_rw = q_head + self.r_w_bias + # attn_ac: (max_time, tot_len, batch_size, n_head) + attn_ac = tf.einsum('ibnd,jbnd->ijbn', q_head_rw, k_head_h) + + # Position based attention score. + q_head_rr = q_head + self.r_r_bias + # attn_bd: (max_time, tot_len, batch_size, n_head) + attn_bd = tf.einsum('ibnd,jbnd->ijbn', q_head_rr, k_head_r) + attn_bd = self._rel_shift(attn_bd, klen=tf.shape(attn_ac)[1]) + + # Segment based attention score. + if segment_mat is None: + attn_ef = 0 + else: + q_head_rs = q_head + self.r_s_bias + attn_ef = tf.einsum( + 'ibnd,snd->ibns', q_head_rs, self.segment_embed) + attn_ef = tf.einsum('ijbs,ibns->ijbn', segment_mat, attn_ef) + + # Merge attention scores and perform masking. + # attn_score: (max_time, tot_len, batch_size, n_head) + attn_score = (attn_ac + attn_bd + attn_ef) * self.scale + if attn_mask is not None: + # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask + attn_score = attn_score - 1e30 * attn_mask + + # attention probability + attn_prob = tf.nn.softmax(attn_score, 1) + attn_prob = self.dropout_attn(attn_prob, training=is_training) + + # attention output + attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h) + + return attn_vec + + def _post_attention(self, attn_vec, mode=None): + is_training = is_train_mode(mode) + attn_out = tf.einsum('ibnd,hnd->ibh', attn_vec, self.output_projection) + attn_out = self.dropout(attn_out, training=is_training) + return attn_out + + def _build(self, states_h, pos_embed, states_g=None, segment_mat=None, + attn_mask_h=None, attn_mask_g=None, target_mapping=None, + memory=None, mode=None): + r"""Compute relative multi-head attention with relative positional + encoding. + + Args: + states_h: A content representation tensor of shape + `[max_time, batch_size, hidden_dim]` + + pos_embed: Position embedding tensor of shape + `[max_time, batch_size, hidden_dim]`. + + states_g (optional): A query representation tensor of shape + `[max_time, batch_size, hidden_dim]`. This tensor is set during + decoding. + + segment_mat (optional): A tensor of size + `[max_time, tot_len, batch_size]` indicating if tokens are in the + same seqment. A value at `(i, j, k)` of `1` indicates tokens at + `i` and `j` are not in the same sequence in batch k. + + attn_mask_h (optional): A tensor of shape + `[max_time, max_time, batch_size, 1]` Attention mask used while + computing attention score for `states_h` + + attn_mask_g (optional): A tensor of shape + `[max_time, max_time, batch_size, 1]` Attention mask used while + computing attention score for `states_g` + + target_mapping (optional): The target token mapping. Float tensor of + shape `[num_targets, max_time, batch_size]`. + A value of 1 for ``target_mapping[i, j, k]`` indicates that + the `i`-th target token (in order of permutation) in batch `k` + is the token at position `j`. + Each row ``target_mapping[i, :, k]`` can have no more than one + value of 1. + + memory (optional): Memory from previous batches. A list of length + `num_layers`, each a tensor of shape + `[mem_len, batch_size, hidden_dim]`. + + mode (optional): A tensor taking value in + :tf_main:`tf.estimator.ModeKeys `, including + `TRAIN`, `EVAL`, and `PREDICT`. If `None`, dropout is + controlled by :func:`texar.tf.global_mode`. + + :returns: Returns output states for `states_h` and `states_g` + (`states_g` is not None) + """ + batch_size = tf.shape(states_h)[1] + + if memory is not None and memory.shape.ndims > 1: + concat_input = tf.concat([memory, states_h], axis=0) + else: + concat_input = states_h + + # Content heads. + q_head_h = tf.einsum('ibh,hnd->ibnd', states_h, self.q_head) + k_head_h = tf.einsum('ibh,hnd->ibnd', concat_input, self.k_head) + v_head_h = tf.einsum('ibh,hnd->ibnd', concat_input, self.v_head) + + # Positional heads. + k_head_r = tf.einsum('ibh,hnd->ibnd', pos_embed, self.k_head_r) + + # Core attention ops. + attn_vec_h = self._compute_attention_score( + q_head_h, k_head_h, v_head_h, k_head_r, segment_mat, attn_mask_h, + mode) + + # Post attention processing. + attn_out_h = self._post_attention(attn_vec_h, mode=mode) + + output_h = tf.contrib.layers.layer_norm( + attn_out_h + states_h, begin_norm_axis=-1, + scope=self.variable_scope, reuse=tf.AUTO_REUSE) + + if states_g is not None: + q_head_g = tf.einsum('ibh,hnd->ibnd', states_g, self.q_head) + shape = tf.shape(q_head_g) + q_head_g = tf.reshape( + q_head_g, + shape=(shape[0], batch_size, self.num_heads, self.head_dim)) + if target_mapping is not None: + q_head_g = tf.einsum( + 'mbnd,mlb->lbnd', q_head_g, target_mapping) + attn_vec_g = self._compute_attention_score( + q_head_g, k_head_h, v_head_h, k_head_r, + segment_mat, attn_mask_g, mode) + if target_mapping is not None: + attn_vec_g = tf.einsum( + 'lbnd,mlb->mbnd', attn_vec_g, target_mapping) + attn_out_g = self._post_attention(attn_vec_g, mode=mode) + output_g = tf.contrib.layers.layer_norm( + attn_out_g + states_g, begin_norm_axis=-1, + scope=self.variable_scope, reuse=tf.AUTO_REUSE) + else: + output_g = None + + return output_h, output_g diff --git a/texar/tf/modules/pretrained/xlnet_utils_test.py b/texar/tf/modules/pretrained/xlnet_utils_test.py index cfaf56b0..f52f8bbc 100644 --- a/texar/tf/modules/pretrained/xlnet_utils_test.py +++ b/texar/tf/modules/pretrained/xlnet_utils_test.py @@ -1,45 +1,88 @@ """ -Unit tests for xlnet utils. +Unit tests for xlnet model utils. """ -import os - import tensorflow as tf from texar.tf.modules.pretrained.xlnet_utils import \ - load_pretrained_xlnet, transform_xlnet_to_texar_config + PositionWiseFF, RelativePositionalEncoding, RelativeMutiheadAttention -class XLNetUtilsTest(tf.test.TestCase): - r"""Tests xlnet utils. +class XLNetModelUtilsTest(tf.test.TestCase): + r"""Tests xlnet model utils. """ - def test_load_pretrained_model_AND_transform_xlnet_to_texar_config(self): - - pretrained_model_dir = load_pretrained_xlnet( - pretrained_model_name="xlnet-base-cased") - - info = list(os.walk(pretrained_model_dir)) - _, _, files = info[0] - self.assertIn('spiece.model', files) - self.assertIn('xlnet_model.ckpt.meta', files) - self.assertIn('xlnet_model.ckpt.data-00000-of-00001', files) - self.assertIn('xlnet_model.ckpt.index', files) - self.assertIn('xlnet_config.json', files) - - model_config = transform_xlnet_to_texar_config(pretrained_model_dir) - - expected_config = { - 'head_dim': 64, - 'ffn_inner_dim': 3072, - 'hidden_dim': 768, - 'activation': 'gelu', - 'num_heads': 12, - 'num_layers': 12, - 'vocab_size': 32000, - 'untie_r': True + def test_PositionWiseFF(self): + + # Case 1 + model = PositionWiseFF() + inputs = tf.random_uniform(shape=(32, model.hparams.hidden_dim)) + outputs = model(inputs) + self.assertEqual(outputs.shape, [32, model._hparams.hidden_dim]) + + # Case 2 + hparams = { + "hidden_dim": 16, + "ffn_inner_dim": 32, + "dropout": 0.1, + "activation": 'relu', } + model = PositionWiseFF(hparams=hparams) + inputs = tf.random_uniform(shape=(32, 16)) + outputs = model(inputs) + self.assertEqual(outputs.shape, [32, 16]) + + # Case 3 + hparams = { + "hidden_dim": 16, + "ffn_inner_dim": 32, + "dropout": 0.1, + "activation": 'gelu', + } + model = PositionWiseFF(hparams=hparams) + inputs = tf.random_uniform(shape=(32, 16)) + outputs = model(inputs) + self.assertEqual(outputs.shape, [32, 16]) + + def test_RelativeMultiheadAttention(self): + num_heads = 12 + head_dim = 64 + + r_r_bias = tf.random_normal(shape=(num_heads, head_dim)) + r_w_bias = tf.random_normal(shape=(num_heads, head_dim)) + + model = RelativeMutiheadAttention(r_r_bias=r_r_bias, r_w_bias=r_w_bias) + + states_h = tf.random_uniform(shape=(16, 32, model._hparams.hidden_dim)) + pos_embed = tf.random_uniform(shape=(24, 32, model._hparams.hidden_dim)) + + output_h, output_g = model(states_h=states_h, pos_embed=pos_embed) + + self.assertEqual(output_h.shape, + [16, 32, model._hparams.hidden_dim]) + self.assertEqual(output_g, None) + + def test_RelativePositionalEncoding(self): + + batch_size = 16 + max_time = 8 + total_len = 32 + + # Case 1 + model = RelativePositionalEncoding() + pos_embed = model(batch_size=batch_size, + max_time=max_time, + total_len=total_len) + self.assertEqual(pos_embed.shape, + [40, 16, model._hparams.dim]) - self.assertDictEqual(model_config, expected_config) + # Case 2 + model = RelativePositionalEncoding() + pos_embed = model(batch_size=batch_size, + max_time=max_time, + total_len=total_len, + attn_type='uni') + self.assertEqual(pos_embed.shape, + [33, 16, model._hparams.dim]) if __name__ == "__main__": diff --git a/texar/tf/modules/regressors/xlnet_regressor.py b/texar/tf/modules/regressors/xlnet_regressor.py index 29fa26f4..01766b7e 100644 --- a/texar/tf/modules/regressors/xlnet_regressor.py +++ b/texar/tf/modules/regressors/xlnet_regressor.py @@ -20,12 +20,14 @@ from __future__ import print_function import tensorflow as tf + from texar.tf.utils.mode import is_train_mode -from texar.tf.core import layers +from texar.tf.core.layers import get_layer, get_initializer from texar.tf.modules.regressors.regressor_base import RegressorBase -from texar.tf.modules import XLNetEncoder -from texar.tf.utils import utils +from texar.tf.modules.encoders.xlnet_encoder import XLNetEncoder from texar.tf.hyperparams import HParams +from texar.tf.modules.pretrained.xlnet import PretrainedXLNetMixin +from texar.tf.utils.utils import dict_fetch # pylint: disable=too-many-arguments, invalid-name, no-member, # pylint: disable=too-many-branches, too-many-locals, too-many-statements @@ -35,8 +37,10 @@ ] -class XLNetRegressor(RegressorBase): - """Regressor based on XLNet modules. +class XLNetRegressor(RegressorBase, PretrainedXLNetMixin): + """Regressor based on XLNet modules. Please see + :class:`~texar.tf.modules.PretrainedXLNetMixin` for a brief description + of XLNet. This is a combination of the :class:`~texar.tf.modules.XLNetEncoder` with a classification layer. Both step-wise classification and sequence-level @@ -45,14 +49,17 @@ class XLNetRegressor(RegressorBase): Arguments are the same as in :class:`~texar.tf.modules.XLNetEncoder`. Args: - pretrained_model_name (optional): a str with the name - of a pre-trained model to load. Currently only 'xlnet-large-cased' - is supported. If `None`, will use the model name in :attr:`hparams`. + pretrained_model_name (optional): a `str`, the name + of pre-trained model (e.g., ``xlnet-based-cased``). Please refer to + :class:`~texar.tf.modules.PretrainedXLNetMixin` for + all supported models. + If `None`, the model name in :attr:`hparams` is used. cache_dir (optional): the path to a folder in which the pre-trained models will be cached. If `None` (default), - a default directory will be used. + a default directory (``texar_data`` folder under user's home + directory) will be used. hparams (dict or HParams, optional): Hyperparameters. Missing - hyperparameter will be set to default values. See + hyperparameters will be set to default values. See :meth:`default_hparams` for the hyperparameter structure and default values. @@ -68,9 +75,9 @@ def __init__(self, with tf.variable_scope(self.variable_scope): tf.get_variable_scope().set_initializer( - layers.get_initializer(self._hparams.initializer)) + get_initializer(self._hparams.initializer)) # Creates the underlying encoder - encoder_hparams = utils.dict_fetch( + encoder_hparams = dict_fetch( hparams, XLNetEncoder.default_hparams()) if encoder_hparams is not None: encoder_hparams['name'] = "encoder" @@ -79,7 +86,7 @@ def __init__(self, cache_dir=cache_dir, hparams=encoder_hparams) if self._hparams.use_projection: - self.projection = layers.get_layer(hparams={ + self.projection = get_layer(hparams={ "type": "Dense", "kwargs": { "units": self._encoder.output_size @@ -89,7 +96,7 @@ def __init__(self, # Creates an dropout layer drop_kwargs = {"rate": self._hparams.dropout} layer_hparams = {"type": "Dropout", "kwargs": drop_kwargs} - self._dropout_layer = layers.get_layer(hparams=layer_hparams) + self._dropout_layer = get_layer(hparams=layer_hparams) logit_kwargs = self._hparams.logit_layer_kwargs if logit_kwargs is None: @@ -104,7 +111,7 @@ def __init__(self, logit_kwargs['name'] = "logit_layer" layer_hparams = {"type": "Dense", "kwargs": logit_kwargs} - self._logit_layer = layers.get_layer(hparams=layer_hparams) + self._logit_layer = get_layer(hparams=layer_hparams) @staticmethod def default_hparams(): diff --git a/texar/tf/modules/regressors/xlnet_regressor_test.py b/texar/tf/modules/regressors/xlnet_regressor_test.py index 67221328..ff96e837 100644 --- a/texar/tf/modules/regressors/xlnet_regressor_test.py +++ b/texar/tf/modules/regressors/xlnet_regressor_test.py @@ -9,10 +9,10 @@ from __future__ import unicode_literals import numpy as np - import tensorflow as tf from texar.tf.modules.regressors.xlnet_regressor import XLNetRegressor +from texar.tf.utils.test import pretrained_test # pylint: disable=too-many-locals, no-member @@ -21,6 +21,17 @@ class XLNetRegressorTest(tf.test.TestCase): """Tests :class:`~texar.tf.modules.XLNetRegressor` class. """ + @pretrained_test + def test_model_loading(self): + r"""Tests model loading functionality.""" + + inputs = tf.placeholder(dtype=tf.int32, shape=[None, None]) + + for pretrained_model_name in XLNetRegressor.available_checkpoints(): + regressor = XLNetRegressor( + pretrained_model_name=pretrained_model_name) + _ = regressor(inputs) + def test_trainable_variables(self): """Tests the functionality of automatically collecting trainable variables. @@ -28,7 +39,10 @@ def test_trainable_variables(self): inputs = tf.placeholder(dtype=tf.int32, shape=[None, None]) # case 1 - regressor = XLNetRegressor() + hparams = { + "pretrained_model_name": None, + } + regressor = XLNetRegressor(hparams=hparams) regressor(inputs) n_xlnet_vars = 162 n_projection_vars = 2 @@ -38,6 +52,7 @@ def test_trainable_variables(self): # case 2 hparams = { + "pretrained_model_name": None, "regr_strategy": "all_time" } regressor = XLNetRegressor(hparams=hparams) @@ -47,6 +62,7 @@ def test_trainable_variables(self): # case 3 hparams = { + "pretrained_model_name": None, "regr_strategy": "time_wise" } regressor = XLNetRegressor(hparams=hparams) @@ -63,7 +79,10 @@ def test_encode(self): maxval=30521, dtype=tf.int32) # case 1 - regressor = XLNetRegressor() + hparams = { + "pretrained_model_name": None, + } + regressor = XLNetRegressor(hparams=hparams) logits = regressor(inputs) with self.test_session() as sess: @@ -73,6 +92,7 @@ def test_encode(self): # case 2 hparams = { + "pretrained_model_name": None, "regr_strategy": "cls_time" } regressor = XLNetRegressor(hparams=hparams) @@ -85,6 +105,7 @@ def test_encode(self): # case 3 hparams = { + "pretrained_model_name": None, "regr_strategy": "time_wise" } regressor = XLNetRegressor(hparams=hparams) @@ -98,6 +119,7 @@ def test_encode(self): # case 4 hparams = { + "pretrained_model_name": None, "regr_strategy": "all_time", "max_seq_len": max_time } @@ -118,6 +140,7 @@ def test_regression(self): batch_size = 8 hparams = { + "pretrained_model_name": None, "regr_strategy": "cls_time" } inputs = tf.placeholder(tf.int32, shape=[batch_size, 6]) diff --git a/texar/tf/utils/test.py b/texar/tf/utils/test.py new file mode 100644 index 00000000..bd20a144 --- /dev/null +++ b/texar/tf/utils/test.py @@ -0,0 +1,30 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utils for unit tests. +""" + +import os + + +def pretrained_test(func): + r"""Tests involving pre-trained checkpoints are skipped using the + `@pretrained_test` decorator. They can be tested locally by setting the + environment variable `TEST_PRETRAINED=1`. + """ + def wrapper(*args, **kwargs): + if os.environ.get('TEST_PRETRAINED', 0) or \ + os.environ.get('TEST_ALL', 0): + return func(*args, **kwargs) + return wrapper