The core building blocks are RNN Encoder-Decoder architectures and Attention mechanism.
The package was largely implemented using the latest (1.2) tf.contrib.seq2seq modules
- AttentionWrapper
- Decoder
- BasicDecoder
- BeamSearchDecoder
The package supports
- Multi-layer GRU/LSTM
- Residual connection
- Dropout
- Attention and input_feeding
- Beamsearch decoding
- Write n-best list
- NumPy >= 1.11.1
- Tensorflow >= 1.2
- June 5, 2017: Major update
- June 6, 2017: Supports batch beamsearch decoding
- June 11, 2017: Separted training / decoding
- June 22, 2017: Supports tf.1.2 (contrib.rnn -> python.ops.rnn_cell)
To preprocess raw parallel data of sample_data.src
and sample_data.trg
, simply run
cd data/
./preprocess.sh src trg sample_data ${max_seq_len}
Running the above code performs widely used preprocessing steps for Machine Translation (MT).
- Normalizing punctuation
- Tokenizing
- Bytepair encoding (# merge = 30000) (Sennrich et al., 2016)
- Cleaning sequences of length over ${max_seq_len}
- Shuffling
- Building dictionaries
To train a seq2seq model,
$ python train.py --cell_type 'lstm' \
--attention_type 'luong' \
--hidden_units 1024 \
--depth 2 \
--embedding_size 500 \
--num_encoder_symbols 30000 \
--num_decoder_symbols 30000 ...
To run the trained model for decoding,
$ python decode.py --beam_width 5 \
--decode_batch_size 30 \
--model_path $PATH_TO_A_MODEL_CHECKPOINT (e.g. model/translate.ckpt-100) \
--max_decode_step 300 \
--write_n_best False
--decode_input $PATH_TO_DECODE_INPUT
--decode_output $PATH_TO_DECODE_OUTPUT
If --beam_width=1
, greedy decoding is performed at each time-step.
Data params
--source_vocabulary
: Path to source vocabulary--target_vocabulary
: Path to target vocabulary--source_train_data
: Path to source training data--target_train_data
: Path to target training data--source_valid_data
: Path to source validation data--target_valid_data
: Path to target validation data
Network params
--cell_type
: RNN cell to use for encoder and decoder (default: lstm)--attention_type
: Attention mechanism (bahdanau, luong), (default: bahdanau)--depth
: Number of hidden units for each layer in the model (default: 2)--embedding_size
: Embedding dimensions of encoder and decoder inputs (default: 500)--num_encoder_symbols
: Source vocabulary size to use (default: 30000)--num_decoder_symbols
: Target vocabulary size to use (default: 30000)--use_residual
: Use residual connection between layers (default: True)--attn_input_feeding
: Use input feeding method in attentional decoder (Luong et al., 2015) (default: True)--use_dropout
: Use dropout in rnn cell output (default: True)--dropout_rate
: Dropout probability for cell outputs (0.0: no dropout) (default: 0.3)
Training params
--learning_rate
: Number of hidden units for each layer in the model (default: 0.0002)--max_gradient_norm
: Clip gradients to this norm (default 1.0)--batch_size
: Batch size--max_epochs
: Maximum training epochs--max_load_batches
: Maximum number of batches to prefetch at one time.--max_seq_length
: Maximum sequence length--display_freq
: Display training status every this iteration--save_freq
: Save model checkpoint every this iteration--valid_freq
: Evaluate the model every this iteration: valid_data needed--optimizer
: Optimizer for training: (adadelta, adam, rmsprop) (default: adam)--model_dir
: Path to save model checkpoints--model_name
: File name used for model checkpoints--shuffle_each_epoch
: Shuffle training dataset for each epoch (default: True)--sort_by_length
: Sort pre-fetched minibatches by their target sequence lengths (default: True)
Decoding params
--beam_width
: Beam width used in beamsearch (default: 1)--decode_batch_size
: Batch size used in decoding--max_decode_step
: Maximum time step limit in decoding (default: 500)--write_n_best
: Write beamsearch n-best list (n=beam_width) (default: False)--decode_input
: Input file path to decode--decode_output
: Output file path of decoding output
Runtime params
--allow_soft_placement
: Allow device soft placement--log_device_placement
: Log placement of ops on devices
The implementation is based on following projects:
- nematus: Theano implementation of Neural Machine Translation. Major reference of this project
- subword-nmt: Included subword-unit scripts to preprocess input data
- moses: Included preprocessing scripts to preprocess input data
- tf.seq2seq_legacy Legacy Tensorflow seq2seq tutorial
- tf_tutorial_plus: Nice tutorials for tf.contrib.seq2seq API
For any comments and feedbacks, please email me at [email protected] or open an issue here.