Skip to content

Latest commit

 

History

History
84 lines (67 loc) · 3.53 KB

README.md

File metadata and controls

84 lines (67 loc) · 3.53 KB

Punctuation

Punctuation restoration using TensorFlow

Introduction

The repository was inspired by github.com/ottokart/punctuator2. I did not want to use phyton in serving the model. So I tried to change Theano to TensorFlow. Attention and Late Fusion layers were the big challanges for me, I'm not sure they are 100% correct. I compared the results with punctuator2 and they looked very similar.

NEW: Feature representation of words

Added support of Kaldi like word features. Word hashing was originally described in Learning deep structured semantic models for web search using clickthrough data - by Po-Sen Huang et al. 2013. It allows to use nearly unlimited vocabulary instead of a shortlist for NN training. Samples added for data prepareation and training.

Requirements

conda create --name punct python=3.10 Implementation was tested with python 3.10. The training code uses modules:

pip install -r requirements
pip install tqdm

The optimization of hyper parameters:

pip install parameter-sherpa
pip install keras

Data preparation

For the initial data requirements see github.com/ottokart/punctuator2. You can configure punctuation and vocabulary size in the ptf/data/data.py. To prepare the data for the training:

python ptf/punctuator2/data.py <initialDataDir> <dataDir>

Or see the egs for sample scripts

Training

To train one model, you can configure model parameters in ptf/train.py. The taining can be performed by:

mkdir model1 && cd model1
python ../ptf/train.py <dataDir> <modelPrefix>

The trained model is saved as keras hd5 format in the working folder model1.

Optimization

There is the python script to optimize hyperparameters of a model using sherpa tool. See optimize/optimize.py. The sample to start an optimization:

mkdir optim1 && cd optim1
python ../optimize/optimize.py <dataDir> 

All models are saved in in the working folder optim1.

Prediction, error calculation

To predict punctuation for a test text:

python ptf/predict.py <testTextFile> <vocaburaly> <hd5ModelFile> <predictedOutputFile>

To evaluate error scores of the prediction:

python ptf/punctuator2/error_calculator.py <testTextFile> <predictedOutputFile>

Saving as a pure tensorflow model

During the training all models are saved in keras format. To save a model in a pure tensorflow format there is a script:

python ptf/save_as_tf.py <hd5ModelFile> <tfModelOutputDir>

Loading model with go

Sample go code on how to load the trained tensorflow model: examples/goload/loadtf.go. To compile the sample go code you need to install tensorflow library and configure LD_LIBRARY_PATH. See https://www.tensorflow.org/install/lang_go


Author

Airenas Vaičiūnas


License

Copyright © 2020, Airenas Vaičiūnas. Released under the The 3-Clause BSD License.

Also, please, see the License Ottokar Tilk.