forked from mozilla/DeepSpeech
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflags.py
152 lines (103 loc) · 10.1 KB
/
flags.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from __future__ import absolute_import, division, print_function
import os
import absl.flags
FLAGS = absl.flags.FLAGS
def create_flags():
# Importer
# ========
f = absl.flags
f.DEFINE_string('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.')
f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
f.DEFINE_string('feature_cache', '', 'path where cached features extracted from --train_files will be saved. If empty, caching will be done in memory and no files will be written.')
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')
f.DEFINE_integer('audio_sample_rate', 16000, 'sample rate value expected by model')
# Data Augmentation
# ================
f.DEFINE_float('data_aug_features_additive', 0, 'std of the Gaussian additive noise')
f.DEFINE_float('data_aug_features_multiplicative', 0, 'std of normal distribution around 1 for multiplicative noise')
f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)')
f.DEFINE_boolean('augmentation_freq_and_time_masking', False, 'whether to use frequency and time masking augmentation')
f.DEFINE_integer('augmentation_freq_and_time_masking_freq_mask_range', 5, 'max range of masks in the frequency domain when performing freqtime-mask augmentation')
f.DEFINE_integer('augmentation_freq_and_time_masking_number_freq_masks', 3, 'number of masks in the frequency domain when performing freqtime-mask augmentation')
f.DEFINE_integer('augmentation_freq_and_time_masking_time_mask_range', 2, 'max range of masks in the time domain when performing freqtime-mask augmentation')
f.DEFINE_integer('augmentation_freq_and_time_masking_number_time_masks', 3, 'number of masks in the time domain when performing freqtime-mask augmentation')
f.DEFINE_float('augmentation_speed_up_std', 0, 'std for speeding-up tempo. If std is 0, this augmentation is not performed')
f.DEFINE_boolean('augmentation_pitch_and_tempo_scaling', False, 'whether to use spectrogram speed and tempo scaling')
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_min_pitch', 0.95, 'min value of pitch scaling')
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_pitch', 1.2, 'max value of pitch scaling')
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_tempo', 1.2, 'max vlaue of tempo scaling')
# Global Constants
# ================
f.DEFINE_integer('epochs', 75, 'how many epochs (complete runs through the train files) to train for')
f.DEFINE_float('dropout_rate', 0.05, 'dropout rate for feedforward layers')
f.DEFINE_float('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate')
f.DEFINE_float('dropout_rate3', -1.0, 'dropout rate for layer 3 - defaults to dropout_rate')
f.DEFINE_float('dropout_rate4', 0.0, 'dropout rate for layer 4 - defaults to 0.0')
f.DEFINE_float('dropout_rate5', 0.0, 'dropout rate for layer 5 - defaults to 0.0')
f.DEFINE_float('dropout_rate6', -1.0, 'dropout rate for layer 6 - defaults to dropout_rate')
f.DEFINE_float('relu_clip', 20.0, 'ReLU clipping value for non-recurrent layers')
# Adam optimizer(http://arxiv.org/abs/1412.6980) parameters
f.DEFINE_float('beta1', 0.9, 'beta 1 parameter of Adam optimizer')
f.DEFINE_float('beta2', 0.999, 'beta 2 parameter of Adam optimizer')
f.DEFINE_float('epsilon', 1e-8, 'epsilon parameter of Adam optimizer')
f.DEFINE_float('learning_rate', 0.001, 'learning rate of Adam optimizer')
# Batch sizes
f.DEFINE_integer('train_batch_size', 1, 'number of elements in a training batch')
f.DEFINE_integer('dev_batch_size', 1, 'number of elements in a validation batch')
f.DEFINE_integer('test_batch_size', 1, 'number of elements in a test batch')
f.DEFINE_integer('export_batch_size', 1, 'number of elements per batch on the exported graph')
# Performance
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
f.DEFINE_boolean('use_allow_growth', False, 'use Allow Growth flag which will allocate only required amount of GPU memory and prevent full allocation of available GPU memory')
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
f.DEFINE_string('cudnn_checkpoint', '', 'path to a checkpoint created using --use_cudnn_rnn. Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.')
# Sample limits
f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit')
f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit')
# Checkpointing
f.DEFINE_string('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
f.DEFINE_integer('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
f.DEFINE_integer('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
f.DEFINE_string('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validated checkpoint, "init" for initializing a fresh model, "auto" for trying the other options in order last > best > init')
# Exporting
f.DEFINE_string('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported')
f.DEFINE_boolean('remove_export', False, 'whether to remove old exported models')
f.DEFINE_boolean('export_tflite', False, 'export a graph ready for TF Lite engine')
f.DEFINE_integer('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
# Reporting
f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL')
f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console')
f.DEFINE_integer('report_count', 10, 'number of phrases with lowest WER(best matching) to print out during a WER report')
f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification')
f.DEFINE_string('test_output_file', '', 'path to a file to save all src/decoded/distance/loss tuples generated during a test epoch')
# Geometry
f.DEFINE_integer('n_hidden', 2048, 'layer width to use when initialising layers')
# Initialization
f.DEFINE_integer('random_seed', 4568, 'default random seed that is used to initialize variables')
# Early Stopping
f.DEFINE_boolean('early_stop', True, 'enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.')
f.DEFINE_integer('es_steps', 4, 'number of validations to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point')
f.DEFINE_float('es_mean_th', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
f.DEFINE_float('es_std_th', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
# Decoder
f.DEFINE_string('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
f.DEFINE_string('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM')
f.DEFINE_string('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie')
f.DEFINE_integer('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
f.DEFINE_float('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
f.DEFINE_float('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
# Inference mode
f.DEFINE_string('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.')
# Register validators for paths which require a file to be specified
f.register_validator('alphabet_config_path',
os.path.isfile,
message='The file pointed to by --alphabet_config_path must exist and be readable.')
f.register_validator('one_shot_infer',
lambda value: not value or os.path.isfile(value),
message='The file pointed to by --one_shot_infer must exist and be readable.')