diff --git a/README.md b/README.md
index 996fe778..d77f382e 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
# Tacotron-2:
-Tensorflow implementation of Deep mind's Tacotron-2. A deep neural network architecture described in this paper: [Natural TTS synthesis by conditioning Wavenet on MEL spectogram predictions](https://arxiv.org/pdf/1712.05884.pdf)
+Tensorflow implementation of DeepMind's Tacotron-2. A deep neural network architecture described in this paper: [Natural TTS synthesis by conditioning Wavenet on MEL spectogram predictions](https://arxiv.org/pdf/1712.05884.pdf)
# Repository Structure:
@@ -15,10 +15,20 @@ Tensorflow implementation of Deep mind's Tacotron-2. A deep neural network archi
├── LJSpeech-1.1 (0)
│ └── wavs
├── logs-Tacotron (2)
+ │ ├── eval_-dir
+ │ │ ├── plots
+ │ │ └── wavs
│ ├── mel-spectrograms
│ ├── plots
│ ├── pretrained
│ └── wavs
+ ├── logs-Wavenet (4)
+ │ ├── eval-dir
+ │ │ ├── plots
+ │ │ └── wavs
+ │ ├── plots
+ │ ├── pretrained
+ │ └── wavs
├── papers
├── tacotron
│ ├── models
@@ -30,26 +40,34 @@ Tensorflow implementation of Deep mind's Tacotron-2. A deep neural network archi
│ │ ├── plots
│ │ └── wavs
│ └── natural
+ ├── wavenet_output (5)
+ │ ├── plots
+ │ └── wavs
├── training_data (1)
│ ├── audio
- │ └── mels
+ │ ├── linear
+ │ └── mels
└── wavenet_vocoder
└── models
-
-
-The previous tree shows what the current state of the repository.
+The previous tree shows the current state of the repository (separate training, one step at a time).
- Step **(0)**: Get your dataset, here I have set the examples of **Ljspeech**, **en_US** and **en_UK** (from **M-AILABS**).
- Step **(1)**: Preprocess your data. This will give you the **training_data** folder.
- Step **(2)**: Train your Tacotron model. Yields the **logs-Tacotron** folder.
- Step **(3)**: Synthesize/Evaluate the Tacotron model. Gives the **tacotron_output** folder.
+- Step **(4)**: Train your Wavenet model. Yield the **logs-Wavenet** folder.
+- Step **(5)**: Synthesize audio using the Wavenet model. Gives the **wavenet_output** folder.
Note:
- **Our preprocessing only supports Ljspeech and Ljspeech-like datasets (M-AILABS speech data)!** If running on datasets stored differently, you will probably need to make your own preprocessing script.
- In the previous tree, files **were not represented** and **max depth was set to 3** for simplicity.
+- If you run training of both **models at the same time**, repository structure will be different.
+
+# Pretrained model and Samples:
+Pre-trained models and audio samples will be added at a later date. You can however check some primary insights of the model performance (at early stages of training) [here](https://github.com/Rayhane-mamah/Tacotron-2/issues/4#issuecomment-378741465). THIS IS VERY OUTDATED, I WILL UPDATE THIS SOON
# Model Architecture:
@@ -69,16 +87,12 @@ To have an overview of our advance on this project, please refer to [this discus
since the two parts of the global model are trained separately, we can start by training the feature prediction model to use his predictions later during the wavenet training.
# How to start
-first, you need to have python 3 installed along with [Tensorflow v1.6](https://www.tensorflow.org/install/).
+first, you need to have python 3 installed along with [Tensorflow](https://www.tensorflow.org/install/).
-next you can install the requirements. If you are an Anaconda user:
+next you can install the requirements. If you are an Anaconda user: (else replace **pip** with **pip3** and **python** with **python3**)
> pip install -r requirements.txt
-else:
-
-> pip3 install -r requirements.txt
-
# Dataset:
We tested the code above on the [ljspeech dataset](https://keithito.com/LJ-Speech-Dataset/), which has almost 24 hours of labeled single actress voice recording. (further info on the dataset are available in the README file when you download it)
@@ -86,6 +100,11 @@ We are also running current tests on the [new M-AILABS speech dataset](http://ww
After **downloading** the dataset, **extract** the compressed file, and **place the folder inside the cloned repository.**
+# Hparams setting:
+Before proceeding, you must pick the hyperparameters that suit best your needs. While it is possible to change the hyper parameters from command line during preprocessing/training, I still recommend making the changes once and for all on the **hparams.py** file directly.
+
+To pick optimal fft parameters, I have made a **griffin_lim_synthesis_tool** notebook that you can use to invert real extracted mel/linear spectrograms and choose how good your preprocessing is. All other options are well explained in the **hparams.py** and have meaningful names so that you can try multiple things with them.
+
# Preprocessing
Before running the following steps, please make sure you are inside **Tacotron-2 folder**
@@ -95,90 +114,76 @@ Preprocessing can then be started using:
> python preprocess.py
-or
-
-> python3 preprocess.py
-
dataset can be chosen using the **--dataset** argument. If using M-AILABS dataset, you need to provide the **language, voice, reader, merge_books and book arguments** for your custom need. Default is **Ljspeech**.
Example M-AILABS:
> python preprocess.py --dataset='M-AILABS' --language='en_US' --voice='female' --reader='mary_ann' --merge_books=False --book='northandsouth'
+or if you want to use all books for a single speaker:
+
+> python preprocess.py --dataset='M-AILABS' --language='en_US' --voice='female' --reader='mary_ann' --merge_books=True
+
This should take no longer than a **few minutes.**
# Training:
-Feature prediction model can be **trained** using:
+To **train both models** sequentially (one after the other):
-> python train.py --model='Tacotron'
-
-or
+> python train.py --model='Tacotron-2'
-> python3 train.py --model='Tacotron'
-checkpoints will be made each **100 steps** and stored under **logs-Tacotron folder.**
+Feature prediction model can **separately** be **trained** using:
-Naturally, **training the wavenet** is done by: (Not implemented yet)
+> python train.py --model='Tacotron'
-> python train.py --model='Wavenet'
+checkpoints will be made each **5000 steps** and stored under **logs-Tacotron folder.**
-or
+Naturally, **training the wavenet separately** is done by:
-> python3 train.py --model='Wavenet'
+> python train.py --model='WaveNet'
logs will be stored inside **logs-Wavenet**.
**Note:**
-- If model argument is not provided, training will default to Tacotron model training.
+- If model argument is not provided, training will default to Tacotron-2 model training. (both models)
+- Please refer to train arguments under [train.py](https://github.com/begeekmyfriend/Tacotron-2/blob/master/train.py) for a set of options you can use.
+- It is now possible to make wavenet preprocessing alone using **wavenet_proprocess.py**.
# Synthesis
-There are **three types** of mel spectrograms synthesis for the Spectrogram prediction network (Tacotron):
+To **synthesize audio** in an **End-to-End** (text to audio) manner (both models at work):
-- **Evaluation** (synthesis on custom sentences). This is what we'll usually use after having a full end to end model.
+> python synthesize.py --model='Tacotron-2'
-> python synthesize.py --model='Tacotron' --mode='eval'
+For the spectrogram prediction network (separately), there are **three types** of mel spectrograms synthesis:
-or
+- **Evaluation** (synthesis on custom sentences). This is what we'll usually use after having a full end to end model.
-> python3 synthesize.py --model='Tacotron' --mode='eval'
+> python synthesize.py --model='Tacotron' --mode='eval'
- **Natural synthesis** (let the model make predictions alone by feeding last decoder output to the next time step).
> python synthesize.py --model='Tacotron' --GTA=False
-or
-
-> python3 synthesize.py --model='Tacotron' --GTA=False
- **Ground Truth Aligned synthesis** (DEFAULT: the model is assisted by true labels in a teacher forcing manner). This synthesis method is used when predicting mel spectrograms used to train the wavenet vocoder. (yields better results as stated in the paper)
-> python synthesize.py --model='Tacotron'
+> python synthesize.py --model='Tacotron' --GTA=True
-or
+Synthesizing the **waveforms** conditionned on previously synthesized Mel-spectrograms (separately) can be done with:
-> python3 synthesize.py --model='Tacotron'
-
-Synthesizing the waveforms conditionned on previously synthesized Mel-spectrograms can be done with:
-
-> python synthesize.py --model='Wavenet'
-
-or
-
-> python3 synthesize.py --model='Wavenet'
+> python synthesize.py --model='WaveNet'
**Note:**
-- If model argument is not provided, synthesis will default to Tacotron model synthesis.
-- If mode argument is not provided, synthesis defaults to Ground Truth Aligned synthesis.
-
-# Pretrained model and Samples:
-Pre-trained models and audio samples will be added at a later date due to technical difficulties. You can however check some primary insights of the model performance (at early stages of training) [here](https://github.com/Rayhane-mamah/Tacotron-2/issues/4#issuecomment-378741465).
+- If model argument is not provided, synthesis will default to Tacotron-2 model synthesis. (End-to-End TTS)
+- Please refer to synthesis arguments under [synthesize.py](https://github.com/begeekmyfriend/Tacotron-2/blob/master/synthesize.py) for a set of options you can use.
# References and Resources:
-- [Tensorflow original tacotron implementation](https://github.com/keithito/tacotron)
+- [Natural TTS synthesis by conditioning Wavenet on MEL spectogram predictions](https://arxiv.org/pdf/1712.05884.pdf)
- [Original tacotron paper](https://arxiv.org/pdf/1703.10135.pdf)
- [Attention-Based Models for Speech Recognition](https://arxiv.org/pdf/1506.07503.pdf)
-- [Natural TTS synthesis by conditioning Wavenet on MEL spectogram predictions](https://arxiv.org/pdf/1712.05884.pdf)
+- [Wavenet: A generative model for raw audio](https://arxiv.org/pdf/1609.03499.pdf)
+- [Fast Wavenet](https://arxiv.org/pdf/1611.09482.pdf)
- [r9y9/wavenet_vocoder](https://github.com/r9y9/wavenet_vocoder)
+- [keithito/tacotron](https://github.com/keithito/tacotron)
-**Work in progress**
diff --git a/datasets/audio.py b/datasets/audio.py
index f1b5a6d3..321ad7de 100644
--- a/datasets/audio.py
+++ b/datasets/audio.py
@@ -1,20 +1,33 @@
import librosa
import librosa.filters
-import numpy as np
+import numpy as np
+import tensorflow as tf
from scipy import signal
-from hparams import hparams
-import tensorflow as tf
from scipy.io import wavfile
-def load_wav(path):
- return librosa.core.load(path, sr=hparams.sample_rate)[0]
+def load_wav(path, sr):
+ return librosa.core.load(path, sr=sr)[0]
-def save_wav(wav, path):
- wav *= 32767 / max(0.01, np.max(np.abs(wav)))
+def save_wav(wav, path, hparams):
+ wav = wav / np.abs(wav).max() * 0.999
+ f1 = 0.5 * 32767 / max(0.01, np.max(np.abs(wav)))
+ f2 = np.sign(wav) * np.power(np.abs(wav), 0.95)
+ wav = f1 * f2
+ wav = signal.convolve(wav, signal.firwin(hparams.num_freq, [hparams.fmin, hparams.fmax], pass_zero=False, fs=hparams.sample_rate))
#proposed by @dsmiller
wavfile.write(path, hparams.sample_rate, wav.astype(np.int16))
+def save_wavenet_wav(wav, path, sr):
+ librosa.output.write_wav(path, wav, sr=sr)
+
+def preemphasis(wav, k):
+ return signal.lfilter([1, -k], [1], wav)
+
+def inv_preemphasis(wav, k):
+ return signal.lfilter([1], [1, -k], wav)
+
+#From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py
def start_and_end_indices(quantized, silence_threshold=2):
for start in range(quantized.size):
if abs(quantized[start] - 127) > silence_threshold:
@@ -28,63 +41,137 @@ def start_and_end_indices(quantized, silence_threshold=2):
return start, end
-def trim_silence(wav):
+def trim_silence(wav, hparams):
'''Trim leading and trailing silence
- Useful for M-AILABS dataset if we choose to trim the extra 0.5 silences.
+ Useful for M-AILABS dataset if we choose to trim the extra 0.5 silence at beginning and end.
'''
- return librosa.effects.trim(wav)[0]
-
-def preemphasis(x):
- return signal.lfilter([1, -hparams.preemphasis], [1], x)
-
-def inv_preemphasis(x):
- return signal.lfilter([1], [1, -hparams.preemphasis], x)
+ #Thanks @begeekmyfriend and @lautjy for pointing out the params contradiction. These params are separate and tunable per dataset.
+ return librosa.effects.trim(wav, top_db= hparams.trim_top_db, frame_length=hparams.trim_fft_size, hop_length=hparams.trim_hop_size)[0]
-def get_hop_size():
+def get_hop_size(hparams):
hop_size = hparams.hop_size
if hop_size is None:
assert hparams.frame_shift_ms is not None
hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
return hop_size
-def melspectrogram(wav):
- D = _stft(wav)
- S = _amp_to_db(_linear_to_mel(np.abs(D))) - hparams.ref_level_db
+def linearspectrogram(wav, hparams):
+ D = _stft(preemphasis(wav, hparams.preemphasis), hparams)
+ S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db
- if hparams.mel_normalization:
- return _normalize(S)
+ if hparams.signal_normalization:
+ return _normalize(S, hparams)
return S
-
-def inv_mel_spectrogram(mel_spectrogram):
+def melspectrogram(wav, hparams):
+ D = _stft(preemphasis(wav, hparams.preemphasis), hparams)
+ S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db
+
+ if hparams.signal_normalization:
+ return _normalize(S, hparams)
+ return S
+
+def inv_linear_spectrogram(linear_spectrogram, hparams):
+ '''Converts linear spectrogram to waveform using librosa'''
+ if hparams.signal_normalization:
+ D = _denormalize(linear_spectrogram, hparams)
+ else:
+ D = linear_spectrogram
+
+ S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear
+
+ if hparams.use_lws:
+ processor = _lws_processor(hparams)
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
+ y = processor.istft(D).astype(np.float32)
+ return inv_preemphasis(y, hparams.preemphasis)
+ else:
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis)
+
+def inv_mel_spectrogram(mel_spectrogram, hparams):
'''Converts mel spectrogram to waveform using librosa'''
- if hparams.mel_normalization:
- D = _denormalize(mel_spectrogram)
+ if hparams.signal_normalization:
+ D = _denormalize(mel_spectrogram, hparams)
+ else:
+ D = mel_spectrogram
+
+ S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear
+
+ if hparams.use_lws:
+ processor = _lws_processor(hparams)
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
+ y = processor.istft(D).astype(np.float32)
+ return inv_preemphasis(y, hparams.preemphasis)
+ else:
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis)
+
+def inv_spectrogram_tensorflow(spectrogram, hparams):
+ '''Builds computational graph to convert spectrogram to waveform using TensorFlow.
+ Unlike inv_spectrogram, this does NOT invert the preemphasis. The caller should call
+ inv_preemphasis on the output after running the graph.
+ '''
+ if hparams.signal_normalization:
+ D = _denormalize_tensorflow(spectrogram, hparams)
+ else:
+ D = linear_spectrogram
+
+ S = _db_to_amp_tensorflow(D + hparams.ref_level_db)
+ return _griffin_lim_tensorflow(tf.pow(S, hparams.power), hparams)
+
+def inv_mel_spectrogram_tensorflow(mel_spectrogram, hparams):
+ '''Builds computational graph to convert mel spectrogram to waveform using TensorFlow.
+ Unlike inv_mel_spectrogram, this does NOT invert the preemphasis. The caller should call
+ inv_preemphasis on the output after running the graph.
+ '''
+ if hparams.signal_normalization:
+ D = _denormalize_tensorflow(mel_spectrogram, hparams)
else:
D = mel_spectrogram
- S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db)) # Convert back to linear
+ S = _db_to_amp_tensorflow(D + hparams.ref_level_db)
+ S = _mel_to_linear(S, hparams) # Convert back to linear
+ return _griffin_lim_tensorflow(S ** hparams.power, hparams)
- return _griffin_lim(S ** hparams.power)
+def _lws_processor(hparams):
+ import lws
+ return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech")
-def _griffin_lim(S):
+def _griffin_lim(S, hparams):
'''librosa implementation of Griffin-Lim
Based on https://github.com/librosa/librosa/issues/434
'''
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
S_complex = np.abs(S).astype(np.complex)
- y = _istft(S_complex * angles)
+ y = _istft(S_complex * angles, hparams)
for i in range(hparams.griffin_lim_iters):
- angles = np.exp(1j * np.angle(_stft(y)))
- y = _istft(S_complex * angles)
+ angles = np.exp(1j * np.angle(_stft(y, hparams)))
+ y = _istft(S_complex * angles, hparams)
return y
-def _stft(y):
- return librosa.stft(y=y, n_fft=hparams.fft_size, hop_length=get_hop_size())
+def _griffin_lim_tensorflow(S, hparams):
+ '''TensorFlow implementation of Griffin-Lim
+ Based on https://github.com/Kyubyong/tensorflow-exercises/blob/master/Audio_Processing.ipynb
+ '''
+ with tf.variable_scope('griffinlim'):
+ # TensorFlow's stft and istft operate on a batch of spectrograms; create batch of size 1
+ S = tf.expand_dims(S, 0)
+ S_complex = tf.identity(tf.cast(S, dtype=tf.complex64))
+ y = tf.contrib.signal.inverse_stft(S_complex, hparams.win_size, get_hop_size(hparams), hparams.n_fft)
+ for i in range(hparams.griffin_lim_iters):
+ est = tf.contrib.signal.stft(y, hparams.win_size, get_hop_size(hparams), hparams.n_fft)
+ angles = est / tf.cast(tf.maximum(1e-8, tf.abs(est)), tf.complex64)
+ y = tf.contrib.signal.inverse_stft(S_complex * angles, hparams.win_size, get_hop_size(hparams), hparams.n_fft)
+ return tf.squeeze(y, 0)
+
+def _stft(y, hparams):
+ if hparams.use_lws:
+ return _lws_processor(hparams).stft(y).T
+ else:
+ return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
-def _istft(y):
- return librosa.istft(y, hop_length=get_hop_size())
+def _istft(y, hparams):
+ return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
def num_frames(length, fsize, fshift):
"""Compute number of time frames of spectrogram
@@ -111,31 +198,34 @@ def pad_lr(x, fsize, fshift):
_mel_basis = None
_inv_mel_basis = None
-def _linear_to_mel(spectogram):
+def _linear_to_mel(spectogram, hparams):
global _mel_basis
if _mel_basis is None:
- _mel_basis = _build_mel_basis()
+ _mel_basis = _build_mel_basis(hparams)
return np.dot(_mel_basis, spectogram)
-def _mel_to_linear(mel_spectrogram):
+def _mel_to_linear(mel_spectrogram, hparams):
global _inv_mel_basis
if _inv_mel_basis is None:
- _inv_mel_basis = np.linalg.pinv(_build_mel_basis())
+ _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
-def _build_mel_basis():
+def _build_mel_basis(hparams):
assert hparams.fmax <= hparams.sample_rate // 2
- return librosa.filters.mel(hparams.sample_rate, hparams.fft_size, n_mels=hparams.num_mels,
+ return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
fmin=hparams.fmin, fmax=hparams.fmax)
-def _amp_to_db(x):
+def _amp_to_db(x, hparams):
min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))
def _db_to_amp(x):
return np.power(10.0, (x) * 0.05)
-def _normalize(S):
+def _db_to_amp_tensorflow(x):
+ return tf.pow(tf.ones(tf.shape(x)) * 10.0, x * 0.05)
+
+def _normalize(S, hparams):
if hparams.allow_clipping_in_normalization:
if hparams.symmetric_mels:
return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
@@ -143,17 +233,16 @@ def _normalize(S):
else:
return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
- assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
if hparams.symmetric_mels:
return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
else:
return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
-def _denormalize(D):
+def _denormalize(D, hparams):
if hparams.allow_clipping_in_normalization:
if hparams.symmetric_mels:
return (((np.clip(D, -hparams.max_abs_value,
- hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
+ hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
+ hparams.min_level_db)
else:
return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
@@ -161,4 +250,18 @@ def _denormalize(D):
if hparams.symmetric_mels:
return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
else:
- return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
\ No newline at end of file
+ return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
+
+def _denormalize_tensorflow(D, hparams):
+ if hparams.allow_clipping_in_normalization:
+ if hparams.symmetric_mels:
+ return (((tf.clip_by_value(D, -hparams.max_abs_value,
+ hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
+ + hparams.min_level_db)
+ else:
+ return ((tf.clip_by_value(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
+
+ if hparams.symmetric_mels:
+ return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
+ else:
+ return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
diff --git a/datasets/preprocessor.py b/datasets/preprocessor.py
index 3eb34950..1e4eadec 100644
--- a/datasets/preprocessor.py
+++ b/datasets/preprocessor.py
@@ -1,19 +1,19 @@
+import glob, os
from concurrent.futures import ProcessPoolExecutor
from functools import partial
+import numpy as np
from datasets import audio
-import os
-import numpy as np
-from hparams import hparams
-from wavenet_vocoder.util import mulaw_quantize, mulaw, is_mulaw, is_mulaw_quantize
-def build_from_path(input_dirs, mel_dir, wav_dir, n_jobs=12, tqdm=lambda x: x):
+def build_from_path(hparams, input_dirs, mel_dir, linear_dir, wav_dir, n_jobs=12, tqdm=lambda x: x):
"""
- Preprocesses the Lj speech dataset from a gven input path to a given output directory
+ Preprocesses the speech dataset from a gven input path to given output directories
Args:
+ - hparams: hyper parameters
- input_dir: input directory that contains the files to prerocess
- mel_dir: output directory of the preprocessed speech mel-spectrogram dataset
+ - linear_dir: output directory of the preprocessed speech linear-spectrogram dataset
- wav_dir: output directory of the preprocessed speech audio dataset
- n_jobs: Optional, number of worker process to parallelize across
- tqdm: Optional, provides a nice progress bar
@@ -22,24 +22,27 @@ def build_from_path(input_dirs, mel_dir, wav_dir, n_jobs=12, tqdm=lambda x: x):
- A list of tuple describing the train examples. this should be written to train.txt
"""
- # We use ProcessPoolExecutor to parallelize across processes, this is just for
+ # We use ProcessPoolExecutor to parallelize across processes, this is just for
# optimization purposes and it can be omited
executor = ProcessPoolExecutor(max_workers=n_jobs)
futures = []
index = 1
for input_dir in input_dirs:
- with open(os.path.join(input_dir, 'metadata.csv'), encoding='utf-8') as f:
- for line in f:
- parts = line.strip().split('|')
- wav_path = os.path.join(input_dir, 'wavs', '{}.wav'.format(parts[0]))
- text = parts[2]
- futures.append(executor.submit(partial(_process_utterance, mel_dir, wav_dir, index, wav_path, text)))
+ trn_files = glob.glob(os.path.join(input_dir, 'biaobei_48000', '*.trn'))
+ for trn in trn_files:
+ with open(trn) as f:
+ basename = trn[:-4]
+ wav_file = basename + '.wav'
+ wav_path = wav_file
+ basename = basename.split('/')[-1]
+ text = f.readline().strip()
+ futures.append(executor.submit(partial(_process_utterance, mel_dir, linear_dir, wav_dir, basename, wav_path, text, hparams)))
index += 1
return [future.result() for future in tqdm(futures) if future.result() is not None]
-def _process_utterance(mel_dir, wav_dir, index, wav_path, text):
+def _process_utterance(mel_dir, linear_dir, wav_dir, index, wav_path, text, hparams):
"""
Preprocesses a single utterance wav/text pair
@@ -48,18 +51,19 @@ def _process_utterance(mel_dir, wav_dir, index, wav_path, text):
Args:
- mel_dir: the directory to write the mel spectograms into
+ - linear_dir: the directory to write the linear spectrograms into
- wav_dir: the directory to write the preprocessed wav into
- index: the numeric index to use in the spectogram filename
- wav_path: path to the audio file containing the speech input
- text: text spoken in the input audio file
+ - hparams: hyper parameters
Returns:
- - A tuple: (mel_filename, n_frames, text)
+ - A tuple: (audio_filename, mel_filename, linear_filename, time_steps, mel_frames, linear_frames, text)
"""
-
try:
# Load the audio as numpy array
- wav = audio.load_wav(wav_path)
+ wav = audio.load_wav(wav_path, sr=hparams.sample_rate)
except FileNotFoundError: #catch missing wav exception
print('file {} present in csv metadata is not present in wav folder. skipping!'.format(
wav_path))
@@ -71,55 +75,49 @@ def _process_utterance(mel_dir, wav_dir, index, wav_path, text):
#M-AILABS extra silence specific
if hparams.trim_silence:
- wav = audio.trim_silence(wav)
-
- #Mu-law quantize
- if is_mulaw_quantize(hparams.input_type):
- #[0, quantize_channels)
- out = mulaw_quantize(wav, hparams.quantize_channels)
-
- #Trim silences
- start, end = audio.start_and_end_indices(out, hparams.silence_threshold)
- wav = wav[start: end]
- out = out[start: end]
-
- constant_values = mulaw_quantize(0, hparams.quantize_channels)
- out_dtype = np.int16
-
- elif is_mulaw(hparams.input_type):
- #[-1, 1]
- out = mulaw(wav, hparams.quantize_channels)
- constant_values = mulaw(0., hparams.quantize_channels)
- out_dtype = np.float32
-
- else:
- #[-1, 1]
- out = wav
- constant_values = 0.
- out_dtype = np.float32
+ wav = audio.trim_silence(wav, hparams)
+
+ #[-1, 1]
+ out = wav
+ constant_values = 0.
+ out_dtype = np.float32
# Compute the mel scale spectrogram from the wav
- mel_spectrogram = audio.melspectrogram(wav).astype(np.float32)
- n_frames = mel_spectrogram.shape[1]
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
+ mel_frames = mel_spectrogram.shape[1]
+
+ if mel_frames > hparams.max_mel_frames or len(text) > hparams.max_text_length:
+ return None
+
+ #Compute the linear scale spectrogram from the wav
+ linear_spectrogram = audio.linearspectrogram(wav, hparams).astype(np.float32)
+ linear_frames = linear_spectrogram.shape[1]
+
+ #sanity check
+ assert linear_frames == mel_frames
+
#Ensure time resolution adjustement between audio and mel-spectrogram
- l, r = audio.pad_lr(wav, hparams.fft_size, audio.get_hop_size())
+ fft_size = hparams.n_fft if hparams.win_size is None else hparams.win_size
+ l, r = audio.pad_lr(wav, fft_size, audio.get_hop_size(hparams))
#Zero pad for quantized signal
out = np.pad(out, (l, r), mode='constant', constant_values=constant_values)
- time_steps = len(out)
- assert time_steps >= n_frames * audio.get_hop_size()
+ assert len(out) >= mel_frames * audio.get_hop_size(hparams)
#time resolution adjustement
#ensure length of raw audio is multiple of hop size so that we can use
#transposed convolution to upsample
- out = out[:n_frames * audio.get_hop_size()]
- assert time_steps % audio.get_hop_size() == 0
+ out = out[:mel_frames * audio.get_hop_size(hparams)]
+ assert len(out) % audio.get_hop_size(hparams) == 0
+ time_steps = len(out)
# Write the spectrogram and audio to disk
- audio_filename = 'speech-audio-{:05d}.npy'.format(index)
- mel_filename = 'speech-mel-{:05d}.npy'.format(index)
- np.save(os.path.join(wav_dir, audio_filename), out.astype(out_dtype), allow_pickle=False)
+ audio_filename = 'audio-{}.npy'.format(index)
+ mel_filename = 'mel-{}.npy'.format(index)
+ linear_filename = 'linear-{}.npy'.format(index)
+ # np.save(os.path.join(wav_dir, audio_filename), out.astype(out_dtype), allow_pickle=False)
np.save(os.path.join(mel_dir, mel_filename), mel_spectrogram.T, allow_pickle=False)
+ np.save(os.path.join(linear_dir, linear_filename), linear_spectrogram.T, allow_pickle=False)
# Return a tuple describing this training example
- return (audio_filename, mel_filename, time_steps, n_frames, text)
\ No newline at end of file
+ return (audio_filename, mel_filename, linear_filename, time_steps, mel_frames, text)
diff --git a/demo_server.py b/demo_server.py
new file mode 100644
index 00000000..4e78cc7c
--- /dev/null
+++ b/demo_server.py
@@ -0,0 +1,164 @@
+import argparse
+import chardet
+import thriftpy
+import falcon
+import tensorflow as tf
+import numpy as np
+import io
+import re
+import os
+import json
+import urllib
+from datasets import audio
+from mainstay import Mainstay
+from hparams import hparams
+from infolog import log
+from tacotron.synthesizer import Synthesizer
+from wsgiref import simple_server
+from pypinyin import pinyin, lazy_pinyin, Style
+
+
+html_body = '''
Tacotron-2 Demo
+
+
+
+
+
+
+'''
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--checkpoint', default='pretrained/', help='Path to model checkpoint')
+parser.add_argument('--hparams', default='',help='Hyperparameter overrides as a comma-separated list of name=value pairs')
+parser.add_argument('--port', default=6006,help='Port of Http service')
+parser.add_argument('--host', default="localhost",help='Host of Http service')
+parser.add_argument('--name', help='Name of logging directory if the two models were trained together.')
+args = parser.parse_args()
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+checkpoint = os.path.join('logs-Tacotron', 'taco_' + args.checkpoint)
+try:
+ checkpoint_path = tf.train.get_checkpoint_state(checkpoint).model_checkpoint_path
+ log('loaded model at {}'.format(checkpoint_path))
+except:
+ raise RuntimeError('Failed to load checkpoint at {}'.format(checkpoint))
+
+synth = Synthesizer()
+modified_hp = hparams.parse(args.hparams)
+synth.load(checkpoint_path, modified_hp)
+
+class Res:
+ def on_get(self,req,res):
+ res.body = html_body
+ res.content_type = "text/html"
+
+class Syn:
+ def on_get(self,req,res):
+ if not req.params.get('text'):
+ raise falcon.HTTPBadRequest()
+ orig_chs = req.params.get('text')
+ norm_chs = chs_norm(orig_chs)
+ print(norm_chs.encode("utf-8").decode("utf-8"))
+ pys = chs_pinyin(norm_chs)
+ out = io.BytesIO()
+ wav = synth.eval(pys)
+ audio.save_wav(wav, out, hparams)
+ res.data = out.getvalue()
+ res.content_type = "audio/wav"
+
+def chs_pinyin(text):
+ pys = pinyin(text, style=Style.TONE3)
+ results = []
+ sentence = []
+ for i in range(len(pys)):
+ if pys[i][0][0] in ",、·,":
+ pys[i][0] = ','
+ elif pys[i][0][0] in ".。…":
+ pys[i][0] = '.'
+ elif pys[i][0][0] in "―――———":
+ pys[i][0] = ','
+ elif pys[i][0][0] in ";::;":
+ pys[i][0] = ','
+ elif pys[i][0][0] in "??":
+ pys[i][0] = '?'
+ elif pys[i][0][0] in "!!":
+ pys[i][0] = '!'
+ elif pys[i][0][0] in "《》()()":
+ continue
+ elif pys[i][0][0] in "“”‘’"\"\'":
+ continue
+ elif pys[i][0][0] in " /<>「」":
+ continue
+
+ sentence.append(pys[i][0])
+ if pys[i][0] in ",.;?!:":
+ results.append(' '.join(sentence))
+ sentence = []
+
+ if len(sentence) > 0:
+ results.append(' '.join(sentence))
+
+ for i, res in enumerate(results):
+ if results[i][-1] not in ",.":
+ results[i] += ' .'
+ print(res)
+
+ return results
+
+
+def chs_norm(text):
+ url = 'http://search.ximalaya.com/text-format/numberFormat/convert'
+ payload = json.dumps(list(text)).encode()
+ request = urllib.request.Request(url, payload)
+ request.add_header("Content-Type",'application/json')
+ responese = urllib.request.urlopen(request)
+ return ''.join(json.loads(responese.read().decode()))
+
+
+api = falcon.API()
+api.add_route("/",Res())
+api.add_route("/synthesize",Syn())
+log("host:{},port:{}".format(args.host,int(args.port)))
+simple_server.make_server(args.host,int(args.port),api).serve_forever()
diff --git a/griffin_lim_synthesis_tool.ipynb b/griffin_lim_synthesis_tool.ipynb
index d2faa098..cebf9006 100644
--- a/griffin_lim_synthesis_tool.ipynb
+++ b/griffin_lim_synthesis_tool.ipynb
@@ -11,10 +11,11 @@
"import numpy as np\n",
"from datasets.audio import *\n",
"import os\n",
+ "from hparams import hparams\n",
"\n",
"n_sample = 0 #Change n_steps here\n",
- "mel_folder = 'logs-Tacotron' #Or change file path\n",
- "mel_file = 'ljspeech-mel-prediction-step-{:05d}.npy'.format(n_sample) #Or file name (for other generated mels)\n",
+ "mel_folder = 'logs-Tacotron/mel-spectrograms' #Or change file path\n",
+ "mel_file = 'mel-prediction-step-{}.npy'.format(n_sample) #Or file name (for other generated mels)\n",
"out_dir = 'wav_out'\n",
"\n",
"os.makedirs(out_dir, exist_ok=True)\n",
@@ -30,9 +31,10 @@
"metadata": {},
"outputs": [],
"source": [
- "wav = inv_mel_spectrogram(mel_spectro.T) \n",
+ "wav = inv_mel_spectrogram(mel_spectro.T, hparams) \n",
"#save the wav under test__\n",
- "save_wav(wav, os.path.join(out_dir, 'tests_{}.wav'.format(mel_file.replace('/', '_').replace('\\\\', '_'))))"
+ "save_wav(wav, os.path.join(out_dir, 'test_{}.wav'.format(mel_file.replace('/', '_').replace('\\\\', '_'))),\n",
+ " sr=hparams.sample_rate)"
]
}
],
diff --git a/hparams.py b/hparams.py
index 6439f6e0..2bf66820 100644
--- a/hparams.py
+++ b/hparams.py
@@ -1,46 +1,68 @@
-import tensorflow as tf
-import numpy as np
-
+import numpy as np
+import tensorflow as tf
# Default hyperparameters
hparams = tf.contrib.training.HParams(
# Comma-separated list of cleaners to run on text prior to training and eval. For non-English
- # text, you may want to use "basic_cleaners" or "transliteration_cleaners" See TRAINING_DATA.md.
- cleaners='english_cleaners',
+ # text, you may want to use "basic_cleaners" or "transliteration_cleaners".
+ cleaners='basic_cleaners',
+ #Hardware setup (TODO: multi-GPU parallel tacotron training)
+ use_all_gpus = False, #Whether to use all GPU resources. If True, total number of available gpus will override num_gpus.
+ num_gpus = 1, #Determines the number of gpus in use
+ ###########################################################################################################################################
#Audio
- num_mels = 80,
- rescale = True,
- rescaling_max = 0.999,
- trim_silence = True,
+ num_mels = 80, #Number of mel-spectrogram channels and local conditioning dimensionality
+ num_freq = 2049, # (= n_fft / 2 + 1) only used when adding linear spectrograms post processing network
+ rescale = False, #Whether to rescale audio prior to preprocessing
+ rescaling_max = 0.999, #Rescaling value
+ trim_silence = True, #Whether to clip silence in Audio (at beginning and end of audio only, not the middle)
+ clip_mels_length = True, #For cases of OOM (Not really recommended, working on a workaround)
+ max_mel_frames = 600, #Only relevant when clip_mels_length = True
+ max_text_length = 150, #Only relevant when clip_mels_length = True
+
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
+ # It's preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
+ # Does not work if n_ffit is not multiple of hop_size!!
+ use_lws=False,
+ silence_threshold=2, #silence threshold used for sound trimming for wavenet preprocessing
#Mel spectrogram
- fft_size = 1024,
- hop_size = 256,
- sample_rate = 22050, #22050 Hz (corresponding to ljspeech dataset)
+ n_fft = 4096, #Extra window size is filled with 0 paddings to match this parameter
+ hop_size = 600, #For 22050Hz, 275 ~= 12.5 ms
+ win_size = 2400, #For 22050Hz, 1100 ~= 50 ms (If None, win_size = n_fft)
+ sample_rate = 48000, #22050 Hz (corresponding to ljspeech dataset)
frame_shift_ms = None,
+ preemphasis = 0.97, # preemphasis coefficient
+
+ #M-AILABS (and other datasets) trim params
+ trim_fft_size = 512,
+ trim_hop_size = 128,
+ trim_top_db = 60,
- #Mel spectrogram normalization/scaling and clipping
- mel_normalization = True,
- allow_clipping_in_normalization = True, #Only relevant if mel_normalization = True
+ #Mel and Linear spectrograms normalization/scaling and clipping
+ signal_normalization = True,
+ allow_clipping_in_normalization = False, #Only relevant if mel_normalization = True
symmetric_mels = True, #Whether to scale the data to be symmetric around 0
- max_abs_value = 4., #max absolute value of data. If symmetric, data will be [-max, max] else [0, max]
+ max_abs_value = 4., #max absolute value of data. If symmetric, data will be [-max, max] else [0, max]
+ normalize_for_wavenet = True, #whether to rescale to [0, 1] for wavenet.
#Limits
- min_level_db =- 100,
+ min_level_db = -120,
ref_level_db = 20,
- fmin = 125,
+ fmin = 125, #Set this to 75 if your speaker is male! if female, 125 should help taking off noise. (To test depending on dataset)
fmax = 7600,
#Griffin Lim
- power = 1.55,
+ power = 1.2,
griffin_lim_iters = 60,
-
+ ###########################################################################################################################################
#Tacotron
- outputs_per_step = 5, #number of frames to generate at each decoding step (speeds up computation and allows for higher batch size)
- stop_at_any = True, #Determines whether the decoder should stop when predicting to any frame or to all of them
+ outputs_per_step = 2, #number of frames to generate at each decoding step (speeds up computation and allows for higher batch size)
+ stop_at_any = False, #Determines whether the decoder should stop when predicting to any frame or to all of them
+ batch_norm_position = 'after', #Can be in ('before', 'after'). Determines whether we use batch norm before or after the activation function (relu). Matter for debate.
embedding_dim = 512, #dimension of embedding space
@@ -49,10 +71,18 @@
enc_conv_channels = 512, #number of encoder convolutions filters for each layer
encoder_lstm_units = 256, #number of lstm units for each direction (forward and backward)
- smoothing = False, #Whether to smooth the attention normalization function
+ smoothing = False, #Whether to smooth the attention normalization function
attention_dim = 128, #dimension of attention space
attention_filters = 32, #number of attention convolution filters
attention_kernel = (31, ), #kernel size of attention convolution
+ cumulative_weights = True, #Whether to cumulate (sum) all previous attention weights or simply feed previous weights (Recommended: True)
+
+ #Attention synthesis constraints
+ #"Monotonic" constraint forces the model to only look at the forwards attention_win_size steps.
+ #"Window" allows the model to look at attention_win_size neighbors, both forward and backward steps.
+ synthesis_constraint = False, #Whether to use attention windows constraints in synthesis only (Useful for long utterances synthesis)
+ synthesis_constraint_type = 'window', #can be in ('window', 'monotonic').
+ attention_win_size = 7, #Side of the window. Current step does not count. If mode is window and attention_win_size is not pair, the 1 extra is provided to backward part of the window.
prenet_layers = [256, 256], #number of layers and number of units of prenet
decoder_layers = 2, #number of decoder lstm layers
@@ -63,101 +93,190 @@
postnet_kernel_size = (5, ), #size of postnet convolution filters for each layer
postnet_channels = 512, #number of postnet convolution filters for each layer
- mask_encoder = False, #whether to mask encoder padding while computing attention
- impute_finished = False, #Whether to use loss mask for padded sequences
- mask_finished = False, #Whether to mask alignments beyond the (False for debug, True for style)
-
-
- #Wavenet
- # Input type:
- # 1. raw [-1, 1]
- # 2. mulaw [-1, 1]
- # 3. mulaw-quantize [0, mu]
- # If input_type is raw or mulaw, network assumes scalar input and
- # discretized mixture of logistic distributions output, otherwise one-hot
- # input and softmax output are assumed.
- # **NOTE**: if you change the one of the two parameters below, you need to
- # re-run preprocessing before training.
- # **NOTE**: scaler input (raw or mulaw) is experimental. Use it your own risk.
- input_type="mulaw-quantize",
- quantize_channels=256, # 65536 or 256
-
- silence_threshold=2,
-
- # Mixture of logistic distributions:
- log_scale_min=float(np.log(1e-14)),
-
- #TODO model params
-
+ #CBHG mel->linear postnet
+ cbhg_kernels = 8, #All kernel sizes from 1 to cbhg_kernels will be used in the convolution bank of CBHG to act as "K-grams"
+ cbhg_conv_channels = 128, #Channels of the convolution bank
+ cbhg_pool_size = 2, #pooling size of the CBHG
+ cbhg_projection = 256, #projection channels of the CBHG (1st projection, 2nd is automatically set to num_mels)
+ cbhg_projection_kernel_size = 3, #kernel_size of the CBHG projections
+ cbhg_highwaynet_layers = 4, #Number of HighwayNet layers
+ cbhg_highway_units = 128, #Number of units used in HighwayNet fully connected layers
+ cbhg_rnn_units = 128, #Number of GRU units used in bidirectional RNN of CBHG block. CBHG output is 2x rnn_units in shape
+
+ #Loss params
+ mask_encoder = False, #whether to mask encoder padding while computing attention. Set to True for better prosody but slower convergence.
+ mask_decoder = False, #Whether to use loss mask for padded sequences (if False, loss function will not be weighted, else recommended pos_weight = 20)
+ cross_entropy_pos_weight = 1, #Use class weights to reduce the stop token classes imbalance (by adding more penalty on False Negatives (FN)) (1 = disabled)
+ predict_linear = True, #Whether to add a post-processing network to the Tacotron to predict linear spectrograms (True mode Not tested!!)
+ ###########################################################################################################################################
#Tacotron Training
- tacotron_batch_size = 64, #number of training samples on each training steps
- tacotron_reg_weight = 1e-6, #regularization weight (for l2 regularization)
-
+ #Reproduction seeds
+ tacotron_random_seed = 5339, #Determines initial graph and operations (i.e: model) random state for reproducibility
+ tacotron_data_random_state = 1234, #random state for train test split repeatability
+
+ #performance parameters
+ tacotron_swap_with_cpu = False, #Whether to use cpu as support to gpu for decoder computation (Not recommended: may cause major slowdowns! Only use when critical!)
+
+ #train/test split ratios, mini-batches sizes
+ tacotron_batch_size = 48, #number of training samples on each training steps
+ #Tacotron Batch synthesis supports ~16x the training batch size (no gradients during testing).
+ #Training Tacotron with unmasked paddings makes it aware of them, which makes synthesis times different from training. We thus recommend masking the encoder.
+ tacotron_synthesis_batch_size = 1, #DO NOT MAKE THIS BIGGER THAN 1 IF YOU DIDN'T TRAIN TACOTRON WITH "mask_encoder=True"!!
+ tacotron_test_size = 0.05, #% of data to keep as test data, if None, tacotron_test_batches must be not None. (5% is enough to have a good idea about overfit)
+ tacotron_test_batches = None, #number of test batches.
+
+ #Learning rate schedule
tacotron_decay_learning_rate = True, #boolean, determines if the learning rate will follow an exponential decay
- tacotron_decay_steps = 50000, #starting point for learning rate decay (and determines the decay slope)
- tacotron_decay_rate = 0.4, #learning rate decay rate
+ tacotron_start_decay = 40000, #Step at which learning decay starts
+ tacotron_decay_steps = 40000, #Determines the learning rate decay slope (UNDER TEST)
+ tacotron_decay_rate = 0.4, #learning rate decay rate (UNDER TEST)
tacotron_initial_learning_rate = 1e-3, #starting learning rate
tacotron_final_learning_rate = 1e-5, #minimal learning rate
+ #Optimization parameters
tacotron_adam_beta1 = 0.9, #AdamOptimizer beta1 parameter
tacotron_adam_beta2 = 0.999, #AdamOptimizer beta2 parameter
- tacotron_adam_epsilon = 1e-6, #AdamOptimizer beta3 parameter
+ tacotron_adam_epsilon = 1e-6, #AdamOptimizer Epsilon parameter
+ #Regularization parameters
+ tacotron_reg_weight = 1e-6, #regularization weight (for L2 regularization)
+ tacotron_scale_regularization = False, #Whether to rescale regularization weight to adapt for outputs range (used when reg_weight is high and biasing the model)
tacotron_zoneout_rate = 0.1, #zoneout rate for all LSTM cells in the network
tacotron_dropout_rate = 0.5, #dropout rate for all convolutional layers + prenet
-
- tacotron_teacher_forcing_ratio = 1., #Value from [0., 1.], 0.=0%, 1.=100%, determines the % of times we force next decoder inputs
-
-
- #Wavenet Training TODO
-
-
-
- #Eval sentences
+ tacotron_clip_gradients = True, #whether to clip gradients
+
+ #Evaluation parameters
+ tacotron_natural_eval = False, #Whether to use 100% natural eval (to evaluate Curriculum Learning performance) or with same teacher-forcing ratio as in training (just for overfit)
+
+ #Decoder RNN learning can take be done in one of two ways:
+ # Teacher Forcing: vanilla teacher forcing (usually with ratio = 1). mode='constant'
+ # Scheduled Sampling Scheme: From Teacher-Forcing to sampling from previous outputs is function of global step. (teacher forcing ratio decay) mode='scheduled'
+ #The second approach is inspired by:
+ #Bengio et al. 2015: Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks.
+ #Can be found under: https://arxiv.org/pdf/1506.03099.pdf
+ tacotron_teacher_forcing_mode = 'constant', #Can be ('constant' or 'scheduled'). 'scheduled' mode applies a cosine teacher forcing ratio decay. (Preference: scheduled)
+ tacotron_teacher_forcing_ratio = 1., #Value from [0., 1.], 0.=0%, 1.=100%, determines the % of times we force next decoder inputs, Only relevant if mode='constant'
+ tacotron_teacher_forcing_init_ratio = 1., #initial teacher forcing ratio. Relevant if mode='scheduled'
+ tacotron_teacher_forcing_final_ratio = 0., #final teacher forcing ratio. (Set None to use alpha instead) Relevant if mode='scheduled'
+ tacotron_teacher_forcing_start_decay = 10000, #starting point of teacher forcing ratio decay. Relevant if mode='scheduled'
+ tacotron_teacher_forcing_decay_steps = 40000, #Determines the teacher forcing ratio decay slope. Relevant if mode='scheduled'
+ tacotron_teacher_forcing_decay_alpha = None, #teacher forcing ratio decay rate. Defines the final tfr as a ratio of initial tfr. Relevant if mode='scheduled'
+ ###########################################################################################################################################
+
+ #Eval sentences (if no eval file was specified, these sentences are used for eval)
sentences = [
- # From July 8, 2017 New York Times:
- 'Scientists at the CERN laboratory say they have discovered a new particle.',
- 'There\'s a way to measure the acute emotional intelligence that has never gone out of style.',
- 'President Trump met with other leaders at the Group of 20 conference.',
- 'The Senate\'s bill to repeal and replace the Affordable Care Act is now imperiled.',
- # From Google's Tacotron example page:
- 'Generative adversarial network or variational auto-encoder.',
- 'Basilar membrane and otolaryngology are not auto-correlations.',
- 'He has read the whole thing.',
- 'He reads books.',
- "Don't desert me here in the desert!",
- 'He thought it was time to present the present.',
- 'Thisss isrealy awhsome.',
- 'Punctuation sensitivity, is working.',
- 'Punctuation sensitivity is working.',
- "The buses aren't the problem, they actually provide a solution.",
- "The buses aren't the PROBLEM, they actually provide a SOLUTION.",
- "The quick brown fox jumps over the lazy dog.",
- "Does the quick brown fox jump over the lazy dog?",
- "Peter Piper picked a peck of pickled peppers. How many pickled peppers did Peter Piper pick?",
- "She sells sea-shells on the sea-shore. The shells she sells are sea-shells I'm sure.",
- "The blue lagoon is a nineteen eighty American romance adventure film.",
- "Tajima Airport serves Toyooka.",
- 'Talib Kweli confirmed to AllHipHop that he will be releasing an album in the next year.',
- #From Training data:
- 'the rest being provided with barrack beds, and in dimensions varying from thirty feet by fifteen to fifteen feet by ten.',
- 'in giltspur street compter, where he was first lodged.',
- 'a man named burnett came with his wife and took up his residence at whitchurch, hampshire, at no great distance from laverstock,',
- 'it appears that oswald had only one caller in response to all of his fpcc activities,',
- 'he relied on the absence of the strychnia.',
- 'scoggins thought it was lighter.',
- '''would, it is probable, have eventually overcome the reluctance of some of the prisoners at least,
- and would have possessed so much moral dignity''',
- '''the only purpose of this whole sentence is to evaluate the scalability of the model for very long sentences.
- This is not even a long sentence anymore, it has become an entire paragraph.
- Should I stop now? Let\'s add this last sentence in which we talk about nothing special.''',
- 'Thank you so much for your support!!'
+ # "huan2 qiu2 wang3 bao4 dao4 .",
+ # "e2 luo2 si1 wei4 xing1 wang3 shi2 yi1 ri4 bao4 dao4 cheng1 .",
+ # "ji4 nian4 di4 yi2 ci4 shi4 jie4 da4 zhan4 jie2 shu4 yi4 bai3 zhou1 nian2 qing4 zhu4 dian3 li3 zai4 ba1 li2 ju3 xing2 .",
+ # "e2 luo2 si1 zong3 tong3 pu3 jing1 he2 mei3 guo2 zong3 tong3 te4 lang3 pu3 ,",
+ # "zai4 ba1 li2 kai3 xuan2 men2 jian4 mian4 shi2 wo4 shou3 zhi4 yi4 .",
+ # "pu3 jing1 biao3 shi4 .",
+ # "tong2 mei3 guo2 zong3 tong3 te4 lang3 pu3 ,",
+ # "jin4 xing2 le hen3 hao3 de jiao1 liu2 .",
+ # "e2 luo2 si1 zong3 tong3 zhu4 shou3 you2 li3 wu1 sha1 ke1 fu1 biao3 shi4 .",
+ # "fa3 guo2 fang1 mian4 zhi2 yi4 yao1 qiu2 ,",
+ # "bu2 yao4 zai4 ba1 li2 ju3 xing2 ji4 nian4 huo2 dong4 qi1 jian1 ." ,
+ # "ju3 xing2 e2 mei3 liang3 guo2 zong3 tong3 de dan1 du2 hui4 wu4 da2 cheng2 le xie2 yi4 .",
+ # "wo3 men yi3 jing1 kai1 shi3 xie2 tiao2 ,",
+ # "e2 luo2 si1 he2 mei3 guo2 zong3 tong3 hui4 wu4 de shi2 jian1 .",
+ # "dan4 hou4 lai2 ,",
+ # "wo3 men kao3 lv4 dao4 le fa3 guo2 tong2 hang2 men de dan1 you1 he2 guan1 qie4 .",
+ # "wu1 sha1 ke1 fu1 shuo1 .",
+ # "yin1 ci3 , wo3 men yu3 mei3 guo2 dai4 biao3 men yi4 qi3 jin4 xing2 le tao3 lun4 .",
+ # "jue2 ding4 zai4 bu4 yi2 nuo4 si1 ai4 li4 si1 feng1 hui4 shang4 ,",
+ # "jin4 xing2 nei4 rong2 geng4 feng1 fu4 de dui4 hua4 .",
+ # "bao4 dao4 cheng1 .",
+ # "pu3 jing1 he2 te4 lang3 pu3 zai4 ai4 li4 she4 gong1 wu3 can1 hui4 shang4 de zuo4 wei4 an1 pai2 ,",
+ # "zai4 zui4 hou4 yi4 fen1 zhong1 jin4 xing2 le tiao2 zheng3 .",
+ # "dan4 zhe4 bing4 bu4 fang2 ai4 ta1 men jiao1 tan2 .",
+ # "sui1 ran2 dong1 dao4 zhu3 fa3 guo2 dui4 ta1 men zai4 ba1 li2 de hui4 wu4 biao3 shi4 fan3 dui4 .",
+ # "dan4 e2 mei3 ling3 dao3 ren2 reng2 ran2 biao3 shi4 .",
+ # "ta1 men xi1 wang4 zai4 ai4 li4 she4 gong1 de gong1 zuo4 wu3 can1 shang4 hui4 mian4 .",
+ # "chu1 bu4 zuo4 wei4 biao3 xian3 shi4 .",
+ # "te4 lang3 pu3 bei4 an1 pai2 zai4 pu3 jing1 pang2 bian1 .",
+ # "dan4 zai4 sui2 hou4 jin4 xing2 de gong1 zuo4 wu3 can1 qi1 jian1 .",
+ # "zuo4 wei4 an1 pai2 xian3 ran2 yi3 jing1 fa1 sheng1 le bian4 hua4 .",
+ # "cong2 zhao4 pian1 lai2 kan4 .",
+ # "pu3 jing1 dang1 shi2 zheng4 quan2 shen2 guan4 zhu4 de yu3 lian2 he2 guo2 mi4 shu1 zhang3 gu3 te4 lei2 si1 jiao1 tan2 .",
+ # "ou1 meng2 wei3 yuan2 hui4 zhu3 xi2 rong2 ke4 zuo4 zai4 pu3 jing1 de you4 bian1 .",
+ # "er2 te4 lang3 pu3 ze2 zuo4 zai4 ma3 ke4 long2 pang2 bian1 .",
+ # "ma3 ke4 long2 de you4 bian1 ze2 shi4 de2 guo2 zong3 li3 mo4 ke4 er3 .",
+ # "ci3 qian2 . pu3 jing1 zai4 fang3 wen4 ba1 li2 qi1 jian1 biao3 shi4 .",
+ # "ta1 bu4 pai2 chu2 yu3 te4 lang3 pu3 zai4 gong1 zuo4 wu3 can1 shi2 jin4 xing2 jiao1 liu2 .",
+ # "pu3 jing1 zai4 fa3 guo2 pin2 dao4 de jie2 mu4 zhong1 hui2 da2 ,",
+ # "shi4 fou3 yi3 tong2 te4 lang3 pu3 jin4 xing2 jiao1 liu2 de wen4 ti2 shi2 biao3 shi4 ,",
+ # "zan4 shi2 mei2 you3 .",
+ # "wo3 men zhi3 da3 le ge4 zhao1 hu .",
+ # "yi2 shi4 yi3 zhe4 yang4 de fang1 shi4 jin4 xing2 .",
+ # "wo3 men wu2 fa3 zai4 na4 li3 jin4 xing2 jiao1 liu2 .",
+ # "wo3 men guan1 kan4 le fa1 sheng1 de shi4 qing2 .",
+ # "dan4 xian4 zai4 hui4 you3 gong1 zuo4 wu3 can1 .",
+ # "ye3 xu3 zai4 na4 li3 .",
+ # "wo3 men hui4 jin4 xing2 jie1 chu4 .",
+ # "dan4 shi4 . wu2 lun4 ru2 he2 .",
+ # "wo3 men shang1 ding4 .",
+ # "wo3 men zai4 zhe4 li3 ,",
+ # "bu2 hui4 wei2 fan3 zhu3 ban4 guo2 de gong1 zuo4 an1 pai2 .",
+ # "gen1 ju4 ta1 men de yao1 qiu2 .",
+ # "wo3 men bu2 hui4 zai4 zhe4 li3 zu3 zhi1 ren4 he2 hui4 mian4 .",
+ # "er2 shi4 ke3 neng2 hui4 zai4 feng1 hui4 qi1 jian1 ,",
+ # "huo4 zai4 ci3 zhi1 hou4 ju3 xing2 hui4 mian4 .",
+ # "pu3 jing1 hai2 biao3 shi4 .",
+ # "e2 luo2 si1 zhun3 bei4 tong2 mei3 guo2 jin4 xing2 dui4 hua4 .",
+ # "fan3 zheng4 bu2 shi4 mo4 si1 ke1 yao4 tui4 chu1 zhong1 dao3 tiao2 yue1 .",
+
+ # "guan1 yu2 xi1 zang4 de chuan2 shuo1 you3 hen3 duo1 ,",
+ # "li4 lai2 , dou1 shi4 chao2 sheng4 zhe3 de tian1 tang2 ,",
+ # "er2 zuo4 wei2 zhong1 guo2 xi1 nan2 bian1 chui2 zhong4 de4 ,",
+ # "ye3 dou1 shi4 zhong1 guo2 ling3 tu3 bu4 ke3 fen1 ge1 de yi2 bu4 fen .",
+ # "er4 ling2 yi1 wu3 nian2 , yang1 shi4 ceng2 jing1 bo1 chu1 guo4 yi2 bu4 gao1 fen1 ji4 lu4 pian4 ,",
+ # "di4 san1 ji2",
+ # "pian4 zhong1 , tian1 gao1 di4 kuo4 de feng1 jing3 ,",
+ # "rang4 wu2 shu4 ren2 dui4 xi1 zang4 qing2 gen1 shen1 zhong4 .",
+ # "shi2 ge2 liang3 nian2 , you2 yuan2 ban1 ren2 ma3 da3 zao4 de jie3 mei4 pian1 ,",
+ # "ji2 di4 , qiao1 ran2 shang4 xian4 !",
+ # "mei3 yi4 zheng4 dou1 shi4 bi4 zhi3 , mei3 yi2 mu4 dou1 shi4 ren2 jian1 xian1 jing4 .",
+ # "zi4 ying3 pian1 bo1 chu1 zhi1 lai2 , hao3 ping2 ru2 chao2 ,",
+ # "jiu4 lian2 yi2 xiang4 yi3 yan2 jin3 chu1 ming2 de dou4 ban4 ping2 fen1 ye3 shi4 hen3 gao1 .",
+ # "zao3 zai4 er4 ling2 yi1 wu3 nian2 ,",
+ # "ta1 de di4 yi1 ji4 di4 san1 ji2 jiu4 na2 dao4 le dou4 ban4 jiu2 dian3 er4 fen1 .",
+ # "er2 rang4 ta1 yi2 xia4 na2 dao4 jiu2 dian3 wu3 fen1 de yuan2 yin1 shi4 yin1 wei4, ",
+ # "ta1 zhan3 shi4 le zai4 na4 pian4 jue2 mei3 yu3 pin2 ji2 bing4 cun2 de jing4 tu3 shang4 ,",
+ # "pu3 tong1 ren2 de zhen1 shi2 sheng1 huo2 shi4 shen2 me yang4 zi .",
+
+ "bai2 jia1 xuan1 hou4 lai2 yin2 yi3 hao2 zhuang4 de shi4 yi4 sheng1 li3 qu3 guo4 qi1 fang2 nv3 ren2 .",
+ "qu3 tou2 fang2 xi2 fu4 shi2 ta1 gang1 gang1 guo4 shi2 liu4 sui4 sheng1 ri4 .",
+ "na4 shi4 xi1 yuan2 shang4 gong3 jia1 cun1 da4 hu4 gong3 zeng1 rong2 de tou2 sheng1 nv3 ,",
+ "bi3 ta1 da4 liang3 sui4 .",
+ "ta1 zai4 wan2 quan2 wu2 zhi1 huang1 luan4 zhong1 , du4 guo4 le xin1 hun1 zhi1 ye4 ,",
+ "liu2 xia4 le yong2 yuan3 xiu1 yu2 xiang4 ren2 dao4 ji2 de ke3 xiao4 de sha3 yang4 ,",
+ "er2 zi4 ji3 que4 yong3 sheng1 nan2 yi3 wang4 ji4 .",
+ "yi4 nian2 hou4 , zhe4 ge4 nv3 ren2 si3 yu2 nan2 chan3 .",
+ "di4 er4 fang2 qu3 de shi4 nan2 yuan2 pang2 jia1 cun1 yin1 shi2 ren2 jia1 , pang2 xiu1 rui4 de nai3 gan1 nv3 er2 .",
+ "zhe4 nv3 zi3 you4 zheng4 hao3 bi3 ta1 xiao2 liang3 sui4 ,",
+ "mu2 yang4 jun4 xiu4 yan3 jing1 hu1 ling2 er .",
+ "ta1 wan2 quan2 bu4 zhi1 dao4 jia4 ren2 shi4 zen3 me hui2 shi4 ,",
+ "er2 ta1 ci3 shi2 yi3 an1 shu2 nan2 nv3 zhi1 jian1 suo2 you3 de yin3 mi4 .",
+ "ta1 kan4 zhe ta1 de xiu1 qie4 huang1 luan4 er2 xiang3 dao4 zi4 ji3 di4 yi1 ci4 de sha3 yang4 fan3 dao4 jue2 de geng4 fu4 ci4 ji .",
+ "dang1 ta1 hong1 suo1 zhe ba3 duo2 duo3 shan2 shan3 er2 you4 bu4 gan3 wei2 ao4 ta1 de xiao3 xi2 fu4 guo3 ru4 shen1 xia4 de shi2 hou4 ,",
+ "ta1 ting1 dao4 le ta1 de bu2 shi4 huan1 le4 er2 shi4 tong4 ku3 de yi4 sheng1 ku1 jiao4 .",
+ "dang1 ta1 pi2 bei4 de xie1 xi1 xia4 lai2 ,",
+ "cai2 fa1 jue2 jian1 bang3 nei4 ce4 teng2 tong4 zuan1 xin1 ,",
+ "ta1 ba3 ta1 yao3 lan4 le .",
+ "ta1 fu3 shang1 xi1 tong4 de shi2 hou4 ,",
+ "xin1 li3 jiu4 chao2 qi3 le dui4 zhe4 ge4 jiao1 guan4 de you2 dian3 ren4 xing4 de nai3 gan1 nv3 er de nao2 huo3 .",
+ "zheng4 yu4 fa1 zuo4 ,",
+ "ta1 que4 ban1 guo4 ta1 de jian1 bang3 an4 shi4 ta1 zai4 lai2 yi1 ci4 .",
+ "yi4 dang1 jing1 guo4 nan2 nv3 jian1 de di4 yi1 ci4 jiao1 huan1 ,",
+ "ta1 jiu4 bian4 de2 mei2 you3 jie2 zhi4 de ren4 xing4 .",
+ "zhe4 ge4 nv3 ren2 cong2 xia4 jiao4 ding3 zhe hong2 chou2 gai4 jin1 , jin4 ru4 bai2 jia1 men2 lou2 ,",
+ "dao4 tang3 jin4 yi2 ju4 bao2 ban3 guan1 cai tai2 chu1 zhe4 ge4 men2 lou2 ,",
+ "shi2 jian1 shang4 bu4 zu2 yi1 nian2 , shi4 hai4 lao2 bing4 si3 de .",
]
-
)
def hparams_debug_string():
- values = hparams.values()
- hp = [' %s: %s' % (name, values[name]) for name in sorted(values) if name != 'sentences']
-return 'Hyperparameters:\n' + '\n'.join(hp)
\ No newline at end of file
+ values = hparams.values()
+ hp = [' %s: %s' % (name, values[name]) for name in sorted(values) if name != 'sentences']
+ return 'Hyperparameters:\n' + '\n'.join(hp)
diff --git a/tacotron/utils/infolog.py b/infolog.py
similarity index 90%
rename from tacotron/utils/infolog.py
rename to infolog.py
index 635cf52d..cd51d725 100644
--- a/tacotron/utils/infolog.py
+++ b/infolog.py
@@ -1,10 +1,9 @@
import atexit
-from datetime import datetime
import json
-from threading import Thread
+from datetime import datetime
+from threading import Thread
from urllib.request import Request, urlopen
-
_format = '%Y-%m-%d %H:%M:%S.%f'
_file = None
_run_name = None
@@ -17,7 +16,7 @@ def init(filename, run_name, slack_url=None):
_file = open(filename, 'a')
_file = open(filename, 'a')
_file.write('\n-----------------------------------------------------------------\n')
- _file.write('Starting new training run\n')
+ _file.write('Starting new {} training run\n'.format(run_name))
_file.write('-----------------------------------------------------------------\n')
_run_name = run_name
_slack_url = slack_url
@@ -48,4 +47,4 @@ def _send_slack(msg):
}).encode())
-atexit.register(_close_logfile)
\ No newline at end of file
+atexit.register(_close_logfile)
diff --git a/preprocess.py b/preprocess.py
index d0301154..e94836ba 100644
--- a/preprocess.py
+++ b/preprocess.py
@@ -1,46 +1,54 @@
import argparse
-from multiprocessing import cpu_count
import os
-from tqdm import tqdm
+from multiprocessing import cpu_count
+
from datasets import preprocessor
from hparams import hparams
+from tqdm import tqdm
-def preprocess(args, input_folders, out_dir):
+def preprocess(args, input_folders, out_dir, hparams):
mel_dir = os.path.join(out_dir, 'mels')
wav_dir = os.path.join(out_dir, 'audio')
+ linear_dir = os.path.join(out_dir, 'linear')
os.makedirs(mel_dir, exist_ok=True)
os.makedirs(wav_dir, exist_ok=True)
- metadata = preprocessor.build_from_path(input_folders, mel_dir, wav_dir, args.n_jobs, tqdm=tqdm)
+ os.makedirs(linear_dir, exist_ok=True)
+ metadata = preprocessor.build_from_path(hparams, input_folders, mel_dir, linear_dir, wav_dir, args.n_jobs, tqdm=tqdm)
write_metadata(metadata, out_dir)
def write_metadata(metadata, out_dir):
with open(os.path.join(out_dir, 'train.txt'), 'w', encoding='utf-8') as f:
for m in metadata:
f.write('|'.join([str(x) for x in m]) + '\n')
- frames = sum([int(m[3]) for m in metadata])
- timesteps = sum([int(m[2]) for m in metadata])
+ mel_frames = sum([int(m[4]) for m in metadata])
+ timesteps = sum([int(m[3]) for m in metadata])
sr = hparams.sample_rate
hours = timesteps / sr / 3600
print('Write {} utterances, {} mel frames, {} audio timesteps, ({:.2f} hours)'.format(
- len(metadata), frames, timesteps, hours))
- print('Max input length (text chars): {}'.format(max(len(m[4]) for m in metadata)))
- print('Max mel frames length: {}'.format(max(int(m[3]) for m in metadata)))
- print('Max audio timesteps length: {}'.format(max(m[2] for m in metadata)))
+ len(metadata), mel_frames, timesteps, hours))
+ print('Max input length (text chars): {}'.format(max(len(m[5]) for m in metadata)))
+ print('Max mel frames length: {}'.format(max(int(m[4]) for m in metadata)))
+ print('Max audio timesteps length: {}'.format(max(m[3] for m in metadata)))
def norm_data(args):
+
+ merge_books = (args.merge_books=='True')
+
print('Selecting data folders..')
- supported_datasets = ['LJSpeech-1.1', 'M-AILABS']
+ supported_datasets = ['LJSpeech-1.0', 'LJSpeech-1.1', 'M-AILABS', 'THCHS-30']
if args.dataset not in supported_datasets:
raise ValueError('dataset value entered {} does not belong to supported datasets: {}'.format(
args.dataset, supported_datasets))
- if args.dataset == 'LJSpeech-1.1':
+ if args.dataset.startswith('LJSpeech'):
return [os.path.join(args.base_dir, args.dataset)]
-
+ if args.dataset.startswith('THCHS-30'):
+ return [os.path.join(args.base_dir, 'data_thchs30')]
+
if args.dataset == 'M-AILABS':
- supported_languages = ['en_US', 'en_UK', 'fr_FR', 'it_IT', 'de_DE', 'es_ES', 'ru_RU',
+ supported_languages = ['en_US', 'en_UK', 'fr_FR', 'it_IT', 'de_DE', 'es_ES', 'ru_RU',
'uk_UK', 'pl_PL', 'nl_NL', 'pt_PT', 'fi_FI', 'se_SE', 'tr_TR', 'ar_SA']
if args.language not in supported_languages:
raise ValueError('Please enter a supported language to use from M-AILABS dataset! \n{}'.format(
@@ -52,15 +60,14 @@ def norm_data(args):
supported_voices))
path = os.path.join(args.base_dir, args.language, 'by_book', args.voice)
- supported_readers = [e for e in os.listdir(path) if 'DS_Store' not in e]
+ supported_readers = [e for e in os.listdir(path) if os.path.isdir(os.path.join(path,e))]
if args.reader not in supported_readers:
raise ValueError('Please enter a valid reader for your language and voice settings! \n{}'.format(
supported_readers))
path = os.path.join(path, args.reader)
- supported_books = [e for e in os.listdir(path) if e != '.DS_Store']
-
- if args.merge_books:
+ supported_books = [e for e in os.listdir(path) if os.path.isdir(os.path.join(path,e))]
+ if merge_books:
return [os.path.join(path, book) for book in supported_books]
else:
@@ -71,29 +78,35 @@ def norm_data(args):
return [os.path.join(path, args.book)]
-def run_preprocess(args):
+def run_preprocess(args, hparams):
input_folders = norm_data(args)
output_folder = os.path.join(args.base_dir, args.output)
- preprocess(args, input_folders, output_folder)
+ preprocess(args, input_folders, output_folder, hparams)
def main():
print('initializing preprocessing..')
parser = argparse.ArgumentParser()
parser.add_argument('--base_dir', default='')
- parser.add_argument('--dataset', default='LJSpeech-1.1')
+ parser.add_argument('--hparams', default='',
+ help='Hyperparameter overrides as a comma-separated list of name=value pairs')
+ parser.add_argument('--dataset', default='THCHS-30')
parser.add_argument('--language', default='en_US')
parser.add_argument('--voice', default='female')
parser.add_argument('--reader', default='mary_ann')
- parser.add_argument('--merge_books', type=bool, default=False)
+ parser.add_argument('--merge_books', default='False')
parser.add_argument('--book', default='northandsouth')
parser.add_argument('--output', default='training_data')
parser.add_argument('--n_jobs', type=int, default=cpu_count())
args = parser.parse_args()
- run_preprocess(args)
+ modified_hp = hparams.parse(args.hparams)
+
+ assert args.merge_books in ('False', 'True')
+
+ run_preprocess(args, modified_hp)
if __name__ == '__main__':
- main()
\ No newline at end of file
+ main()
diff --git a/requirements.txt b/requirements.txt
index b899d09e..7bc12670 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,13 @@
falcon==1.2.0
inflect==0.2.5
+audioread==2.1.5
librosa==0.5.1
matplotlib==2.0.2
-numpy==1.13.0
+numpy==1.14.0
scipy==1.0.0
tqdm==4.11.2
-Unidecode==0.4.20
\ No newline at end of file
+Unidecode==0.4.20
+pyaudio==0.2.11
+sounddevice==0.3.10
+lws
+keras
\ No newline at end of file
diff --git a/synthesize.py b/synthesize.py
index c4058ce8..8cdc5893 100644
--- a/synthesize.py
+++ b/synthesize.py
@@ -1,33 +1,71 @@
import argparse
+import os
+from warnings import warn
+
+import tensorflow as tf
+
+from hparams import hparams
+from infolog import log
from tacotron.synthesize import tacotron_synthesize
+def prepare_run(args):
+ modified_hp = hparams.parse(args.hparams)
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+ run_name = args.name or args.tacotron_name or args.model
+ taco_checkpoint = os.path.join('logs-' + run_name, 'taco_' + args.checkpoint)
+
+ run_name = args.name or args.wavenet_name or args.model
+ wave_checkpoint = os.path.join('logs-' + run_name, 'wave_' + args.checkpoint)
+ return taco_checkpoint, wave_checkpoint, modified_hp
+
+def get_sentences(args):
+ if args.text_list != '':
+ with open(args.text_list) as f:
+ #sentences = list(map(lambda l: l.decode("utf-8")[:-1], f.readlines()))
+ i = 0
+ sentences = []
+ for line in f.readlines():
+ i += 1
+ if i % 2 == 0:
+ sentences.append(line.strip())
+ else:
+ sentences = hparams.sentences
+ return sentences
+
+
def main():
- accepted_modes = ['eval', 'synthesis']
+ accepted_modes = ['eval', 'synthesis', 'live']
parser = argparse.ArgumentParser()
- parser.add_argument('--checkpoint', default='logs-Tacotron/pretrained/', help='Path to model checkpoint')
+ parser.add_argument('--checkpoint', default='pretrained/', help='Path to model checkpoint')
parser.add_argument('--hparams', default='',
help='Hyperparameter overrides as a comma-separated list of name=value pairs')
+ parser.add_argument('--name', help='Name of logging directory if the two models were trained together.')
+ parser.add_argument('--tacotron_name', help='Name of logging directory of Tacotron. If trained separately')
+ parser.add_argument('--wavenet_name', help='Name of logging directory of WaveNet. If trained separately')
parser.add_argument('--model', default='Tacotron')
parser.add_argument('--input_dir', default='training_data/', help='folder to contain inputs sentences/targets')
+ parser.add_argument('--mels_dir', default='tacotron_output/eval/', help='folder to contain mels to synthesize audio from using the Wavenet')
parser.add_argument('--output_dir', default='output/', help='folder to contain synthesized mel spectrograms')
- parser.add_argument('--mode', default='synthesis', help='mode of run: can be one of {}'.format(accepted_modes))
- parser.add_argument('--GTA', default=True, help='Ground truth aligned synthesis, defaults to True, only considered in synthesis mode')
+ parser.add_argument('--mode', default='eval', help='mode of run: can be one of {}'.format(accepted_modes))
+ parser.add_argument('--GTA', default='True', help='Ground truth aligned synthesis, defaults to True, only considered in synthesis mode')
+ parser.add_argument('--text_list', default='', help='Text file contains list of texts to be synthesized. Valid if mode=eval')
+ parser.add_argument('--speaker_id', default=None, help='Defines the speakers ids to use when running standalone Wavenet on a folder of mels. this variable must be a comma-separated list of ids')
args = parser.parse_args()
-
- accepted_models = ['Tacotron', 'Wavenet']
-
- if args.model not in accepted_models:
- raise ValueError('please enter a valid model to train: {}'.format(accepted_models))
if args.mode not in accepted_modes:
raise ValueError('accepted modes are: {}, found {}'.format(accepted_modes, args.mode))
- if args.model == 'Tacotron':
- tacotron_synthesize(args)
- elif args.model == 'Wavenet':
- raise NotImplementedError('Wavenet is still a work in progress, thank you for your patience!')
+ if args.GTA not in ('True', 'False'):
+ raise ValueError('GTA option must be either True or False')
+
+ taco_checkpoint, wave_checkpoint, hparams = prepare_run(args)
+ sentences = get_sentences(args)
+
+ tacotron_synthesize(args, hparams, taco_checkpoint, sentences)
if __name__ == '__main__':
- main()
\ No newline at end of file
+ main()
diff --git a/tacotron/feeder.py b/tacotron/feeder.py
index 73c7b845..22931a1a 100644
--- a/tacotron/feeder.py
+++ b/tacotron/feeder.py
@@ -1,27 +1,17 @@
-import numpy as np
import os
import threading
import time
import traceback
-from tacotron.utils.text import text_to_sequence
-from tacotron.utils.infolog import log
-import tensorflow as tf
-from hparams import hparams
+import numpy as np
+import tensorflow as tf
+from infolog import log
+from sklearn.model_selection import train_test_split
+from tacotron.utils.text import text_to_sequence
_batches_per_group = 32
-#pad input sequences with the 0 ( _ )
-_pad = 0
-#explicitely setting the padding to a value that doesn't originally exist in the spectogram
-#to avoid any possible conflicts, without affecting the output range of the model too much
-if hparams.symmetric_mels:
- _target_pad = -(hparams.max_abs_value + .1)
-else:
- _target_pad = -0.1
-#Mark finished sequences with 1s
-_token_pad = 1.
-
-class Feeder(threading.Thread):
+
+class Feeder:
"""
Feeds batches of data into queue on a background thread.
"""
@@ -31,113 +21,214 @@ def __init__(self, coordinator, metadata_filename, hparams):
self._coord = coordinator
self._hparams = hparams
self._cleaner_names = [x.strip() for x in hparams.cleaners.split(',')]
- self._offset = 0
+ self._train_offset = 0
+ self._test_offset = 0
# Load metadata
self._mel_dir = os.path.join(os.path.dirname(metadata_filename), 'mels')
+ self._linear_dir = os.path.join(os.path.dirname(metadata_filename), 'linear')
with open(metadata_filename, encoding='utf-8') as f:
self._metadata = [line.strip().split('|') for line in f]
frame_shift_ms = hparams.hop_size / hparams.sample_rate
- hours = sum([int(x[3]) for x in self._metadata]) * frame_shift_ms / (3600)
+ hours = sum([int(x[4]) for x in self._metadata]) * frame_shift_ms / (3600)
log('Loaded metadata for {} examples ({:.2f} hours)'.format(len(self._metadata), hours))
- # Create placeholders for inputs and targets. Don't specify batch size because we want
- # to be able to feed different batch sizes at eval time.
- self._placeholders = [
- tf.placeholder(tf.int32, shape=(None, None), name='inputs'),
- tf.placeholder(tf.int32, shape=(None, ), name='input_lengths'),
- tf.placeholder(tf.float32, shape=(None, None, hparams.num_mels), name='mel_targets'),
- tf.placeholder(tf.float32, shape=(None, None), name='token_targets'),
- ]
-
- # Create queue for buffering data
- queue = tf.FIFOQueue(8, [tf.int32, tf.int32, tf.float32, tf.float32], name='input_queue')
- self._enqueue_op = queue.enqueue(self._placeholders)
- self.inputs, self.input_lengths, self.mel_targets, self.token_targets = queue.dequeue()
- self.inputs.set_shape(self._placeholders[0].shape)
- self.input_lengths.set_shape(self._placeholders[1].shape)
- self.mel_targets.set_shape(self._placeholders[2].shape)
- self.token_targets.set_shape(self._placeholders[3].shape)
-
- def start_in_session(self, session):
+ #Train test split
+ if hparams.tacotron_test_size is None:
+ assert hparams.tacotron_test_batches is not None
+
+ test_size = (hparams.tacotron_test_size if hparams.tacotron_test_size is not None
+ else hparams.tacotron_test_batches * hparams.tacotron_batch_size)
+ indices = np.arange(len(self._metadata))
+ train_indices, test_indices = train_test_split(indices,
+ test_size=test_size, random_state=hparams.tacotron_data_random_state)
+
+ #Make sure test_indices is a multiple of batch_size else round up
+ len_test_indices = self._round_down(len(test_indices), hparams.tacotron_batch_size)
+ extra_test = test_indices[len_test_indices:]
+ test_indices = test_indices[:len_test_indices]
+ train_indices = np.concatenate([train_indices, extra_test])
+
+ self._train_meta = list(np.array(self._metadata)[train_indices])
+ self._test_meta = list(np.array(self._metadata)[test_indices])
+
+ self.test_steps = len(self._test_meta) // hparams.tacotron_batch_size
+
+ if hparams.tacotron_test_size is None:
+ assert hparams.tacotron_test_batches == self.test_steps
+
+ #pad input sequences with the 0 ( _ )
+ self._pad = 0
+ #explicitely setting the padding to a value that doesn't originally exist in the spectogram
+ #to avoid any possible conflicts, without affecting the output range of the model too much
+ if hparams.symmetric_mels:
+ self._target_pad = -(hparams.max_abs_value + .1)
+ else:
+ self._target_pad = -0.1
+ #Mark finished sequences with 1s
+ self._token_pad = 1.
+
+ with tf.device('/cpu:0'):
+ # Create placeholders for inputs and targets. Don't specify batch size because we want
+ # to be able to feed different batch sizes at eval time.
+ self._placeholders = [
+ tf.placeholder(tf.int32, shape=(None, None), name='inputs'),
+ tf.placeholder(tf.int32, shape=(None, ), name='input_lengths'),
+ tf.placeholder(tf.float32, shape=(None, None, hparams.num_mels), name='mel_targets'),
+ tf.placeholder(tf.float32, shape=(None, None), name='token_targets'),
+ tf.placeholder(tf.float32, shape=(None, None, hparams.num_freq), name='linear_targets'),
+ tf.placeholder(tf.int32, shape=(None, ), name='targets_lengths'),
+ ]
+
+ # Create queue for buffering data
+ queue = tf.FIFOQueue(8, [tf.int32, tf.int32, tf.float32, tf.float32, tf.float32, tf.int32], name='input_queue')
+ self._enqueue_op = queue.enqueue(self._placeholders)
+ self.inputs, self.input_lengths, self.mel_targets, self.token_targets, self.linear_targets, self.targets_lengths = queue.dequeue()
+
+ self.inputs.set_shape(self._placeholders[0].shape)
+ self.input_lengths.set_shape(self._placeholders[1].shape)
+ self.mel_targets.set_shape(self._placeholders[2].shape)
+ self.token_targets.set_shape(self._placeholders[3].shape)
+ self.linear_targets.set_shape(self._placeholders[4].shape)
+ self.targets_lengths.set_shape(self._placeholders[5].shape)
+
+ # Create eval queue for buffering eval data
+ eval_queue = tf.FIFOQueue(1, [tf.int32, tf.int32, tf.float32, tf.float32, tf.float32, tf.int32], name='eval_queue')
+ self._eval_enqueue_op = eval_queue.enqueue(self._placeholders)
+ self.eval_inputs, self.eval_input_lengths, self.eval_mel_targets, self.eval_token_targets, \
+ self.eval_linear_targets, self.eval_targets_lengths = eval_queue.dequeue()
+
+ self.eval_inputs.set_shape(self._placeholders[0].shape)
+ self.eval_input_lengths.set_shape(self._placeholders[1].shape)
+ self.eval_mel_targets.set_shape(self._placeholders[2].shape)
+ self.eval_token_targets.set_shape(self._placeholders[3].shape)
+ self.eval_linear_targets.set_shape(self._placeholders[4].shape)
+ self.eval_targets_lengths.set_shape(self._placeholders[5].shape)
+
+ def start_threads(self, session):
self._session = session
- self.start()
+ thread = threading.Thread(name='background', target=self._enqueue_next_train_group)
+ thread.daemon = True #Thread will close when parent quits
+ thread.start()
+
+ thread = threading.Thread(name='background', target=self._enqueue_next_test_group)
+ thread.daemon = True #Thread will close when parent quits
+ thread.start()
+
+ def _get_test_groups(self):
+ meta = self._test_meta[self._test_offset]
+ self._test_offset += 1
+
+ text = meta[5]
- def run(self):
- try:
- while not self._coord.should_stop():
- self._enqueue_next_group()
- except Exception as e:
- traceback.print_exc()
- self._coord.request_stop(e)
+ input_data = np.asarray(text_to_sequence(text, self._cleaner_names), dtype=np.int32)
+ mel_target = np.load(os.path.join(self._mel_dir, meta[1]))
+ #Create parallel sequences containing zeros to represent a non finished sequence
+ token_target = np.asarray([0.] * (len(mel_target) - 1))
+ linear_target = np.load(os.path.join(self._linear_dir, meta[2]))
+ return (input_data, mel_target, token_target, linear_target, len(mel_target))
- def _enqueue_next_group(self):
+ def make_test_batches(self):
start = time.time()
# Read a group of examples
n = self._hparams.tacotron_batch_size
r = self._hparams.outputs_per_step
- examples = [self._get_next_example() for i in range(n * _batches_per_group)]
+
+ #Test on entire test set
+ examples = [self._get_test_groups() for i in range(len(self._test_meta))]
# Bucket examples based on similar output sequence length for efficiency
examples.sort(key=lambda x: x[-1])
batches = [examples[i: i+n] for i in range(0, len(examples), n)]
np.random.shuffle(batches)
- log('\nGenerated {} batches of size {} in {:.3f} sec'.format(len(batches), n, time.time() - start))
- for batch in batches:
- feed_dict = dict(zip(self._placeholders, _prepare_batch(batch, r)))
- self._session.run(self._enqueue_op, feed_dict=feed_dict)
-
- def _get_next_example(self):
- """
- Gets a single example (input, mel_target, token_target) from disk
- """
- if self._offset >= len(self._metadata):
- self._offset = 0
- np.random.shuffle(self._metadata)
- meta = self._metadata[self._offset]
- self._offset += 1
-
- text = meta[4]
-
- input_data = np.asarray(text_to_sequence(text, self._cleaner_names), dtype=np.int32)
- mel_target = np.load(os.path.join(self._mel_dir, meta[1]))
- #Create parallel sequences containing zeros to represent a non finished sequence
- token_target = np.asarray([0.] * len(mel_target))
- return (input_data, mel_target, token_target, len(mel_target))
+ log('\nGenerated {} test batches of size {} in {:.3f} sec'.format(len(batches), n, time.time() - start))
+ return batches, r
+ def _enqueue_next_train_group(self):
+ while not self._coord.should_stop():
+ start = time.time()
-def _prepare_batch(batch, outputs_per_step):
- np.random.shuffle(batch)
- inputs = _prepare_inputs([x[0] for x in batch])
- input_lengths = np.asarray([len(x[0]) for x in batch], dtype=np.int32)
- mel_targets = _prepare_targets([x[1] for x in batch], outputs_per_step)
- #Pad sequences with 1 to infer that the sequence is done
- token_targets = _prepare_token_targets([x[2] for x in batch], outputs_per_step)
- return (inputs, input_lengths, mel_targets, token_targets)
+ # Read a group of examples
+ n = self._hparams.tacotron_batch_size
+ r = self._hparams.outputs_per_step
+ examples = [self._get_next_example() for i in range(n * _batches_per_group)]
-def _prepare_inputs(inputs):
- max_len = max([len(x) for x in inputs])
- return np.stack([_pad_input(x, max_len) for x in inputs])
+ # Bucket examples based on similar output sequence length for efficiency
+ examples.sort(key=lambda x: x[-1])
+ batches = [examples[i: i+n] for i in range(0, len(examples), n)]
+ np.random.shuffle(batches)
-def _prepare_targets(targets, alignment):
- max_len = max([len(t) for t in targets]) + 1
- return np.stack([_pad_target(t, _round_up(max_len, alignment)) for t in targets])
+ log('\nGenerated {} train batches of size {} in {:.3f} sec'.format(len(batches), n, time.time() - start))
+ for batch in batches:
+ feed_dict = dict(zip(self._placeholders, self._prepare_batch(batch, r)))
+ self._session.run(self._enqueue_op, feed_dict=feed_dict)
-def _prepare_token_targets(targets, alignment):
- max_len = max([len(t) for t in targets]) + 1
- return np.stack([_pad_token_target(t, _round_up(max_len, alignment)) for t in targets])
+ def _enqueue_next_test_group(self):
+ #Create test batches once and evaluate on them for all test steps
+ test_batches, r = self.make_test_batches()
+ while not self._coord.should_stop():
+ for batch in test_batches:
+ feed_dict = dict(zip(self._placeholders, self._prepare_batch(batch, r)))
+ self._session.run(self._eval_enqueue_op, feed_dict=feed_dict)
-def _pad_input(x, length):
- return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
+ def _get_next_example(self):
+ """Gets a single example (input, mel_target, token_target, linear_target, mel_length) from_ disk
+ """
+ if self._train_offset >= len(self._train_meta):
+ self._train_offset = 0
+ np.random.shuffle(self._train_meta)
-def _pad_target(t, length):
- return np.pad(t, [(0, length - t.shape[0]), (0, 0)], mode='constant', constant_values=_target_pad)
+ meta = self._train_meta[self._train_offset]
+ self._train_offset += 1
-def _pad_token_target(t, length):
- return np.pad(t, (0, length - t.shape[0]), mode='constant', constant_values=_token_pad)
+ text = meta[5]
-def _round_up(x, multiple):
- remainder = x % multiple
- return x if remainder == 0 else x + multiple - remainder
+ input_data = np.asarray(text_to_sequence(text, self._cleaner_names), dtype=np.int32)
+ mel_target = np.load(os.path.join(self._mel_dir, meta[1]))
+ #Create parallel sequences containing zeros to represent a non finished sequence
+ token_target = np.asarray([0.] * (len(mel_target) - 1))
+ linear_target = np.load(os.path.join(self._linear_dir, meta[2]))
+ return (input_data, mel_target, token_target, linear_target, len(mel_target))
+
+
+ def _prepare_batch(self, batch, outputs_per_step):
+ np.random.shuffle(batch)
+ inputs = self._prepare_inputs([x[0] for x in batch])
+ input_lengths = np.asarray([len(x[0]) for x in batch], dtype=np.int32)
+ mel_targets = self._prepare_targets([x[1] for x in batch], outputs_per_step)
+ #Pad sequences with 1 to infer that the sequence is done
+ token_targets = self._prepare_token_targets([x[2] for x in batch], outputs_per_step)
+ linear_targets = self._prepare_targets([x[3] for x in batch], outputs_per_step)
+ targets_lengths = np.asarray([x[-1] for x in batch], dtype=np.int32) #Used to mask loss
+ return (inputs, input_lengths, mel_targets, token_targets, linear_targets, targets_lengths)
+
+ def _prepare_inputs(self, inputs):
+ max_len = max([len(x) for x in inputs])
+ return np.stack([self._pad_input(x, max_len) for x in inputs])
+
+ def _prepare_targets(self, targets, alignment):
+ max_len = max([len(t) for t in targets])
+ return np.stack([self._pad_target(t, self._round_up(max_len, alignment)) for t in targets])
+
+ def _prepare_token_targets(self, targets, alignment):
+ max_len = max([len(t) for t in targets]) + 1
+ return np.stack([self._pad_token_target(t, self._round_up(max_len, alignment)) for t in targets])
+
+ def _pad_input(self, x, length):
+ return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=self._pad)
+
+ def _pad_target(self, t, length):
+ return np.pad(t, [(0, length - t.shape[0]), (0, 0)], mode='constant', constant_values=self._target_pad)
+
+ def _pad_token_target(self, t, length):
+ return np.pad(t, (0, length - t.shape[0]), mode='constant', constant_values=self._token_pad)
+
+ def _round_up(self, x, multiple):
+ remainder = x % multiple
+ return x if remainder == 0 else x + multiple - remainder
+
+ def _round_down(self, x, multiple):
+ remainder = x % multiple
+ return x if remainder == 0 else x - remainder
diff --git a/tacotron/models/Architecture_wrappers.py b/tacotron/models/Architecture_wrappers.py
index 8cd34777..fa0ed52f 100644
--- a/tacotron/models/Architecture_wrappers.py
+++ b/tacotron/models/Architecture_wrappers.py
@@ -2,17 +2,14 @@
All notations and variable names were used in concordance with originial tensorflow implementation
"""
import collections
+
import numpy as np
import tensorflow as tf
+from tacotron.models.attention import _compute_attention
from tensorflow.contrib.rnn import RNNCell
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import rnn_cell_impl
-from tensorflow.python.ops import check_ops
+from tensorflow.python.framework import ops, tensor_shape
+from tensorflow.python.ops import array_ops, check_ops, rnn_cell_impl, tensor_array_ops
from tensorflow.python.util import nest
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import tensor_array_ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import _compute_attention
_zero_state_tensors = rnn_cell_impl._zero_state_tensors
@@ -36,7 +33,7 @@ def __init__(self, convolutional_layers, lstm_layer):
self._convolutions = convolutional_layers
self._cell = lstm_layer
- def __call__(self, inputs, input_lengths):
+ def __call__(self, inputs, input_lengths=None):
#Pass input sequence through a stack of convolutional layers
conv_output = self._convolutions(inputs)
@@ -51,7 +48,7 @@ def __call__(self, inputs, input_lengths):
class TacotronDecoderCellState(
collections.namedtuple("TacotronDecoderCellState",
("cell_state", "attention", "time", "alignments",
- "alignment_history", "finished"))):
+ "alignment_history", "max_attentions"))):
"""`namedtuple` storing the state of a `TacotronDecoderCell`.
Contains:
- `cell_state`: The state of the wrapped `RNNCell` at the previous time
@@ -83,20 +80,20 @@ class TacotronDecoderCell(RNNCell):
* : This is typically taking a vanilla LSTM, wrapping it using tensorflow's attention wrapper,
and wrap that with the prenet before doing an input feeding, and with the prediction layer
- that uses RNN states to project on output space. Actions marked with (*) can be replaced with
+ that uses RNN states to project on output space. Actions marked with (*) can be replaced with
tensorflow's attention wrapper call if it was using cumulative alignments instead of previous alignments only.
"""
- def __init__(self, prenet, attention_mechanism, rnn_cell, frame_projection, stop_projection, mask_finished=False):
+ def __init__(self, prenet, attention_mechanism, rnn_cell, frame_projection, stop_projection):
"""Initialize decoder parameters
Args:
prenet: A tensorflow fully connected layer acting as the decoder pre-net
- attention_mechanism: A _BaseAttentionMechanism instance, usefull to
+ attention_mechanism: A _BaseAttentionMechanism instance, usefull to
learn encoder-decoder alignments
rnn_cell: Instance of RNNCell, main body of the decoder
frame_projection: tensorflow fully connected layer with r * num_mels output units
- stop_projection: tensorflow fully connected layer, expected to project to a scalar
+ stop_projection: tensorflow fully connected layer, expected to project to a scalar
and through a sigmoid activation
mask_finished: Boolean, Whether to mask decoder frames after the
"""
@@ -108,7 +105,6 @@ def __init__(self, prenet, attention_mechanism, rnn_cell, frame_projection, stop
self._frame_projection = frame_projection
self._stop_projection = stop_projection
- self._mask_finished = mask_finished
self._attention_layer_size = self._attention_mechanism.values.get_shape()[-1].value
def _batch_size_checks(self, batch_size, error_message):
@@ -133,11 +129,11 @@ def state_size(self):
attention=self._attention_layer_size,
alignments=self._attention_mechanism.alignments_size,
alignment_history=(),
- finished=())
+ max_attentions=())
def zero_state(self, batch_size, dtype):
"""Return an initial (zero) state tuple for this `AttentionWrapper`.
-
+
Args:
batch_size: `0D` integer tensor: the batch size.
dtype: The internal state data type.
@@ -168,7 +164,7 @@ def zero_state(self, batch_size, dtype):
alignments=self._attention_mechanism.initial_alignments(batch_size, dtype),
alignment_history=tensor_array_ops.TensorArray(dtype=dtype, size=0,
dynamic_size=True),
- finished=tf.reshape(tf.tile([0.0], [batch_size]), [-1, 1]))
+ max_attentions=tf.zeros((batch_size, ), dtype=tf.int32))
def __call__(self, inputs, state):
#Information bottleneck (essential for learning attention)
@@ -180,18 +176,20 @@ def __call__(self, inputs, state):
#Unidirectional LSTM layers
LSTM_output, next_cell_state = self._cell(LSTM_input, state.cell_state)
+
#Compute the attention (context) vector and alignments using
- #the new decoder cell hidden state as query vector
+ #the new decoder cell hidden state as query vector
#and cumulative alignments to extract location features
#The choice of the new cell hidden state (s_{i}) of the last
#decoder RNN Cell is based on Luong et Al. (2015):
#https://arxiv.org/pdf/1508.04025.pdf
previous_alignments = state.alignments
previous_alignment_history = state.alignment_history
- context_vector, alignments, cumulated_alignments = _compute_attention(self._attention_mechanism,
+ context_vector, alignments, cumulated_alignments, max_attentions = _compute_attention(self._attention_mechanism,
LSTM_output,
previous_alignments,
- attention_layer=None)
+ attention_layer=None,
+ prev_max_attentions=state.max_attentions)
#Concat LSTM outputs and context vector to form projections inputs
projections_input = tf.concat([LSTM_output, context_vector], axis=-1)
@@ -200,19 +198,8 @@ def __call__(self, inputs, state):
cell_outputs = self._frame_projection(projections_input)
stop_tokens = self._stop_projection(projections_input)
- #mask attention computed for decoding steps where sequence is already finished
- #this is purely for visual purposes and will not affect the training of the model
- #we don't pay much attention to the alignments of the output paddings if we impute
- #the decoder outputs beyond the end of sequence.
- if self._mask_finished:
- finished = tf.cast(state.finished * tf.ones(tf.shape(alignments)), tf.bool)
- mask = tf.zeros(tf.shape(alignments))
- masked_alignments = tf.where(finished, mask, alignments)
- else:
- masked_alignments = alignments
-
#Save alignment history
- alignment_history = previous_alignment_history.write(state.time, masked_alignments)
+ alignment_history = previous_alignment_history.write(state.time, alignments)
#Prepare next decoder state
next_state = TacotronDecoderCellState(
@@ -221,6 +208,6 @@ def __call__(self, inputs, state):
attention=context_vector,
alignments=cumulated_alignments,
alignment_history=alignment_history,
- finished=state.finished)
+ max_attentions=max_attentions)
- return (cell_outputs, stop_tokens), next_state
+ return (cell_outputs, stop_tokens), next_state
diff --git a/tacotron/models/attention.py b/tacotron/models/attention.py
index c31fbc78..2804edcc 100644
--- a/tacotron/models/attention.py
+++ b/tacotron/models/attention.py
@@ -2,14 +2,37 @@
import tensorflow as tf
from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import BahdanauAttention
-from tensorflow.python.ops import nn_ops
from tensorflow.python.layers import core as layers_core
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import math_ops
-from hparams import hparams
-
-
+from tensorflow.python.ops import array_ops, math_ops, nn_ops, variable_scope
+
+
+#From https://github.com/tensorflow/tensorflow/blob/r1.7/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+def _compute_attention(attention_mechanism, cell_output, attention_state,
+ attention_layer, prev_max_attentions):
+ """Computes the attention and alignments for a given attention_mechanism."""
+ alignments, next_attention_state, max_attentions = attention_mechanism(
+ cell_output, state=attention_state, prev_max_attentions=prev_max_attentions)
+
+ # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
+ expanded_alignments = array_ops.expand_dims(alignments, 1)
+ # Context is the inner product of alignments and values along the
+ # memory time dimension.
+ # alignments shape is
+ # [batch_size, 1, memory_time]
+ # attention_mechanism.values shape is
+ # [batch_size, memory_time, memory_size]
+ # the batched matmul is over memory_time, so the output shape is
+ # [batch_size, 1, memory_size].
+ # we then squeeze out the singleton dim.
+ context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
+ context = array_ops.squeeze(context, [1])
+
+ if attention_layer is not None:
+ attention = attention_layer(array_ops.concat([cell_output, context], 1))
+ else:
+ attention = context
+
+ return attention, alignments, next_attention_state, max_attentions
def _location_sensitive_score(W_query, W_fil, W_keys):
@@ -20,17 +43,17 @@ def _location_sensitive_score(W_query, W_fil, W_keys):
vances in Neural Information Processing Systems, 2015, pp.
577–585.
- #############################################################################
- hybrid attention (content-based + location-based)
- f = F * α_{i-1}
- energy = dot(v_a, tanh(W_keys(h_enc) + W_query(h_dec) + W_fil(f) + b_a))
- #############################################################################
+ #############################################################################
+ hybrid attention (content-based + location-based)
+ f = F * α_{i-1}
+ energy = dot(v_a, tanh(W_keys(h_enc) + W_query(h_dec) + W_fil(f) + b_a))
+ #############################################################################
- Args:
+ Args:
W_query: Tensor, shape '[batch_size, 1, attention_dim]' to compare to location features.
W_location: processed previous alignments into location features, shape '[batch_size, max_time, attention_dim]'
W_keys: Tensor, shape '[batch_size, max_time, attention_dim]', typically the encoder outputs.
- Returns:
+ Returns:
A '[batch_size, max_time]' attention score (energy)
"""
# Get the number of hidden units from the trailing dimension of keys
@@ -38,7 +61,8 @@ def _location_sensitive_score(W_query, W_fil, W_keys):
num_units = W_keys.shape[-1].value or array_ops.shape(W_keys)[-1]
v_a = tf.get_variable(
- 'attention_variable', shape=[num_units], dtype=dtype)
+ 'attention_variable_projection', shape=[num_units], dtype=dtype,
+ initializer=tf.contrib.layers.xavier_initializer())
b_a = tf.get_variable(
'attention_bias', shape=[num_units], dtype=dtype,
initializer=tf.zeros_initializer())
@@ -54,17 +78,17 @@ def _smoothing_normalization(e):
577–585.
############################################################################
- Smoothing normalization function
- a_{i, j} = sigmoid(e_{i, j}) / sum_j(sigmoid(e_{i, j}))
- ############################################################################
-
- Args:
- e: matrix [batch_size, max_time(memory_time)]: expected to be energy (score)
- values of an attention mechanism
- Returns:
- matrix [batch_size, max_time]: [0, 1] normalized alignments with possible
- attendance to multiple memory time steps.
- """
+ Smoothing normalization function
+ a_{i, j} = sigmoid(e_{i, j}) / sum_j(sigmoid(e_{i, j}))
+ ############################################################################
+
+ Args:
+ e: matrix [batch_size, max_time(memory_time)]: expected to be energy (score)
+ values of an attention mechanism
+ Returns:
+ matrix [batch_size, max_time]: [0, 1] normalized alignments with possible
+ attendance to multiple memory time steps.
+ """
return tf.nn.sigmoid(e) / tf.reduce_sum(tf.nn.sigmoid(e), axis=-1, keepdims=True)
@@ -75,8 +99,8 @@ class LocationSensitiveAttention(BahdanauAttention):
"D. Bahdanau, K. Cho, and Y. Bengio, “Neural machine transla-
tion by jointly learning to align and translate,” in Proceedings
of ICLR, 2015."
- to use previous alignments as additional location features.
-
+ to use previous alignments as additional location features.
+
This attention is described in:
J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben-
gio, “Attention-based models for speech recognition,” in Ad-
@@ -87,9 +111,12 @@ class LocationSensitiveAttention(BahdanauAttention):
def __init__(self,
num_units,
memory,
+ hparams,
+ is_training,
mask_encoder=True,
memory_sequence_length=None,
smoothing=False,
+ cumulate_weights=True,
name='LocationSensitiveAttention'):
"""Construct the Attention mechanism.
Args:
@@ -101,7 +128,7 @@ def __init__(self,
in memory. If provided, the memory tensor rows are masked with zeros
for values past the respective sequence lengths. Only relevant if mask_encoder = True.
smoothing (optional): Boolean. Determines which normalization function to use.
- Default normalization function (probablity_fn) is softmax. If smoothing is
+ Default normalization function (probablity_fn) is softmax. If smoothing is
enabled, we replace softmax with:
a_{i, j} = sigmoid(e_{i, j}) / sum_j(sigmoid(e_{i, j}))
Introduced in:
@@ -109,9 +136,9 @@ def __init__(self,
gio, “Attention-based models for speech recognition,” in Ad-
vances in Neural Information Processing Systems, 2015, pp.
577–585.
- This is mainly used if the model wants to attend to multiple inputs parts
+ This is mainly used if the model wants to attend to multiple input parts
at the same decoding step. We probably won't be using it since multiple sound
- frames may depend from the same character, probably not the way around.
+ frames may depend on the same character/phone, probably not the way around.
Note:
We still keep it implemented in case we want to test it. They used it in the
paper in the context of speech recognition, where one phoneme may depend on
@@ -130,12 +157,16 @@ def __init__(self,
name=name)
self.location_convolution = tf.layers.Conv1D(filters=hparams.attention_filters,
- kernel_size=hparams.attention_kernel, padding='same', use_bias=False,
- name='location_features_convolution')
- self.location_layer = tf.layers.Dense(units=num_units, use_bias=False,
+ kernel_size=hparams.attention_kernel, padding='same', use_bias=True,
+ bias_initializer=tf.zeros_initializer(), name='location_features_convolution')
+ self.location_layer = tf.layers.Dense(units=num_units, use_bias=False,
dtype=tf.float32, name='location_features_layer')
+ self._cumulate = cumulate_weights
+ self.synthesis_constraint = hparams.synthesis_constraint and not is_training
+ self.attention_win_size = tf.convert_to_tensor(hparams.attention_win_size, dtype=tf.int32)
+ self.constraint_type = hparams.synthesis_constraint_type
- def __call__(self, query, state):
+ def __call__(self, query, state, prev_max_attentions):
"""Score the query based on the keys and values.
Args:
query: Tensor of dtype matching `self.values` and shape
@@ -167,9 +198,29 @@ def __call__(self, query, state):
# energy shape [batch_size, max_time]
energy = _location_sensitive_score(processed_query, processed_location_features, self.keys)
+ if self.synthesis_constraint:
+ Tx = tf.shape(energy)[-1]
+ # prev_max_attentions = tf.squeeze(prev_max_attentions, [-1])
+ if self.constraint_type == 'monotonic':
+ key_masks = tf.sequence_mask(prev_max_attentions, Tx)
+ reverse_masks = tf.sequence_mask(Tx - self.attention_win_size - prev_max_attentions, Tx)[:, ::-1]
+ else:
+ assert self.constraint_type == 'window'
+ key_masks = tf.sequence_mask(prev_max_attentions - (self.attention_win_size // 2 + (self.attention_win_size % 2 != 0)), Tx)
+ reverse_masks = tf.sequence_mask(Tx - (self.attention_win_size // 2) - prev_max_attentions, Tx)[:, ::-1]
+
+ masks = tf.logical_or(key_masks, reverse_masks)
+ paddings = tf.ones_like(energy) * (-2 ** 32 + 1) # (N, Ty/r, Tx)
+ energy = tf.where(tf.equal(masks, False), energy, paddings)
+
# alignments shape = energy shape = [batch_size, max_time]
alignments = self._probability_fn(energy, previous_alignments)
+ max_attentions = tf.argmax(alignments, -1, output_type=tf.int32) # (N, Ty/r)
# Cumulate alignments
- next_state = alignments + previous_alignments
- return alignments, next_state
+ if self._cumulate:
+ next_state = alignments + previous_alignments
+ else:
+ next_state = alignments
+
+ return alignments, next_state, max_attentions
diff --git a/tacotron/models/attention_old.py b/tacotron/models/attention_old.py
deleted file mode 100644
index 23180296..00000000
--- a/tacotron/models/attention_old.py
+++ /dev/null
@@ -1,174 +0,0 @@
-"""Attention file for location based attention (compatible with tensorflow attention wrapper)"""
-
-import tensorflow as tf
-from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import BahdanauAttention
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.layers import core as layers_core
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import math_ops
-from hparams import hparams
-
-
-
-
-def _location_sensitive_score(W_query, W_fil, W_keys):
- """Impelements Bahdanau-style (cumulative) scoring function.
- This attention is described in:
- J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben-
- gio, “Attention-based models for speech recognition,” in Ad-
- vances in Neural Information Processing Systems, 2015, pp.
- 577–585.
-
- #############################################################################
- hybrid attention (content-based + location-based)
- f = F * α_{i-1}
- energy = dot(v_a, tanh(W_keys(h_enc) + W_query(h_dec) + W_fil(f) + b_a))
- #############################################################################
-
- Args:
- W_query: Tensor, shape '[batch_size, 1, attention_dim]' to compare to location features.
- W_location: processed previous alignments into location features, shape '[batch_size, max_time, attention_dim]'
- W_keys: Tensor, shape '[batch_size, max_time, attention_dim]', typically the encoder outputs.
- Returns:
- A '[batch_size, max_time]' attention score (energy)
- """
- # Get the number of hidden units from the trailing dimension of keys
- dtype = W_query.dtype
- num_units = W_keys.shape[-1].value or array_ops.shape(W_keys)[-1]
-
- v_a = tf.get_variable(
- 'attention_variable', shape=[num_units], dtype=dtype)
- b_a = tf.get_variable(
- 'attention_bias', shape=[num_units], dtype=dtype,
- initializer=tf.zeros_initializer())
-
- return tf.reduce_sum(v_a * tf.tanh(W_keys + W_query + W_fil + b_a), [2])
-
-def _smoothing_normalization(e):
- """Applies a smoothing normalization function instead of softmax
- Introduced in:
- J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben-
- gio, “Attention-based models for speech recognition,” in Ad-
- vances in Neural Information Processing Systems, 2015, pp.
- 577–585.
-
- ############################################################################
- Smoothing normalization function
- a_{i, j} = sigmoid(e_{i, j}) / sum_j(sigmoid(e_{i, j}))
- ############################################################################
-
- Args:
- e: matrix [batch_size, max_time(memory_time)]: expected to be energy (score)
- values of an attention mechanism
- Returns:
- matrix [batch_size, max_time]: [0, 1] normalized alignments with possible
- attendance to multiple memory time steps.
- """
- return tf.nn.sigmoid(e) / tf.reduce_sum(tf.nn.sigmoid(e), axis=-1, keepdims=True)
-
-
-class LocationSensitiveAttention(BahdanauAttention):
- """Impelements Bahdanau-style (cumulative) scoring function.
- Usually referred to as "hybrid" attention (content-based + location-based)
- Extends the additive attention described in:
- "D. Bahdanau, K. Cho, and Y. Bengio, “Neural machine transla-
- tion by jointly learning to align and translate,” in Proceedings
- of ICLR, 2015."
- to use previous alignments as additional location features.
-
- This attention is described in:
- J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben-
- gio, “Attention-based models for speech recognition,” in Ad-
- vances in Neural Information Processing Systems, 2015, pp.
- 577–585.
- """
-
- def __init__(self,
- num_units,
- memory,
- mask_encoder=True,
- memory_sequence_length=None,
- smoothing=False,
- name='LocationSensitiveAttention'):
- """Construct the Attention mechanism.
- Args:
- num_units: The depth of the query mechanism.
- memory: The memory to query; usually the output of an RNN encoder. This
- tensor should be shaped `[batch_size, max_time, ...]`.
- mask_encoder (optional): Boolean, whether to mask encoder paddings.
- memory_sequence_length (optional): Sequence lengths for the batch entries
- in memory. If provided, the memory tensor rows are masked with zeros
- for values past the respective sequence lengths. Only relevant if mask_encoder = True.
- smoothing (optional): Boolean. Determines which normalization function to use.
- Default normalization function (probablity_fn) is softmax. If smoothing is
- enabled, we replace softmax with:
- a_{i, j} = sigmoid(e_{i, j}) / sum_j(sigmoid(e_{i, j}))
- Introduced in:
- J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben-
- gio, “Attention-based models for speech recognition,” in Ad-
- vances in Neural Information Processing Systems, 2015, pp.
- 577–585.
- This is mainly used if the model wants to attend to multiple inputs parts
- at the same decoding step. We probably won't be using it since multiple sound
- frames may depend from the same character, probably not the way around.
- Note:
- We still keep it implemented in case we want to test it. They used it in the
- paper in the context of speech recognition, where one phoneme may depend on
- multiple subsequent sound frames.
- name: Name to use when creating ops.
- """
- #Create normalization function
- #Setting it to None defaults in using softmax
- normalization_function = _smoothing_normalization if (smoothing == True) else None
- memory_length = memory_sequence_length if (mask_encoder==True) else None
- super(LocationSensitiveAttention, self).__init__(
- num_units=num_units,
- memory=memory,
- memory_sequence_length=memory_length,
- probability_fn=normalization_function,
- name=name)
-
- self.location_convolution = tf.layers.Conv1D(filters=hparams.attention_filters,
- kernel_size=hparams.attention_kernel, padding='same', use_bias=False,
- name='location_features_convolution')
- self.location_layer = tf.layers.Dense(units=num_units, use_bias=False,
- dtype=tf.float32, name='location_features_layer')
-
- def __call__(self, query, previous_alignments):
- """Score the query based on the keys and values.
- Args:
- query: Tensor of dtype matching `self.values` and shape
- `[batch_size, query_depth]`.
- previous_alignments: Tensor of dtype matching `self.values` and shape
- `[batch_size, alignments_size]`
- (`alignments_size` is memory's `max_time`).
- Returns:
- alignments: Tensor of dtype matching `self.values` and shape
- `[batch_size, alignments_size]` (`alignments_size` is memory's
- `max_time`).
- """
- with variable_scope.variable_scope(None, "Location_Sensitive_Attention", [query]):
-
- # processed_query shape [batch_size, query_depth] -> [batch_size, attention_dim]
- processed_query = self.query_layer(query) if self.query_layer else query
- # -> [batch_size, 1, attention_dim]
- processed_query = tf.expand_dims(processed_query, 1)
-
- # processed_location_features shape [batch_size, max_time, attention dimension]
- # [batch_size, max_time] -> [batch_size, max_time, 1]
- expanded_alignments = tf.expand_dims(previous_alignments, axis=2)
- # location features [batch_size, max_time, filters]
- f = self.location_convolution(expanded_alignments)
- # Projected location features [batch_size, max_time, attention_dim]
- processed_location_features = self.location_layer(f)
-
- # energy shape [batch_size, max_time]
- energy = _location_sensitive_score(processed_query, processed_location_features, self.keys)
-
- # alignments shape = energy shape = [batch_size, max_time]
- alignments = self._probability_fn(energy, previous_alignments)
-
- # Cumulate alignments
- next_state = alignments + previous_alignments
- return alignments, next_state
diff --git a/tacotron/models/custom_decoder.py b/tacotron/models/custom_decoder.py
index 91593ffe..1029eaa9 100644
--- a/tacotron/models/custom_decoder.py
+++ b/tacotron/models/custom_decoder.py
@@ -1,19 +1,15 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+from __future__ import absolute_import, division, print_function
import collections
-import tensorflow as tf
+import tensorflow as tf
+from tacotron.models.helpers import TacoTestHelper, TacoTrainingHelper
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import ops, tensor_shape
from tensorflow.python.layers import base as layers_base
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.util import nest
-from tacotron.models.helpers import TacoTrainingHelper, TacoTestHelper
-
class CustomDecoderOutput(
@@ -44,8 +40,7 @@ def __init__(self, cell, helper, initial_state, output_layer=None):
Raises:
TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
"""
- if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
- raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
+ rnn_cell_impl.assert_like_rnncell(type(cell), cell)
if not isinstance(helper, helper_py.Helper):
raise TypeError("helper must be a Helper, received: %s" % type(helper))
if (output_layer is not None
@@ -136,4 +131,4 @@ def step(self, time, inputs, state, name=None):
stop_token_prediction=stop_token)
outputs = CustomDecoderOutput(cell_outputs, stop_token, sample_ids)
- return (outputs, next_state, next_inputs, finished)
\ No newline at end of file
+ return (outputs, next_state, next_inputs, finished)
diff --git a/tacotron/models/helpers.py b/tacotron/models/helpers.py
index ff35d58a..9c464921 100644
--- a/tacotron/models/helpers.py
+++ b/tacotron/models/helpers.py
@@ -1,15 +1,15 @@
import numpy as np
import tensorflow as tf
from tensorflow.contrib.seq2seq import Helper
-from hparams import hparams
class TacoTestHelper(Helper):
- def __init__(self, batch_size, output_dim, r):
+ def __init__(self, batch_size, hparams):
with tf.name_scope('TacoTestHelper'):
self._batch_size = batch_size
- self._output_dim = output_dim
- self._reduction_factor = r
+ self._output_dim = hparams.num_mels
+ self._reduction_factor = hparams.outputs_per_step
+ self.stop_at_any = hparams.stop_at_any
@property
def batch_size(self):
@@ -39,20 +39,20 @@ def next_inputs(self, time, outputs, state, sample_ids, stop_token_prediction, n
#A sequence is finished when the output probability is > 0.5
finished = tf.cast(tf.round(stop_token_prediction), tf.bool)
- #Since we are predicting r frames at each step, two modes are
+ #Since we are predicting r frames at each step, two modes are
#then possible:
# Stop when the model outputs a p > 0.5 for any frame between r frames (Recommended)
# Stop when the model outputs a p > 0.5 for all r frames (Safer)
#Note:
# With enough training steps, the model should be able to predict when to stop correctly
# and the use of stop_at_any = True would be recommended. If however the model didn't
- # learn to stop correctly yet, (stops too soon) one could choose to use the safer option
+ # learn to stop correctly yet, (stops too soon) one could choose to use the safer option
# to get a correct synthesis
- if hparams.stop_at_any:
- finished = tf.reduce_any(finished) #Recommended
+ if self.stop_at_any:
+ finished = tf.reduce_any(tf.reduce_all(finished, axis=0)) #Recommended
else:
- finished = tf.reduce_all(finished) #Safer option
-
+ finished = tf.reduce_all(tf.reduce_all(finished, axis=0)) #Safer option
+
# Feed last output frame as next input. outputs is [N, output_dim * r]
next_inputs = outputs[:, -self._output_dim:]
next_state = state
@@ -60,24 +60,24 @@ def next_inputs(self, time, outputs, state, sample_ids, stop_token_prediction, n
class TacoTrainingHelper(Helper):
- def __init__(self, batch_size, targets, stop_targets, output_dim, r, ratio, gta):
+ def __init__(self, batch_size, targets, hparams, gta, evaluating, global_step):
# inputs is [N, T_in], targets is [N, T_out, D]
with tf.name_scope('TacoTrainingHelper'):
self._batch_size = batch_size
- self._output_dim = output_dim
- self._reduction_factor = r
- self._ratio = ratio
+ self._output_dim = hparams.num_mels
+ self._reduction_factor = hparams.outputs_per_step
+ self._ratio = tf.convert_to_tensor(hparams.tacotron_teacher_forcing_ratio)
self.gta = gta
+ self.eval = evaluating
+ self._hparams = hparams
+ self.global_step = global_step
+ r = self._reduction_factor
# Feed every r-th target frame as input
self._targets = targets[:, r-1::r, :]
- if not gta:
- # Detect finished sequence using stop_targets
- self._stop_targets = stop_targets[:, r-1::r]
- else:
- # GTA synthesis
- self._lengths = tf.tile([tf.shape(self._targets)[1]], [self._batch_size])
+ #Maximal sequence length
+ self._lengths = tf.tile([tf.shape(self._targets)[1]], [self._batch_size])
@property
def batch_size(self):
@@ -96,6 +96,17 @@ def sample_ids_dtype(self):
return np.int32
def initialize(self, name=None):
+ #Compute teacher forcing ratio for this global step.
+ #In GTA mode, override teacher forcing scheme to work with full teacher forcing
+ if self.gta:
+ self._ratio = tf.convert_to_tensor(1.) #Force GTA model to always feed ground-truth
+ elif self.eval and self._hparams.tacotron_natural_eval:
+ self._ratio = tf.convert_to_tensor(0.) #Force eval model to always feed predictions
+ else:
+ if self._hparams.tacotron_teacher_forcing_mode == 'scheduled':
+ self._ratio = _teacher_forcing_ratio_decay(self._hparams.tacotron_teacher_forcing_init_ratio,
+ self.global_step, self._hparams)
+
return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim))
def sample(self, time, outputs, state, name=None):
@@ -103,22 +114,56 @@ def sample(self, time, outputs, state, name=None):
def next_inputs(self, time, outputs, state, sample_ids, stop_token_prediction, name=None):
with tf.name_scope(name or 'TacoTrainingHelper'):
- if not self.gta:
- #mark sequences where stop_target == 1 as finished (for case of imputation)
- finished = tf.equal(self._stop_targets[:, time], [1.])
- else:
- #GTA synthesis stop
- finished = (time + 1 >= self._lengths)
+ #synthesis stop (we let the model see paddings as we mask them when computing loss functions)
+ finished = (time + 1 >= self._lengths)
- if np.random.random() <= self._ratio:
- next_inputs = self._targets[:, time, :] #Teacher-forcing: return true frame
- else:
- next_inputs = outputs[:, -self._output_dim:]
- #Update the finished state
- next_state = state.replace(finished=tf.cast(tf.reshape(finished, [-1, 1]), tf.float32))
+ #Pick previous outputs randomly with respect to teacher forcing ratio
+ next_inputs = tf.cond(
+ tf.less(tf.random_uniform([], minval=0, maxval=1, dtype=tf.float32), self._ratio),
+ lambda: self._targets[:, time, :], #Teacher-forcing: return true frame
+ lambda: outputs[:,-self._output_dim:])
+
+ #Pass on state
+ next_state = state
return (finished, next_inputs, next_state)
def _go_frames(batch_size, output_dim):
'''Returns all-zero frames for a given batch size and output dimension'''
- return tf.tile([[0.0]], [batch_size, output_dim])
\ No newline at end of file
+ return tf.tile([[0.0]], [batch_size, output_dim])
+
+def _teacher_forcing_ratio_decay(init_tfr, global_step, hparams):
+ #################################################################
+ # Narrow Cosine Decay:
+
+ # Phase 1: tfr = init
+ # We only start learning rate decay after 10k steps
+
+ # Phase 2: tfr in ]init, final[
+ # decay reach minimal value at step ~40k
+
+ # Phase 3: tfr = final
+ # clip by minimal teacher forcing ratio value (step >~ 40k)
+ #################################################################
+ #Pick final teacher forcing rate value
+ if hparams.tacotron_teacher_forcing_final_ratio is not None:
+ alpha = float(hparams.tacotron_teacher_forcing_final_ratio / hparams.tacotron_teacher_forcing_init_ratio)
+
+ else:
+ assert hparams.tacotron_teacher_forcing_decay_alpha is not None
+ alpha = hparams.tacotron_teacher_forcing_decay_alpha
+
+ #Compute natural cosine decay
+ tfr = tf.train.cosine_decay(init_tfr,
+ global_step=global_step - hparams.tacotron_teacher_forcing_start_decay, #tfr ~= init at step 10k
+ decay_steps=hparams.tacotron_teacher_forcing_decay_steps, #tfr ~= final at step ~40k
+ alpha=alpha, #tfr = alpha% of init_tfr as final value
+ name='tfr_cosine_decay')
+
+ #force teacher forcing ratio to take initial value when global step < start decay step.
+ narrow_tfr = tf.cond(
+ tf.less(global_step, tf.convert_to_tensor(hparams.tacotron_teacher_forcing_start_decay)),
+ lambda: tf.convert_to_tensor(init_tfr),
+ lambda: tfr)
+
+ return narrow_tfr
diff --git a/tacotron/models/modules.py b/tacotron/models/modules.py
index d77fe7eb..252e18f9 100644
--- a/tacotron/models/modules.py
+++ b/tacotron/models/modules.py
@@ -1,29 +1,160 @@
-import tensorflow as tf
-from tacotron.models.zoneout_LSTM import ZoneoutLSTMCell
-from tensorflow.contrib.rnn import LSTMBlockCell
-from hparams import hparams
+import numpy as np
+import tensorflow as tf
-def conv1d(inputs, kernel_size, channels, activation, is_training, scope):
- drop_rate = hparams.tacotron_dropout_rate
+def GuidedAttention(N, T, g=0.2):
+ W = np.zeros((N, T), dtype=np.float32)
+ for n in range(N):
+ for t in range(T):
+ W[n, t] = 1 - np.exp(-(t / float(T) - n / float(N)) ** 2 / (2 * g * g))
+ return W
- with tf.variable_scope(scope):
- conv1d_output = tf.layers.conv1d(
- inputs,
- filters=channels,
- kernel_size=kernel_size,
- activation=None,
- padding='same')
- batched = tf.layers.batch_normalization(conv1d_output, training=is_training)
- activated = activation(batched)
- return tf.layers.dropout(activated, rate=drop_rate, training=is_training,
- name='dropout_{}'.format(scope))
+
+class HighwayNet:
+ def __init__(self, units, name=None):
+ self.units = units
+ self.scope = 'HighwayNet' if name is None else name
+
+ self.H_layer = tf.layers.Dense(units=self.units, activation=tf.nn.relu, name='H')
+ self.T_layer = tf.layers.Dense(units=self.units, activation=tf.nn.sigmoid, name='T', bias_initializer=tf.constant_initializer(-1.))
+
+ def __call__(self, inputs):
+ with tf.variable_scope(self.scope):
+ H = self.H_layer(inputs)
+ T = self.T_layer(inputs)
+ return H * T + inputs * (1. - T)
+
+
+class CBHG:
+ def __init__(self, K, conv_channels, pool_size, projections, projection_kernel_size, n_highwaynet_layers, highway_units, rnn_units, bnorm, is_training, name=None):
+ self.K = K
+ self.conv_channels = conv_channels
+ self.pool_size = pool_size
+
+ self.projections = projections
+ self.projection_kernel_size = projection_kernel_size
+ self.bnorm = bnorm
+
+ self.is_training = is_training
+ self.scope = 'CBHG' if name is None else name
+
+ self.highway_units = highway_units
+ self.highwaynet_layers = [HighwayNet(highway_units, name='{}_highwaynet_{}'.format(self.scope, i+1)) for i in range(n_highwaynet_layers)]
+ self._fw_cell = tf.nn.rnn_cell.GRUCell(rnn_units, name='{}_forward_RNN'.format(self.scope))
+ self._bw_cell = tf.nn.rnn_cell.GRUCell(rnn_units, name='{}_backward_RNN'.format(self.scope))
+
+ def __call__(self, inputs, input_lengths):
+ with tf.variable_scope(self.scope):
+ with tf.variable_scope('conv_bank'):
+ #Convolution bank: concatenate on the last axis to stack channels from all convolutions
+ #The convolution bank uses multiple different kernel sizes to have many insights of the input sequence
+ #This makes one of the strengths of the CBHG block on sequences.
+ conv_outputs = tf.concat(
+ [conv1d(inputs, k, self.conv_channels, tf.nn.relu, self.is_training, self.bnorm, 'conv1d_{}'.format(k)) for k in range(1, self.K+1)],
+ axis=-1
+ )
+
+ #Maxpooling (dimension reduction, Using max instead of average helps finding "Edges" in mels)
+ maxpool_output = tf.layers.max_pooling1d(
+ conv_outputs,
+ pool_size=self.pool_size,
+ strides=1,
+ padding='same')
+
+ #Two projection layers
+ proj1_output = conv1d(maxpool_output, self.projection_kernel_size, self.projections[0], tf.nn.relu, self.is_training, self.bnorm, 'proj1')
+ proj2_output = conv1d(proj1_output, self.projection_kernel_size, self.projections[1], lambda _: _, self.is_training, self.bnorm, 'proj2')
+
+ #Residual connection
+ highway_input = proj2_output + inputs
+
+ #Additional projection in case of dimension mismatch (for HighwayNet "residual" connection)
+ if highway_input.shape[2] != self.highway_units:
+ highway_input = tf.layers.dense(highway_input, self.highway_units)
+
+ #4-layer HighwayNet
+ for highwaynet in self.highwaynet_layers:
+ highway_input = highwaynet(highway_input)
+ rnn_input = highway_input
+
+ #Bidirectional RNN
+ outputs, states = tf.nn.bidirectional_dynamic_rnn(
+ self._fw_cell,
+ self._bw_cell,
+ rnn_input,
+ sequence_length=input_lengths,
+ dtype=tf.float32)
+ return tf.concat(outputs, axis=2) #Concat forward and backward outputs
+
+
+class ZoneoutLSTMCell(tf.nn.rnn_cell.RNNCell):
+ '''Wrapper for tf LSTM to create Zoneout LSTM Cell
+
+ inspired by:
+ https://github.com/teganmaharaj/zoneout/blob/master/zoneout_tensorflow.py
+
+ Published by one of 'https://arxiv.org/pdf/1606.01305.pdf' paper writers.
+
+ Many thanks to @Ondal90 for pointing this out. You sir are a hero!
+ '''
+ def __init__(self, num_units, is_training, zoneout_factor_cell=0., zoneout_factor_output=0., state_is_tuple=True, name=None):
+ '''Initializer with possibility to set different zoneout values for cell/hidden states.
+ '''
+ zm = min(zoneout_factor_output, zoneout_factor_cell)
+ zs = max(zoneout_factor_output, zoneout_factor_cell)
+
+ if zm < 0. or zs > 1.:
+ raise ValueError('One/both provided Zoneout factors are not in [0, 1]')
+
+ self._cell = tf.nn.rnn_cell.LSTMCell(num_units, state_is_tuple=state_is_tuple, name=name)
+ self._zoneout_cell = zoneout_factor_cell
+ self._zoneout_outputs = zoneout_factor_output
+ self.is_training = is_training
+ self.state_is_tuple = state_is_tuple
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ def __call__(self, inputs, state, scope=None):
+ '''Runs vanilla LSTM Cell and applies zoneout.
+ '''
+ #Apply vanilla LSTM
+ output, new_state = self._cell(inputs, state, scope)
+
+ if self.state_is_tuple:
+ (prev_c, prev_h) = state
+ (new_c, new_h) = new_state
+ else:
+ num_proj = self._cell._num_units if self._cell._num_proj is None else self._cell._num_proj
+ prev_c = tf.slice(state, [0, 0], [-1, self._cell._num_units])
+ prev_h = tf.slice(state, [0, self._cell._num_units], [-1, num_proj])
+ new_c = tf.slice(new_state, [0, 0], [-1, self._cell._num_units])
+ new_h = tf.slice(new_state, [0, self._cell._num_units], [-1, num_proj])
+
+ #Apply zoneout
+ if self.is_training:
+ #nn.dropout takes keep_prob (probability to keep activations) not drop_prob (probability to mask activations)!
+ c = (1 - self._zoneout_cell) * tf.nn.dropout(new_c - prev_c, (1 - self._zoneout_cell)) + prev_c
+ h = (1 - self._zoneout_outputs) * tf.nn.dropout(new_h - prev_h, (1 - self._zoneout_outputs)) + prev_h
+
+ else:
+ c = (1 - self._zoneout_cell) * new_c + self._zoneout_cell * prev_c
+ h = (1 - self._zoneout_outputs) * new_h + self._zoneout_outputs * prev_h
+
+ new_state = tf.nn.rnn_cell.LSTMStateTuple(c, h) if self.state_is_tuple else tf.concat(1, [c, h])
+
+ return output, new_state
class EncoderConvolutions:
"""Encoder convolutional layers used to find local dependencies in inputs characters.
"""
- def __init__(self, is_training, kernel_size=(5, ), channels=512, activation=tf.nn.relu, scope=None):
+ def __init__(self, is_training, hparams, activation=tf.nn.relu, scope=None):
"""
Args:
is_training: Boolean, determines if the model is training or in inference to control dropout
@@ -35,17 +166,20 @@ def __init__(self, is_training, kernel_size=(5, ), channels=512, activation=tf.n
super(EncoderConvolutions, self).__init__()
self.is_training = is_training
- self.kernel_size = kernel_size
- self.channels = channels
+ self.kernel_size = hparams.enc_conv_kernel_size
+ self.channels = hparams.enc_conv_channels
self.activation = activation
self.scope = 'enc_conv_layers' if scope is None else scope
+ self.drop_rate = hparams.tacotron_dropout_rate
+ self.enc_conv_num_layers = hparams.enc_conv_num_layers
+ self.bnorm = hparams.batch_norm_position
def __call__(self, inputs):
with tf.variable_scope(self.scope):
x = inputs
- for i in range(hparams.enc_conv_num_layers):
+ for i in range(self.enc_conv_num_layers):
x = conv1d(x, self.kernel_size, self.channels, self.activation,
- self.is_training, 'conv_layer_{}_'.format(i + 1)+self.scope)
+ self.is_training, self.bnorm, 'conv_layer_{}_'.format(i + 1)+self.scope)
return x
@@ -67,19 +201,27 @@ def __init__(self, is_training, size=256, zoneout=0.1, scope=None):
self.zoneout = zoneout
self.scope = 'encoder_LSTM' if scope is None else scope
- #Create LSTM Cell
- self._cell = ZoneoutLSTMCell(size, is_training,
+ #Create forward LSTM Cell
+ self._fw_cell = ZoneoutLSTMCell(size, is_training,
+ zoneout_factor_cell=zoneout,
+ zoneout_factor_output=zoneout,
+ name='encoder_fw_LSTM')
+
+ #Create backward LSTM Cell
+ self._bw_cell = ZoneoutLSTMCell(size, is_training,
zoneout_factor_cell=zoneout,
- zoneout_factor_output=zoneout)
+ zoneout_factor_output=zoneout,
+ name='encoder_bw_LSTM')
def __call__(self, inputs, input_lengths):
with tf.variable_scope(self.scope):
outputs, (fw_state, bw_state) = tf.nn.bidirectional_dynamic_rnn(
- self._cell,
- self._cell,
+ self._fw_cell,
+ self._bw_cell,
inputs,
sequence_length=input_lengths,
- dtype=tf.float32)
+ dtype=tf.float32,
+ swap_memory=True)
return tf.concat(outputs, axis=2) # Concat and return forward + backward outputs
@@ -87,34 +229,33 @@ def __call__(self, inputs, input_lengths):
class Prenet:
"""Two fully connected layers used as an information bottleneck for the attention.
"""
- def __init__(self, is_training, layer_sizes=[256, 256], activation=tf.nn.relu, scope=None):
+ def __init__(self, is_training, layers_sizes=[256, 256], drop_rate=0.5, activation=tf.nn.relu, scope=None):
"""
Args:
- is_training: Boolean, determines if the model is in training or inference to control dropout
- layer_sizes: list of integers, the length of the list represents the number of pre-net
+ layers_sizes: list of integers, the length of the list represents the number of pre-net
layers and the list values represent the layers number of units
activation: callable, activation functions of the prenet layers.
scope: Prenet scope.
"""
super(Prenet, self).__init__()
- self.drop_rate = hparams.tacotron_dropout_rate
+ self.drop_rate = drop_rate
- self.layer_sizes = layer_sizes
- self.is_training = is_training
+ self.layers_sizes = layers_sizes
self.activation = activation
-
+ self.is_training = is_training
+
self.scope = 'prenet' if scope is None else scope
def __call__(self, inputs):
x = inputs
with tf.variable_scope(self.scope):
- for i, size in enumerate(self.layer_sizes):
+ for i, size in enumerate(self.layers_sizes):
dense = tf.layers.dense(x, units=size, activation=self.activation,
name='dense_{}'.format(i + 1))
#The paper discussed introducing diversity in generation at inference time
- #by using a dropout of 0.5 only in prenet layers.
- x = tf.layers.dropout(dense, rate=self.drop_rate, training=self.is_training,
+ #by using a dropout of 0.5 only in prenet layers (in both training and inference).
+ x = tf.layers.dropout(dense, rate=self.drop_rate, training=True,
name='dropout_{}'.format(i + 1) + self.scope)
return x
@@ -139,9 +280,10 @@ def __init__(self, is_training, layers=2, size=1024, zoneout=0.1, scope=None):
self.scope = 'decoder_rnn' if scope is None else scope
#Create a set of LSTM layers
- self.rnn_layers = [ZoneoutLSTMCell(size, is_training,
+ self.rnn_layers = [ZoneoutLSTMCell(size, is_training,
zoneout_factor_cell=zoneout,
- zoneout_factor_output=zoneout) for i in range(layers)]
+ zoneout_factor_output=zoneout,
+ name='decoder_LSTM_{}'.format(i+1)) for i in range(layers)]
self._cell = tf.contrib.rnn.MultiRNNCell(self.rnn_layers, state_is_tuple=True)
@@ -164,15 +306,17 @@ def __init__(self, shape=80, activation=None, scope=None):
self.shape = shape
self.activation = activation
-
+
self.scope = 'Linear_projection' if scope is None else scope
+ self.dense = tf.layers.Dense(units=shape, activation=activation, name='projection_{}'.format(self.scope))
def __call__(self, inputs):
with tf.variable_scope(self.scope):
#If activation==None, this returns a simple Linear projection
#else the projection will be passed through an activation function
- output = tf.layers.dense(inputs, units=self.shape, activation=self.activation,
- name='projection_{}'.format(self.scope))
+ # output = tf.layers.dense(inputs, units=self.shape, activation=self.activation,
+ # name='projection_{}'.format(self.scope))
+ output = self.dense(inputs)
return output
@@ -180,7 +324,7 @@ def __call__(self, inputs):
class StopProjection:
"""Projection to a scalar and through a sigmoid activation
"""
- def __init__(self, is_training, shape=hparams.outputs_per_step, activation=tf.nn.sigmoid, scope=None):
+ def __init__(self, is_training, shape=1, activation=tf.nn.sigmoid, scope=None):
"""
Args:
is_training: Boolean, to control the use of sigmoid function as it is useless to use it
@@ -191,7 +335,7 @@ def __init__(self, is_training, shape=hparams.outputs_per_step, activation=tf.nn
"""
super(StopProjection, self).__init__()
self.is_training = is_training
-
+
self.shape = shape
self.activation = activation
self.scope = 'stop_token_projection' if scope is None else scope
@@ -210,7 +354,7 @@ def __call__(self, inputs):
class Postnet:
"""Postnet that takes final decoder output and fine tunes it (using vision on past and future frames)
"""
- def __init__(self, is_training, kernel_size=(5, ), channels=512, activation=tf.nn.tanh, scope=None):
+ def __init__(self, is_training, hparams, activation=tf.nn.tanh, scope=None):
"""
Args:
is_training: Boolean, determines if the model is training or in inference to control dropout
@@ -222,16 +366,127 @@ def __init__(self, is_training, kernel_size=(5, ), channels=512, activation=tf.n
super(Postnet, self).__init__()
self.is_training = is_training
- self.kernel_size = kernel_size
- self.channels = channels
+ self.kernel_size = hparams.postnet_kernel_size
+ self.channels = hparams.postnet_channels
self.activation = activation
self.scope = 'postnet_convolutions' if scope is None else scope
+ self.postnet_num_layers = hparams.postnet_num_layers
+ self.drop_rate = hparams.tacotron_dropout_rate
+ self.bnorm = hparams.batch_norm_position
def __call__(self, inputs):
with tf.variable_scope(self.scope):
x = inputs
- for i in range(hparams.postnet_num_layers - 1):
+ for i in range(self.postnet_num_layers - 1):
x = conv1d(x, self.kernel_size, self.channels, self.activation,
- self.is_training, 'conv_layer_{}_'.format(i + 1)+self.scope)
- x = conv1d(x, self.kernel_size, self.channels, lambda _: _, self.is_training, 'conv_layer_{}_'.format(5)+self.scope)
- return x
\ No newline at end of file
+ self.is_training, self.bnorm, 'conv_layer_{}_'.format(i + 1)+self.scope)
+ x = conv1d(x, self.kernel_size, self.channels, lambda _: _, self.is_training, self.bnorm,
+ 'conv_layer_{}_'.format(5)+self.scope)
+ return x
+
+
+def conv1d(inputs, kernel_size, channels, activation, is_training, bnorm, scope):
+ assert bnorm in ('before', 'after')
+ with tf.variable_scope(scope):
+ conv1d_output = tf.layers.conv1d(
+ inputs,
+ filters=channels,
+ kernel_size=kernel_size,
+ activation=activation if bnorm == 'after' else None,
+ padding='same')
+ batched = tf.layers.batch_normalization(conv1d_output, training=is_training)
+ return activation(batched) if bnorm == 'before' else batched
+
+def _round_up_tf(x, multiple):
+ # Tf version of remainder = x % multiple
+ remainder = tf.mod(x, multiple)
+ # Tf version of return x if remainder == 0 else x + multiple - remainder
+ x_round = tf.cond(tf.equal(remainder, tf.zeros(tf.shape(remainder), dtype=tf.int32)),
+ lambda: x,
+ lambda: x + multiple - remainder)
+
+ return x_round
+
+def sequence_mask(lengths, r, expand=True):
+ '''Returns a 2-D or 3-D tensorflow sequence mask depending on the argument 'expand'
+ '''
+ max_len = tf.reduce_max(lengths)
+ max_len = _round_up_tf(max_len, tf.convert_to_tensor(r))
+ if expand:
+ return tf.expand_dims(tf.sequence_mask(lengths, maxlen=max_len, dtype=tf.float32), axis=-1)
+ return tf.sequence_mask(lengths, maxlen=max_len, dtype=tf.float32)
+
+def MaskedMSE(targets, outputs, targets_lengths, hparams, mask=None):
+ '''Computes a masked Mean Squared Error
+ '''
+
+ #[batch_size, time_dimension, 1]
+ #example:
+ #sequence_mask([1, 3, 2], 5) = [[[1., 0., 0., 0., 0.]],
+ # [[1., 1., 1., 0., 0.]],
+ # [[1., 1., 0., 0., 0.]]]
+ #Note the maxlen argument that ensures mask shape is compatible with r>1
+ #This will by default mask the extra paddings caused by r>1
+ if mask is None:
+ mask = sequence_mask(targets_lengths, hparams.outputs_per_step, True)
+
+ #[batch_size, time_dimension, channel_dimension(mels)]
+ ones = tf.ones(shape=[tf.shape(mask)[0], tf.shape(mask)[1], tf.shape(targets)[-1]], dtype=tf.float32)
+ mask_ = mask * ones
+
+ with tf.control_dependencies([tf.assert_equal(tf.shape(targets), tf.shape(mask_))]):
+ return tf.losses.mean_squared_error(labels=targets, predictions=outputs, weights=mask_)
+
+def MaskedSigmoidCrossEntropy(targets, outputs, targets_lengths, hparams, mask=None):
+ '''Computes a masked SigmoidCrossEntropy with logits
+ '''
+
+ #[batch_size, time_dimension]
+ #example:
+ #sequence_mask([1, 3, 2], 5) = [[1., 0., 0., 0., 0.],
+ # [1., 1., 1., 0., 0.],
+ # [1., 1., 0., 0., 0.]]
+ #Note the maxlen argument that ensures mask shape is compatible with r>1
+ #This will by default mask the extra paddings caused by r>1
+ if mask is None:
+ mask = sequence_mask(targets_lengths, hparams.outputs_per_step, False)
+
+ with tf.control_dependencies([tf.assert_equal(tf.shape(targets), tf.shape(mask))]):
+ #Use a weighted sigmoid cross entropy to measure the loss. Set hparams.cross_entropy_pos_weight to 1
+ #will have the same effect as vanilla tf.nn.sigmoid_cross_entropy_with_logits.
+ losses = tf.nn.weighted_cross_entropy_with_logits(targets=targets, logits=outputs, pos_weight=hparams.cross_entropy_pos_weight)
+
+ with tf.control_dependencies([tf.assert_equal(tf.shape(mask), tf.shape(losses))]):
+ masked_loss = losses * mask
+
+ return tf.reduce_sum(masked_loss) / tf.count_nonzero(masked_loss, dtype=tf.float32)
+
+def MaskedLinearLoss(targets, outputs, targets_lengths, hparams, mask=None):
+ '''Computes a masked MAE loss with priority to low frequencies
+ '''
+
+ #[batch_size, time_dimension, 1]
+ #example:
+ #sequence_mask([1, 3, 2], 5) = [[[1., 0., 0., 0., 0.]],
+ # [[1., 1., 1., 0., 0.]],
+ # [[1., 1., 0., 0., 0.]]]
+ #Note the maxlen argument that ensures mask shape is compatible with r>1
+ #This will by default mask the extra paddings caused by r>1
+ if mask is None:
+ mask = sequence_mask(targets_lengths, hparams.outputs_per_step, True)
+
+ #[batch_size, time_dimension, channel_dimension(freq)]
+ ones = tf.ones(shape=[tf.shape(mask)[0], tf.shape(mask)[1], tf.shape(targets)[-1]], dtype=tf.float32)
+ mask_ = mask * ones
+
+ l1 = tf.abs(targets - outputs)
+ n_priority_freq = int(2000 / (hparams.sample_rate * 0.5) * hparams.num_freq)
+
+ with tf.control_dependencies([tf.assert_equal(tf.shape(targets), tf.shape(mask_))]):
+ masked_l1 = l1 * mask_
+ masked_l1_low = masked_l1[:,:,0:n_priority_freq]
+
+ mean_l1 = tf.reduce_sum(masked_l1) / tf.reduce_sum(mask_)
+ mean_l1_low = tf.reduce_sum(masked_l1_low) / tf.reduce_sum(mask_)
+
+ return 0.5 * mean_l1 + 0.5 * mean_l1_low
diff --git a/tacotron/models/rnn_wrappers.py b/tacotron/models/rnn_wrappers.py
deleted file mode 100644
index c7eab6e1..00000000
--- a/tacotron/models/rnn_wrappers.py
+++ /dev/null
@@ -1,99 +0,0 @@
-"""A set of RNN wrappers usefull for tacotron 2 architecture
-All notations and variable names were used in concordance with originial tensorflow implementation
-Some tensors were passed through wrappers to make sure we respect the described architecture
-"""
-
-import numpy as np
-import tensorflow as tf
-from tensorflow.contrib.rnn import RNNCell
-from .modules import prenet, projection
-from tensorflow.python.framework import ops
-from hparams import hparams
-
-
-
-class DecoderPrenetWrapper(RNNCell):
- '''Runs RNN inputs through a prenet before sending them to the cell.'''
- def __init__(self, cell, is_training):
- super(DecoderPrenetWrapper, self).__init__()
- self._cell = cell
- self._is_training = is_training
-
- @property
- def state_size(self):
- return self._cell.state_size
-
- @property
- def output_size(self):
- return self._cell.output_size
-
- def call(self, inputs, state):
- prenet_out = prenet(inputs, self._is_training, hparams.prenet_layers, scope='decoder_attention_prenet')
- self._prenet_out = prenet_out
- return self._cell(prenet_out, state)
-
- def zero_state(self, batch_size, dtype):
- return self._cell.zero_state(batch_size, dtype)
-
-
-class ConcatPrenetAndAttentionWrapper(RNNCell):
- '''Concatenates prenet output with the attention context vector.
- This is expected to wrap a cell wrapped with an AttentionWrapper constructed with
- attention_layer_size=None and output_attention=False. Such a cell's state will include an
- "attention" field that is the context vector.
- '''
- def __init__(self, cell):
- super(ConcatPrenetAndAttentionWrapper, self).__init__()
- self._cell = cell
-
- @property
- def state_size(self):
- return self._cell.state_size
-
- @property
- def output_size(self):
- #attention is stored in attentionwrapper cell state
- return self._cell.output_size + self._cell.state_size.attention
-
- def call(self, inputs, state):
- #We assume paper writers mentionned the attention network output when
- #they say "The pre-net output and attention context vector are concatenated and
- #passed through a stack of 2 uni-directional LSTM layers"
- #We rely on the original tacotron architecture for this hypothesis.
- output, res_state = self._cell(inputs, state)
-
- #Store attention in this wrapper to make access easier from future wrappers
- self._context_vector = res_state.attention
- return tf.concat([output, self._context_vector], axis=-1), res_state
-
- def zero_state(self, batch_size, dtype):
- return self._cell.zero_state(batch_size, dtype)
-
-
-class ConcatLSTMOutputAndAttentionWrapper(RNNCell):
- '''Concatenates decoder RNN cell output with the attention context vector.
- This is expected to wrap a cell wrapped with an AttentionWrapper constructed with
- attention_layer_size=None and output_attention=False. Such a cell's state will include an
- "attention" field that is the context vector.
- '''
- def __init__(self, cell):
- super(ConcatLSTMOutputAndAttentionWrapper, self).__init__()
- self._cell = cell
- self._prenet_attention_cell = self._cell._cells[0]
-
- @property
- def state_size(self):
- return self._cell.state_size
-
- @property
- def output_size(self):
- return self._cell.output_size + self._prenet_attention_cell.state_size.attention
-
- def call(self, inputs, state):
- output, res_state = self._cell(inputs, state)
- context_vector = self._prenet_attention_cell._context_vector
- self.lstm_concat_context = tf.concat([output, context_vector], axis=-1)
- return self.lstm_concat_context, res_state
-
- def zero_state(self, batch_size, dtype):
- return self._cell.zero_state(batch_size, dtype)
diff --git a/tacotron/models/tacotron.py b/tacotron/models/tacotron.py
index 2a5297f0..ed4642a2 100644
--- a/tacotron/models/tacotron.py
+++ b/tacotron/models/tacotron.py
@@ -1,20 +1,12 @@
-import tensorflow as tf
-from tacotron.utils.symbols import symbols
-from tacotron.utils.infolog import log
-from tacotron.models.helpers import TacoTrainingHelper, TacoTestHelper
+import tensorflow as tf
+from infolog import log
+from tacotron.models.Architecture_wrappers import TacotronDecoderCell, TacotronEncoderCell
+from tacotron.models.attention import LocationSensitiveAttention
+from tacotron.models.custom_decoder import CustomDecoder
+from tacotron.models.helpers import TacoTestHelper, TacoTrainingHelper
from tacotron.models.modules import *
-from tacotron.models.zoneout_LSTM import ZoneoutLSTMCell
+from tacotron.utils.symbols import symbols
from tensorflow.contrib.seq2seq import dynamic_decode
-from tacotron.models.Architecture_wrappers import TacotronEncoderCell, TacotronDecoderCell
-from tacotron.models.custom_decoder import CustomDecoder
-
-if int(tf.__version__.replace('.', '')) < 160:
- log('using old attention Tensorflow structure (1.5.0 and earlier)')
- from tacotron.models.attention_old import LocationSensitiveAttention
-else:
- log('using new attention Tensorflow structure (1.6.0 and later)')
- from tacotron.models.attention import LocationSensitiveAttention
-
class Tacotron():
@@ -23,7 +15,9 @@ class Tacotron():
def __init__(self, hparams):
self._hparams = hparams
- def initialize(self, inputs, input_lengths, mel_targets=None, stop_token_targets=None, gta=False):
+
+ def initialize(self, inputs, input_lengths, mel_targets=None, stop_token_targets=None, linear_targets=None, targets_lengths=None, gta=False,
+ global_step=None, is_training=False, is_evaluating=False):
"""
Initializes the model for inference
@@ -42,11 +36,24 @@ def initialize(self, inputs, input_lengths, mel_targets=None, stop_token_targets
raise ValueError('no mel targets were provided but token_targets were given')
if mel_targets is not None and stop_token_targets is None and not gta:
raise ValueError('Mel targets are provided without corresponding token_targets')
+ if not gta and self._hparams.predict_linear==True and linear_targets is None and is_training:
+ raise ValueError('Model is set to use post processing to predict linear spectrograms in training but no linear targets given!')
+ if gta and linear_targets is not None:
+ raise ValueError('Linear spectrogram prediction is not supported in GTA mode!')
+ if is_training and self._hparams.mask_decoder and targets_lengths is None:
+ raise RuntimeError('Model set to mask paddings but no targets lengths provided for the mask!')
+ if is_training and is_evaluating:
+ raise RuntimeError('Model can not be in training and evaluation modes at the same time!')
with tf.variable_scope('inference') as scope:
- is_training = mel_targets is not None and not gta
batch_size = tf.shape(inputs)[0]
hp = self._hparams
+ assert hp.tacotron_teacher_forcing_mode in ('constant', 'scheduled')
+ if hp.tacotron_teacher_forcing_mode == 'scheduled' and is_training:
+ assert global_step is not None
+
+ #GTA is only used for predicting mels to train Wavenet vocoder, so we ommit post processing when doing GTA synthesis
+ post_condition = hp.predict_linear and not gta
# Embeddings ==> [batch_size, sequence_length, embedding_dim]
embedding_table = tf.get_variable(
@@ -56,8 +63,7 @@ def initialize(self, inputs, input_lengths, mel_targets=None, stop_token_targets
#Encoder Cell ==> [batch_size, encoder_steps, encoder_lstm_units]
encoder_cell = TacotronEncoderCell(
- EncoderConvolutions(is_training, kernel_size=hp.enc_conv_kernel_size,
- channels=hp.enc_conv_channels, scope='encoder_convolutions'),
+ EncoderConvolutions(is_training, hparams=hp, scope='encoder_convolutions'),
EncoderRNN(is_training, size=hp.encoder_lstm_units,
zoneout=hp.tacotron_zoneout_rate, scope='encoder_LSTM'))
@@ -69,17 +75,18 @@ def initialize(self, inputs, input_lengths, mel_targets=None, stop_token_targets
#Decoder Parts
#Attention Decoder Prenet
- prenet = Prenet(is_training, layer_sizes=hp.prenet_layers, scope='decoder_prenet')
+ prenet = Prenet(is_training, layers_sizes=hp.prenet_layers, drop_rate=hp.tacotron_dropout_rate, scope='decoder_prenet')
#Attention Mechanism
- attention_mechanism = LocationSensitiveAttention(hp.attention_dim, encoder_outputs,
- mask_encoder=hp.mask_encoder, memory_sequence_length=input_lengths, smoothing=hp.smoothing)
+ attention_mechanism = LocationSensitiveAttention(hp.attention_dim, encoder_outputs, hparams=hp,
+ is_training=is_training, mask_encoder=hp.mask_encoder, memory_sequence_length=input_lengths,
+ smoothing=hp.smoothing, cumulate_weights=hp.cumulative_weights)
#Decoder LSTM Cells
decoder_lstm = DecoderRNN(is_training, layers=hp.decoder_layers,
size=hp.decoder_lstm_units, zoneout=hp.tacotron_zoneout_rate, scope='decoder_lstm')
#Frames Projection layer
frame_projection = FrameProjection(hp.num_mels * hp.outputs_per_step, scope='linear_transform')
# projection layer
- stop_projection = StopProjection(is_training, scope='stop_token_projection')
+ stop_projection = StopProjection(is_training or is_evaluating, shape=hp.outputs_per_step, scope='stop_token_projection')
#Decoder Cell ==> [batch_size, decoder_steps, num_mels * r] (after decoding)
@@ -88,45 +95,43 @@ def initialize(self, inputs, input_lengths, mel_targets=None, stop_token_targets
attention_mechanism,
decoder_lstm,
frame_projection,
- stop_projection,
- mask_finished=hp.mask_finished)
+ stop_projection)
#Define the helper for our decoder
- if (is_training or gta) == True:
- self.helper = TacoTrainingHelper(batch_size, mel_targets, stop_token_targets,
- hp.num_mels, hp.outputs_per_step, hp.tacotron_teacher_forcing_ratio, gta)
+ if is_training or is_evaluating or gta:
+ self.helper = TacoTrainingHelper(batch_size, mel_targets, hp, gta, is_evaluating, global_step)
else:
- self.helper = TacoTestHelper(batch_size, hp.num_mels, hp.outputs_per_step)
+ self.helper = TacoTestHelper(batch_size, hp)
#initial decoder state
decoder_init_state = decoder_cell.zero_state(batch_size=batch_size, dtype=tf.float32)
#Only use max iterations at synthesis time
- max_iters = hp.max_iters if not is_training else None
+ max_iters = hp.max_iters if not (is_training or is_evaluating) else None
#Decode
(frames_prediction, stop_token_prediction, _), final_decoder_state, _ = dynamic_decode(
CustomDecoder(decoder_cell, self.helper, decoder_init_state),
- impute_finished=hp.impute_finished,
- maximum_iterations=max_iters)
+ impute_finished=False,
+ maximum_iterations=max_iters,
+ swap_memory=hp.tacotron_swap_with_cpu)
- # Reshape outputs to be one output per entry
+ # Reshape outputs to be one output per entry
#==> [batch_size, non_reduced_decoder_steps (decoder_steps * r), num_mels]
decoder_output = tf.reshape(frames_prediction, [batch_size, -1, hp.num_mels])
stop_token_prediction = tf.reshape(stop_token_prediction, [batch_size, -1])
-
+
#Postnet
- postnet = Postnet(is_training, kernel_size=hp.postnet_kernel_size,
- channels=hp.postnet_channels, scope='postnet_convolutions')
+ postnet = Postnet(is_training, hparams=hp, scope='postnet_convolutions')
#Compute residual using post-net ==> [batch_size, decoder_steps * r, postnet_channels]
residual = postnet(decoder_output)
- #Project residual to same dimension as mel spectrogram
+ #Project residual to same dimension as mel spectrogram
#==> [batch_size, decoder_steps * r, num_mels]
residual_projection = FrameProjection(hp.num_mels, scope='postnet_projection')
projected_residual = residual_projection(residual)
@@ -135,9 +140,28 @@ def initialize(self, inputs, input_lengths, mel_targets=None, stop_token_targets
#Compute the mel spectrogram
mel_outputs = decoder_output + projected_residual
+
+ if post_condition:
+ # Add post-processing CBHG. This does a great job at extracting features from mels before projection to Linear specs.
+ post_cbhg = CBHG(hp.cbhg_kernels, hp.cbhg_conv_channels, hp.cbhg_pool_size, [hp.cbhg_projection, hp.num_mels],
+ hp.cbhg_projection_kernel_size, hp.cbhg_highwaynet_layers,
+ hp.cbhg_highway_units, hp.cbhg_rnn_units, hp.batch_norm_position, is_training, name='CBHG_postnet')
+
+ #[batch_size, decoder_steps(mel_frames), cbhg_channels]
+ post_outputs = post_cbhg(mel_outputs, None)
+
+ #Linear projection of extracted features to make linear spectrogram
+ linear_specs_projection = FrameProjection(hp.num_freq, scope='cbhg_linear_specs_projection')
+
+ #[batch_size, decoder_steps(linear_frames), num_freq]
+ linear_outputs = linear_specs_projection(post_outputs)
+
+
#Grab alignments from the final decoder state
alignments = tf.transpose(final_decoder_state.alignment_history.stack(), [1, 2, 0])
+ if is_training:
+ self.ratio = self.helper._ratio
self.inputs = inputs
self.input_lengths = input_lengths
self.decoder_output = decoder_output
@@ -145,8 +169,16 @@ def initialize(self, inputs, input_lengths, mel_targets=None, stop_token_targets
self.stop_token_prediction = stop_token_prediction
self.stop_token_targets = stop_token_targets
self.mel_outputs = mel_outputs
+ if post_condition:
+ self.linear_outputs = linear_outputs
+ self.linear_targets = linear_targets
self.mel_targets = mel_targets
- log('Initialized Tacotron model. Dimensions: ')
+ self.targets_lengths = targets_lengths
+ log('Initialized Tacotron model. Dimensions (? = dynamic shape): ')
+ log(' Train mode: {}'.format(is_training))
+ log(' Eval mode: {}'.format(is_evaluating))
+ log(' GTA mode: {}'.format(gta))
+ log(' Synthesis mode: {}'.format(not (is_training or is_evaluating)))
log(' embedding: {}'.format(embedded_inputs.shape))
log(' enc conv out: {}'.format(enc_conv_output_shape))
log(' encoder out: {}'.format(encoder_outputs.shape))
@@ -154,6 +186,8 @@ def initialize(self, inputs, input_lengths, mel_targets=None, stop_token_targets
log(' residual out: {}'.format(residual.shape))
log(' projected residual out: {}'.format(projected_residual.shape))
log(' mel out: {}'.format(mel_outputs.shape))
+ if post_condition:
+ log(' linear out: {}'.format(linear_outputs.shape))
log(' out: {}'.format(stop_token_prediction.shape))
@@ -162,28 +196,73 @@ def add_loss(self):
with tf.variable_scope('loss') as scope:
hp = self._hparams
- # Compute loss of predictions before postnet
- before = tf.losses.mean_squared_error(self.mel_targets, self.decoder_output)
- # Compute loss after postnet
- after = tf.losses.mean_squared_error(self.mel_targets, self.mel_outputs)
- #Compute loss (for learning dynamic generation stop)
- stop_token_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
- labels=self.stop_token_targets,
- logits=self.stop_token_prediction))
+ if hp.mask_decoder:
+ # Compute loss of predictions before postnet
+ before = MaskedMSE(self.mel_targets, self.decoder_output, self.targets_lengths,
+ hparams=self._hparams)
+ # Compute loss after postnet
+ after = MaskedMSE(self.mel_targets, self.mel_outputs, self.targets_lengths,
+ hparams=self._hparams)
+ #Compute loss (for learning dynamic generation stop)
+ stop_token_loss = MaskedSigmoidCrossEntropy(self.stop_token_targets,
+ self.stop_token_prediction, self.targets_lengths, hparams=self._hparams)
+ #Compute masked linear loss
+ if hp.predict_linear:
+ #Compute Linear L1 mask loss (priority to low frequencies)
+ linear_loss = MaskedLinearLoss(self.linear_targets, self.linear_outputs,
+ self.targets_lengths, hparams=self._hparams)
+ else:
+ linear_loss=0.
+ else:
+ # guided_attention loss
+ N = self._hparams.max_text_length
+ T = self._hparams.max_mel_frames // self._hparams.outputs_per_step
+ A = tf.pad(self.alignments, [(0, 0), (0, N), (0, T)], mode="CONSTANT", constant_values=-1.)[:, :N, :T]
+ gts = tf.convert_to_tensor(GuidedAttention(N, T))
+ attention_masks = tf.to_float(tf.not_equal(A, -1))
+ attention_loss = tf.reduce_sum(tf.abs(A * gts) * attention_masks)
+ attention_loss /= tf.reduce_sum(attention_masks)
+ # Compute loss of predictions before postnet
+ before = tf.losses.mean_squared_error(self.mel_targets, self.decoder_output)
+ # Compute loss after postnet
+ after = tf.losses.mean_squared_error(self.mel_targets, self.mel_outputs)
+ #Compute loss (for learning dynamic generation stop)
+ stop_token_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
+ labels=self.stop_token_targets,
+ logits=self.stop_token_prediction))
+
+ if hp.predict_linear:
+ #Compute linear loss
+ #From https://github.com/keithito/tacotron/blob/tacotron2-work-in-progress/models/tacotron.py
+ #Prioritize loss for frequencies under 2000 Hz.
+ # l1 = tf.abs(self.linear_targets - self.linear_outputs)
+ # n_priority_freq = int(4000 / (hp.sample_rate * 0.5) * hp.num_freq)
+ # linear_loss = 0.5 * tf.reduce_mean(l1) + 0.5 * tf.reduce_mean(l1[:,:,0:n_priority_freq])
+ linear_loss = tf.losses.mean_squared_error(self.linear_targets, self.linear_outputs)
+ else:
+ linear_loss = 0.
+
+ # Compute the regularization weight
+ if hp.tacotron_scale_regularization:
+ reg_weight_scaler = 1. / (2 * hp.max_abs_value) if hp.symmetric_mels else 1. / (hp.max_abs_value)
+ reg_weight = hp.tacotron_reg_weight * reg_weight_scaler
+ else:
+ reg_weight = hp.tacotron_reg_weight
# Get all trainable variables
all_vars = tf.trainable_variables()
- # Compute the regularization term
regularization = tf.add_n([tf.nn.l2_loss(v) for v in all_vars
- if not('bias' in v.name or 'Bias' in v.name)]) * hp.tacotron_reg_weight
+ if not('bias' in v.name or 'Bias' in v.name)]) * reg_weight
# Compute final loss term
self.before_loss = before
self.after_loss = after
self.stop_token_loss = stop_token_loss
self.regularization_loss = regularization
+ self.linear_loss = linear_loss
+ self.attention_loss = attention_loss
- self.loss = self.before_loss + self.after_loss + self.stop_token_loss + self.regularization_loss
+ self.loss = self.before_loss + self.after_loss + self.stop_token_loss + self.regularization_loss + self.linear_loss + self.attention_loss
def add_optimizer(self, global_step):
'''Adds optimizer. Sets "gradients" and "optimize" fields. add_loss must have been called.
@@ -204,21 +283,41 @@ def add_optimizer(self, global_step):
hp.tacotron_adam_beta2, hp.tacotron_adam_epsilon)
gradients, variables = zip(*optimizer.compute_gradients(self.loss))
self.gradients = gradients
+ #Just for causion
+ #https://github.com/Rayhane-mamah/Tacotron-2/issues/11
+ if hp.tacotron_clip_gradients:
+ clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.)
+ else:
+ clipped_gradients = gradients
# Add dependency on UPDATE_OPS; otherwise batchnorm won't work correctly. See:
# https://github.com/tensorflow/tensorflow/issues/1122
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
- self.optimize = optimizer.apply_gradients(zip(gradients, variables),
+ self.optimize = optimizer.apply_gradients(zip(clipped_gradients, variables),
global_step=global_step)
def _learning_rate_decay(self, init_lr, global_step):
- # Exponential decay
- # We won't drop learning rate below 10e-5
+ #################################################################
+ # Narrow Exponential Decay:
+
+ # Phase 1: lr = 1e-3
+ # We only start learning rate decay after 50k steps
+
+ # Phase 2: lr in ]1e-5, 1e-3[
+ # decay reach minimal value at step 310k
+
+ # Phase 3: lr = 1e-5
+ # clip by minimal learning rate value (step > 310k)
+ #################################################################
hp = self._hparams
- step = tf.cast(global_step + 1, dtype=tf.float32)
- lr = tf.train.exponential_decay(init_lr,
- global_step - self.decay_steps + 1,
- self.decay_steps,
- self.decay_rate,
- name='exponential_decay')
- return tf.maximum(hp.tacotron_final_learning_rate, lr)
\ No newline at end of file
+
+ #Compute natural exponential decay
+ lr = tf.train.exponential_decay(init_lr,
+ global_step - hp.tacotron_start_decay, #lr = 1e-3 at step 50k
+ self.decay_steps,
+ self.decay_rate, #lr = 1e-5 around step 310k
+ name='lr_exponential_decay')
+
+
+ #clip learning rate by max and min values (initial and final values)
+ return tf.minimum(tf.maximum(lr, hp.tacotron_final_learning_rate), init_lr)
diff --git a/tacotron/models/zoneout_LSTM.py b/tacotron/models/zoneout_LSTM.py
deleted file mode 100644
index 5de4fbbd..00000000
--- a/tacotron/models/zoneout_LSTM.py
+++ /dev/null
@@ -1,265 +0,0 @@
-import numpy as np
-import tensorflow as tf
-from tensorflow.python.ops.rnn_cell import RNNCell
-
-
-# Thanks to 'initializers_enhanced.py' of Project RNN Enhancement:
-# https://github.com/nicolas-ivanov/Seq2Seq_Upgrade_TensorFlow/blob/master/rnn_enhancement/initializers_enhanced.py
-def orthogonal_initializer(scale=1.0):
- def _initializer(shape, dtype=tf.float32):
- flat_shape = (shape[0], np.prod(shape[1:]))
- a = np.random.normal(0.0, 1.0, flat_shape)
- u, _, v = np.linalg.svd(a, full_matrices=False)
- q = u if u.shape == flat_shape else v
- q = q.reshape(shape)
- return tf.constant(scale * q[:shape[0], :shape[1]], dtype=tf.float32)
- return _initializer
-
-
-class ZoneoutLSTMCell(RNNCell):
- """Zoneout Regularization for LSTM-RNN.
- """
-
- def __init__(self, num_units, is_training, input_size=None,
- use_peepholes=False, cell_clip=None,
- #initializer=orthogonal_initializer(),
- initializer=tf.contrib.layers.xavier_initializer(),
- num_proj=None, proj_clip=None, ext_proj=None,
- forget_bias=1.0,
- state_is_tuple=True,
- activation=tf.tanh,
- zoneout_factor_cell=0.0,
- zoneout_factor_output=0.0,
- reuse=None):
- """Initialize the parameters for an LSTM cell.
- Args:
- num_units: int, The number of units in the LSTM cell.
- is_training: bool, set True when training.
- use_peepholes: bool, set True to enable diagonal/peephole
- connections.
- cell_clip: (optional) A float value, if provided the cell state
- is clipped by this value prior to the cell output activation.
- initializer: (optional) The initializer to use for the weight
- matrices.
- num_proj: (optional) int, The output dimensionality for
- the projection matrices. If None, no projection is performed.
- forget_bias: Biases of the forget gate are initialized by default
- to 1 in order to reduce the scale of forgetting at the beginning of
- the training.
- activation: Activation function of the inner states.
- """
- if not state_is_tuple:
- tf.logging.warn(
- "%s: Using a concatenated state is slower and will soon be "
- "deprecated. Use state_is_tuple=True.", self)
- if input_size is not None:
- tf.logging.warn(
- "%s: The input_size parameter is deprecated.", self)
-
- if not (zoneout_factor_cell >= 0.0 and zoneout_factor_cell <= 1.0):
- raise ValueError(
- "Parameter zoneout_factor_cell must be in [0 1]")
-
- if not (zoneout_factor_output >= 0.0 and zoneout_factor_output <= 1.0):
- raise ValueError(
- "Parameter zoneout_factor_cell must be in [0 1]")
-
- self.num_units = num_units
- self.is_training = is_training
- self.use_peepholes = use_peepholes
- self.cell_clip = cell_clip
- self.num_proj = num_proj
- self.proj_clip = proj_clip
- self.initializer = initializer
- self.forget_bias = forget_bias
- self.state_is_tuple = state_is_tuple
- self.activation = activation
- self.zoneout_factor_cell = zoneout_factor_cell
- self.zoneout_factor_output = zoneout_factor_output
-
- if num_proj:
- self._state_size = (
- tf.nn.rnn_cell.LSTMStateTuple(num_units, num_proj)
- if state_is_tuple else num_units + num_proj)
- self._output_size = num_proj
- else:
- self._state_size = (
- tf.nn.rnn_cell.LSTMStateTuple(num_units, num_units)
- if state_is_tuple else 2 * num_units)
- self._output_size = num_units
-
- self._ext_proj = ext_proj
-
- @property
- def state_size(self):
- return self._state_size
-
- @property
- def output_size(self):
- if self._ext_proj is None:
- return self._output_size
- return self._ext_proj
-
- def __call__(self, inputs, state, scope=None):
-
- num_proj = self.num_units if self.num_proj is None else self.num_proj
-
- if self.state_is_tuple:
- (c_prev, h_prev) = state
- else:
- c_prev = tf.slice(state, [0, 0], [-1, self.num_units])
- h_prev = tf.slice(state, [0, self.num_units], [-1, num_proj])
-
- # c_prev : Tensor with the size of [batch_size, state_size]
- # h_prev : Tensor with the size of [batch_size, state_size/2]
-
- dtype = inputs.dtype
- input_size = inputs.get_shape().with_rank(2)[1]
-
- with tf.variable_scope(scope or type(self).__name__):
- if input_size.value is None:
- raise ValueError(
- "Could not infer input size from inputs.get_shape()[-1]")
-
- # i = input_gate, j = new_input, f = forget_gate, o = output_gate
- lstm_matrix = _linear([inputs, h_prev], 4 * self.num_units, True)
- i, j, f, o = tf.split(lstm_matrix, 4, 1)
-
- # diagonal connections
- if self.use_peepholes:
- w_f_diag = tf.get_variable(
- "W_F_diag", shape=[self.num_units], dtype=dtype)
- w_i_diag = tf.get_variable(
- "W_I_diag", shape=[self.num_units], dtype=dtype)
- w_o_diag = tf.get_variable(
- "W_O_diag", shape=[self.num_units], dtype=dtype)
-
- with tf.name_scope(None, "zoneout"):
- # make binary mask tensor for cell
- keep_prob_cell = tf.convert_to_tensor(
- self.zoneout_factor_cell,
- dtype=c_prev.dtype
- )
- random_tensor_cell = keep_prob_cell
- random_tensor_cell += \
- tf.random_uniform(tf.shape(c_prev),
- seed=None, dtype=c_prev.dtype)
- binary_mask_cell = tf.floor(random_tensor_cell)
- # 0 <-> 1 swap
- binary_mask_cell_complement = tf.ones(tf.shape(c_prev)) \
- - binary_mask_cell
-
- # make binary mask tensor for output
- keep_prob_output = tf.convert_to_tensor(
- self.zoneout_factor_output,
- dtype=h_prev.dtype
- )
- random_tensor_output = keep_prob_output
- random_tensor_output += \
- tf.random_uniform(tf.shape(h_prev),
- seed=None, dtype=h_prev.dtype)
- binary_mask_output = tf.floor(random_tensor_output)
- # 0 <-> 1 swap
- binary_mask_output_complement = tf.ones(tf.shape(h_prev)) \
- - binary_mask_output
-
- # apply zoneout for cell
- if self.use_peepholes:
- c_temp = c_prev * \
- tf.sigmoid(f + self.forget_bias +
- w_f_diag * c_prev) + \
- tf.sigmoid(i + w_i_diag * c_prev) * \
- self.activation(j)
- if self.is_training and self.zoneout_factor_cell > 0.0:
- c = binary_mask_cell * c_prev + \
- binary_mask_cell_complement * c_temp
- else:
- c = c_temp
- else:
- c_temp = c_prev * tf.sigmoid(f + self.forget_bias) + \
- tf.sigmoid(i) * self.activation(j)
- if self.is_training and self.zoneout_factor_cell > 0.0:
- c = binary_mask_cell * c_prev + \
- binary_mask_cell_complement * c_temp
- else:
- c = c_temp
-
- if self.cell_clip is not None:
- c = tf.clip_by_value(c, -self.cell_clip, self.cell_clip)
-
- # apply zoneout for output
- if self.use_peepholes:
- h_temp = tf.sigmoid(o + w_o_diag * c) * self.activation(c)
- if self.is_training and self.zoneout_factor_output > 0.0:
- h = binary_mask_output * h_prev + \
- binary_mask_output_complement * h_temp
- else:
- h = h_temp
- else:
- h_temp = tf.sigmoid(o) * self.activation(c)
- if self.is_training and self.zoneout_factor_output > 0.0:
- h = binary_mask_output * h_prev + \
- binary_mask_output_complement * h_temp
- else:
- h = h_temp
-
- # apply prejection
- if self.num_proj is not None:
- w_proj = tf.get_variable(
- "W_P", [self.num_units, num_proj], dtype=dtype)
-
- h = tf.matmul(h, w_proj)
- if self.proj_clip is not None:
- h = tf.clip_by_value(h, -self.proj_clip, self.proj_clip)
-
- new_state = (tf.nn.rnn_cell.LSTMStateTuple(c, h)
- if self.state_is_tuple else tf.concat(1, [c, h]))
-
- return h, new_state
-
-
-def _linear(args, output_size, bias, bias_start=0.0, scope=None):
- """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
- Args:
- args: a 2D Tensor or a list of 2D, batch x n, Tensors.
- output_size: int, second dimension of W[i].
- bias: boolean, whether to add a bias term or not.
- bias_start: starting value to initialize the bias; 0 by default.
- scope: VariableScope for the created subgraph; defaults to "Linear".
- Returns:
- A 2D Tensor with shape [batch x output_size] equal to
- sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
- Raises:
- ValueError: if some of the arguments has unspecified or wrong shape.
- """
- if args is None or (isinstance(args, (list, tuple)) and not args):
- raise ValueError("`args` must be specified")
- if not isinstance(args, (list, tuple)):
- args = [args]
-
- # Calculate the total size of arguments on dimension 1.
- total_arg_size = 0
- shapes = [a.get_shape().as_list() for a in args]
- for shape in shapes:
- if len(shape) != 2:
- raise ValueError(
- "Linear is expecting 2D arguments: %s" % str(shapes))
- if not shape[1]:
- raise ValueError(
- "Linear expects shape[1] of arguments: %s" % str(shapes))
- else:
- total_arg_size += shape[1]
-
- # Now the computation.
- with tf.variable_scope(scope or "Linear"):
- matrix = tf.get_variable("Matrix", [total_arg_size, output_size])
- if len(args) == 1:
- res = tf.matmul(args[0], matrix)
- else:
- res = tf.matmul(tf.concat(args, 1), matrix)
- if not bias:
- return res
- bias_term = tf.get_variable(
- "Bias", [output_size],
- initializer=tf.constant_initializer(bias_start))
- return res + bias_term
\ No newline at end of file
diff --git a/tacotron/synthesize.py b/tacotron/synthesize.py
index f193afa8..51e78dbc 100644
--- a/tacotron/synthesize.py
+++ b/tacotron/synthesize.py
@@ -1,76 +1,125 @@
import argparse
import os
import re
+import time
+from time import sleep
+from datasets import audio
+import tensorflow as tf
from hparams import hparams, hparams_debug_string
+from infolog import log
from tacotron.synthesizer import Synthesizer
-import tensorflow as tf
-import time
from tqdm import tqdm
-def run_eval(args, checkpoint_path, output_dir):
- print(hparams_debug_string())
+def generate_fast(model, text):
+ model.synthesize(text, None, None, None, None)
+
+
+def run_live(args, checkpoint_path, hparams):
+ #Log to Terminal without keeping any records in files
+ log(hparams_debug_string())
synth = Synthesizer()
- synth.load(checkpoint_path)
+ synth.load(checkpoint_path, hparams)
+
+ #Generate fast greeting message
+ greetings = 'Hello, Welcome to the Live testing tool. Please type a message and I will try to read it!'
+ log(greetings)
+ generate_fast(synth, greetings)
+
+ #Interaction loop
+ while True:
+ try:
+ text = input()
+ generate_fast(synth, text)
+
+ except KeyboardInterrupt:
+ leave = 'Thank you for testing our features. see you soon.'
+ log(leave)
+ generate_fast(synth, leave)
+ sleep(2)
+ break
+
+def run_eval(args, checkpoint_path, output_dir, hparams, sentences):
eval_dir = os.path.join(output_dir, 'eval')
log_dir = os.path.join(output_dir, 'logs-eval')
+ if args.model == 'Tacotron-2':
+ assert os.path.normpath(eval_dir) == os.path.normpath(args.mels_dir) #mels_dir = wavenet_input_dir
+
#Create output path if it doesn't exist
os.makedirs(eval_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
os.makedirs(os.path.join(log_dir, 'wavs'), exist_ok=True)
os.makedirs(os.path.join(log_dir, 'plots'), exist_ok=True)
- with open(os.path.join(eval_dir, 'map.txt'), 'w') as file:
- for i, text in enumerate(tqdm(hparams.sentences)):
- start = time.time()
- mel_filename = synth.synthesize(text, i+1, eval_dir, log_dir, None)
+ log(hparams_debug_string())
+ synth = Synthesizer()
+ synth.load(checkpoint_path, hparams)
+
+ sentences = list(map(lambda s: s.strip(), sentences))
+ delta_size = hparams.tacotron_synthesis_batch_size if hparams.tacotron_synthesis_batch_size < len(sentences) else len(sentences)
+ batch_sentences = [sentences[i: i+hparams.tacotron_synthesis_batch_size] for i in range(0, len(sentences), delta_size)]
+ start = time.time()
+ for i, batch in enumerate(tqdm(batch_sentences)):
+ audio.save_wav(synth.eval(batch), os.path.join(log_dir, 'wavs', 'eval_batch_{:03}.wav'.format(i)), hparams)
+ log('\nGenerated total batch of {} in {:.3f} sec'.format(delta_size, time.time() - start))
+
+ return eval_dir
+
+def run_synthesis(args, checkpoint_path, output_dir, hparams):
+ GTA = (args.GTA == 'True')
+ if GTA:
+ synth_dir = os.path.join(output_dir, 'gta')
+
+ #Create output path if it doesn't exist
+ os.makedirs(synth_dir, exist_ok=True)
+ else:
+ synth_dir = os.path.join(output_dir, 'natural')
+
+ #Create output path if it doesn't exist
+ os.makedirs(synth_dir, exist_ok=True)
- file.write('{}|{}\n'.format(text, mel_filename))
- print('synthesized mel spectrograms at {}'.format(eval_dir))
-def run_synthesis(args, checkpoint_path, output_dir):
metadata_filename = os.path.join(args.input_dir, 'train.txt')
- print(hparams_debug_string())
+ log(hparams_debug_string())
synth = Synthesizer()
- synth.load(checkpoint_path, gta=args.GTA)
+ synth.load(checkpoint_path, hparams, gta=GTA)
with open(metadata_filename, encoding='utf-8') as f:
metadata = [line.strip().split('|') for line in f]
frame_shift_ms = hparams.hop_size / hparams.sample_rate
- hours = sum([int(x[3]) for x in metadata]) * frame_shift_ms / (3600)
- print('Loaded metadata for {} examples ({:.2f} hours)'.format(len(metadata), hours))
+ hours = sum([int(x[4]) for x in metadata]) * frame_shift_ms / (3600)
+ log('Loaded metadata for {} examples ({:.2f} hours)'.format(len(metadata), hours))
- if args.GTA==True:
- synth_dir = os.path.join(output_dir, 'gta')
- else:
- synth_dir = os.path.join(output_dir, 'natural')
-
- #Create output path if it doesn't exist
- os.makedirs(synth_dir, exist_ok=True)
+ metadata = [metadata[i: i+hparams.tacotron_synthesis_batch_size] for i in range(0, len(metadata), hparams.tacotron_synthesis_batch_size)]
- print('starting synthesis')
+ log('starting synthesis')
mel_dir = os.path.join(args.input_dir, 'mels')
+ wav_dir = os.path.join(args.input_dir, 'audio')
with open(os.path.join(synth_dir, 'map.txt'), 'w') as file:
for i, meta in enumerate(tqdm(metadata)):
- text = meta[4]
- mel_filename = os.path.join(mel_dir, meta[1])
- mel_output_filename = synth.synthesize(text, i+1, synth_dir, None, mel_filename)
+ texts = [m[5] for m in meta]
+ mel_filenames = [os.path.join(mel_dir, m[1]) for m in meta]
+ wav_filenames = [os.path.join(wav_dir, m[0]) for m in meta]
+ basenames = [os.path.basename(m).replace('.npy', '').replace('mel-', '') for m in mel_filenames]
+ mel_output_filenames, speaker_ids = synth.synthesize(texts, basenames, synth_dir, None, mel_filenames)
- file.write('{}|{}|{}\n'.format(text, mel_filename, mel_output_filename))
- print('synthesized mel spectrograms at {}'.format(synth_dir))
+ for elems in zip(wav_filenames, mel_filenames, mel_output_filenames, speaker_ids, texts):
+ file.write('|'.join([str(x) for x in elems]) + '\n')
+ log('synthesized mel spectrograms at {}'.format(synth_dir))
+ return os.path.join(synth_dir, 'map.txt')
-def tacotron_synthesize(args):
- hparams.parse(args.hparams)
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+def tacotron_synthesize(args, hparams, checkpoint, sentences=None):
output_dir = 'tacotron_' + args.output_dir
try:
- checkpoint_path = tf.train.get_checkpoint_state(args.checkpoint).model_checkpoint_path
- print('loaded model at {}'.format(checkpoint_path))
+ checkpoint_path = tf.train.get_checkpoint_state(checkpoint).model_checkpoint_path
+ log('loaded model at {}'.format(checkpoint_path))
except:
- raise AssertionError('Cannot restore checkpoint: {}, did you train a model?'.format(args.checkpoint))
+ raise RuntimeError('Failed to load checkpoint at {}'.format(checkpoint))
if args.mode == 'eval':
- run_eval(args, checkpoint_path, output_dir)
+ return run_eval(args, checkpoint_path, output_dir, hparams, sentences)
+ elif args.mode == 'synthesis':
+ return run_synthesis(args, checkpoint_path, output_dir, hparams)
else:
- run_synthesis(args, checkpoint_path, output_dir)
+ run_live(args, checkpoint_path, hparams)
diff --git a/tacotron/synthesizer.py b/tacotron/synthesizer.py
index 720baa32..31a439a9 100644
--- a/tacotron/synthesizer.py
+++ b/tacotron/synthesizer.py
@@ -1,69 +1,203 @@
import os
+import wave
+from datetime import datetime
+import io
import numpy as np
import tensorflow as tf
-from hparams import hparams
+from datasets import audio
+from infolog import log
from librosa import effects
from tacotron.models import create_model
-from tacotron.utils.text import text_to_sequence
from tacotron.utils import plot
-from datasets import audio
-from datetime import datetime
+from tacotron.utils.text import text_to_sequence
class Synthesizer:
- def load(self, checkpoint_path, gta=False, model_name='Tacotron'):
- print('Constructing model: %s' % model_name)
- inputs = tf.placeholder(tf.int32, [1, None], 'inputs')
- input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths')
- targets = tf.placeholder(tf.float32, [1, None, hparams.num_mels], 'mel_targets')
+ def load(self, checkpoint_path, hparams, gta=False, model_name='Tacotron'):
+ log('Constructing model: %s' % model_name)
+ inputs = tf.placeholder(tf.int32, [None, None], 'inputs')
+ input_lengths = tf.placeholder(tf.int32, [None], 'input_lengths')
+ targets = tf.placeholder(tf.float32, [None, None, hparams.num_mels], 'mel_targets')
with tf.variable_scope('model') as scope:
self.model = create_model(model_name, hparams)
if gta:
self.model.initialize(inputs, input_lengths, targets, gta=gta)
- else:
+ else:
self.model.initialize(inputs, input_lengths)
+ self.alignments = self.model.alignments
self.mel_outputs = self.model.mel_outputs
- self.alignment = self.model.alignments[0]
+ self.stop_token_prediction = self.model.stop_token_prediction
+ if hparams.predict_linear and not gta:
+ self.linear_outputs = self.model.linear_outputs
+ self.linear_spectrograms = tf.placeholder(tf.float32, (None, hparams.num_freq), name='linear_spectrograms')
+ self.linear_wav_outputs = audio.inv_spectrogram_tensorflow(self.linear_spectrograms, hparams)
self.gta = gta
- print('Loading checkpoint: %s' % checkpoint_path)
- self.session = tf.Session()
+ self._hparams = hparams
+ #pad input sequences with the 0 ( _ )
+ self._pad = 0
+ #explicitely setting the padding to a value that doesn't originally exist in the spectogram
+ #to avoid any possible conflicts, without affecting the output range of the model too much
+ if hparams.symmetric_mels:
+ self._target_pad = -(hparams.max_abs_value + .1)
+ else:
+ self._target_pad = -0.1
+
+ log('Loading checkpoint: %s' % checkpoint_path)
+ #Memory allocation on the GPU as needed
+ config = tf.ConfigProto()
+ config.gpu_options.allow_growth = True
+
+ self.session = tf.Session(config=config)
self.session.run(tf.global_variables_initializer())
+
saver = tf.train.Saver()
saver.restore(self.session, checkpoint_path)
- def synthesize(self, text, index, out_dir, log_dir, mel_filename):
+ def synthesize(self, texts, basenames, out_dir, log_dir, mel_filenames):
+ hparams = self._hparams
cleaner_names = [x.strip() for x in hparams.cleaners.split(',')]
- seq = text_to_sequence(text, cleaner_names)
+ seqs = [np.asarray(text_to_sequence(text, cleaner_names)) for text in texts]
+ input_lengths = [len(seq) for seq in seqs]
+ seqs = self._prepare_inputs(seqs)
feed_dict = {
- self.model.inputs: [np.asarray(seq, dtype=np.int32)],
- self.model.input_lengths: np.asarray([len(seq)], dtype=np.int32),
+ self.model.inputs: seqs,
+ self.model.input_lengths: np.asarray(input_lengths, dtype=np.int32),
}
if self.gta:
- feed_dict[self.model.mel_targets] = np.load(mel_filename).reshape(1, -1, 80)
+ np_targets = [np.load(mel_filename) for mel_filename in mel_filenames]
+ target_lengths = [len(np_target) for np_target in np_targets]
+ padded_targets = self._prepare_targets(np_targets, self._hparams.outputs_per_step)
+ feed_dict[self.model.mel_targets] = padded_targets.reshape(len(np_targets), -1, hparams.num_mels)
+
+ if self.gta or not hparams.predict_linear:
+ mels, alignments = self.session.run([self.mel_outputs, self.alignments], feed_dict=feed_dict)
+ if self.gta:
+ mels = [mel[:target_length, :] for mel, target_length in zip(mels, target_lengths)] #Take off the reduction factor padding frames for time consistency with wavenet
+ assert len(mels) == len(np_targets)
+
+ else:
+ linears, mels, alignments, stop_tokens = self.session.run([self.linear_outputs, self.mel_outputs, self.alignments, self.stop_token_prediction], feed_dict=feed_dict)
+
+ #Get Mel/Linear lengths for the entire batch from stop_tokens predictions
+ target_lengths = self._get_output_lengths(stop_tokens)
+
+ #Take off the batch wise padding
+ mels = [mel[:target_length, :] for mel, target_length in zip(mels, target_lengths)]
+ linears = [linear[:target_length, :] for linear, target_length in zip(linears, target_lengths)]
+ assert len(mels) == len(linears) == len(texts)
+
+ # if basenames is None:
+ # #Generate wav and read it
+ # wav = audio.inv_mel_spectrogram(mels.T, hparams)
+ # audio.save_wav(wav, 'temp.wav', hparams) #Find a better way
+
+ # chunk = 512
+ # f = wave.open('temp.wav', 'rb')
+ # p = pyaudio.PyAudio()
+ # stream = p.open(format=p.get_format_from_width(f.getsampwidth()),
+ # channels=f.getnchannels(),
+ # rate=f.getframerate(),
+ # output=True)
+ # data = f.readframes(chunk)
+ # while data:
+ # stream.write(data)
+ # data=f.readframes(chunk)
+
+ # stream.stop_stream()
+ # stream.close()
+
+ # p.terminate()
+ # return
+
+
+ saved_mels_paths = []
+ speaker_ids = []
+ for i, mel in enumerate(mels):
+ #Get speaker id for global conditioning (only used with GTA generally)
+ speaker_id = ''
+ speaker_ids.append(speaker_id)
+
+ # Write the spectrogram to disk
+ # Note: outputs mel-spectrogram files and target ones have same names, just different folders
+ mel_filename = os.path.join(out_dir, '{}.npy'.format(basenames[i]))
+ np.save(mel_filename, mel.T, allow_pickle=False)
+ saved_mels_paths.append(mel_filename)
+
+ if log_dir is not None:
+ #save wav (mel -> wav)
+ wav = audio.inv_mel_spectrogram(mel.T, hparams)
+ audio.save_wav(wav, os.path.join(log_dir, 'wavs/wav-{}-mel.wav'.format(basenames[i])), hparams)
+
+ #save alignments
+ plot.plot_alignment(alignments[i], os.path.join(log_dir, 'plots/alignment-{}.png'.format(basenames[i])),
+ info='{}'.format(texts[i]), split_title=True)
+
+ #save mel spectrogram plot
+ plot.plot_spectrogram(mel, os.path.join(log_dir, 'plots/mel-{}.png'.format(basenames[i])),
+ info='{}'.format(texts[i]), split_title=True)
+
+ if hparams.predict_linear and not self.gta:
+ #save wav (linear -> wav)
+ linear_wav = self.session.run(self.linear_wav_outputs, feed_dict={self.linear_spectrograms: linears[i]})
+ wav = audio.inv_preemphasis(linear_wav, hparams.preemphasis)
+ audio.save_wav(wav, os.path.join(log_dir, 'wavs/wav-{}-linear.wav'.format(i)), hparams)
+
+ #save mel spectrogram plot
+ plot.plot_spectrogram(linears[i], os.path.join(log_dir, 'plots/linear-{}.png'.format(basenames[i])),
+ info='{}'.format(texts[i]), split_title=True, auto_aspect=True)
+
+ return saved_mels_paths, speaker_ids
+
+ def eval(self, batch):
+ hparams = self._hparams
+ cleaner_names = [x.strip() for x in hparams.cleaners.split(',')]
+ seqs = [np.asarray(text_to_sequence(text, cleaner_names)) for text in batch]
+ input_lengths = [len(seq) for seq in seqs]
+ seqs = self._prepare_inputs(seqs)
+ feed_dict = {
+ self.model.inputs: seqs,
+ self.model.input_lengths: np.asarray(input_lengths, dtype=np.int32),
+ }
+
+ linears, stop_tokens = self.session.run([self.linear_outputs, self.stop_token_prediction], feed_dict=feed_dict)
+
+ #Get Mel/Linear lengths for the entire batch from stop_tokens predictions
+ target_lengths = self._get_output_lengths(stop_tokens)
+
+ #Take off the batch wise padding
+ linears = [linear[:target_length, :] for linear, target_length in zip(linears, target_lengths)]
+ assert len(linears) == len(batch)
- mels, alignment = self.session.run([self.mel_outputs, self.alignment], feed_dict=feed_dict)
+ #save wav (linear -> wav)
+ results = []
+ for i, linear in enumerate(linears):
+ linear_wav = self.session.run(self.linear_wav_outputs, feed_dict={self.linear_spectrograms: linear})
+ wav = audio.inv_preemphasis(linear_wav, hparams.preemphasis)
+ results.append(np.concatenate((wav, np.zeros(self._hparams.hop_size * 40))))
+ return np.concatenate(results)
- mels = mels.reshape(-1, 80) #Thanks to @imdatsolak for pointing this out
+ def _round_up(self, x, multiple):
+ remainder = x % multiple
+ return x if remainder == 0 else x + multiple - remainder
- # Write the spectrogram to disk
- # Note: outputs mel-spectrogram files and target ones have same names, just different folders
- mel_filename = os.path.join(out_dir, 'ljspeech-mel-{:05d}.npy'.format(index))
- np.save(mel_filename, mels, allow_pickle=False)
+ def _prepare_inputs(self, inputs):
+ max_len = max([len(x) for x in inputs])
+ return np.stack([self._pad_input(x, max_len) for x in inputs])
- if log_dir is not None:
- #save wav
- wav = audio.inv_mel_spectrogram(mels.T)
- audio.save_wav(wav, os.path.join(log_dir, 'wavs/ljspeech-wav-{:05d}.wav'.format(index)))
+ def _pad_input(self, x, length):
+ return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=self._pad)
- #save alignments
- plot.plot_alignment(alignment, os.path.join(log_dir, 'plots/ljspeech-alignment-{:05d}.png'.format(index)),
- info='{}'.format(text), split_title=True)
+ def _prepare_targets(self, targets, alignment):
+ max_len = max([len(t) for t in targets])
+ return np.stack([self._pad_target(t, self._round_up(max_len, alignment)) for t in targets])
- #save mel spectrogram plot
- plot.plot_spectrogram(mels, os.path.join(log_dir, 'plots/ljspeech-mel-{:05d}.png'.format(index)),
- info='{}'.format(text), split_title=True)
+ def _pad_target(self, t, length):
+ return np.pad(t, [(0, length - t.shape[0]), (0, 0)], mode='constant', constant_values=self._target_pad)
- return mel_filename
\ No newline at end of file
+ def _get_output_lengths(self, stop_tokens):
+ #Determine each mel length by the stop token predictions. (len = first occurence of 1 in stop_tokens row wise)
+ output_lengths = [row.index(1) + 1 if 1 in row else len(row) for row in np.round(stop_tokens).tolist()]
+ return output_lengths
diff --git a/tacotron/train.py b/tacotron/train.py
index 82717f31..480768d9 100644
--- a/tacotron/train.py
+++ b/tacotron/train.py
@@ -1,82 +1,144 @@
-import numpy as np
-from datetime import datetime
+import argparse
import os
import subprocess
import time
-import tensorflow as tf
import traceback
-import argparse
-
+from datetime import datetime
+import infolog
+import numpy as np
+import tensorflow as tf
+from datasets import audio
+from hparams import hparams_debug_string
from tacotron.feeder import Feeder
-from hparams import hparams, hparams_debug_string
from tacotron.models import create_model
+from tacotron.utils import ValueWindow, plot
from tacotron.utils.text import sequence_to_text
-from tacotron.utils import infolog, plot, ValueWindow
-from datasets import audio
+from tqdm import tqdm
+
log = infolog.log
-def add_stats(model):
+def add_train_stats(model, hparams):
with tf.variable_scope('stats') as scope:
tf.summary.histogram('mel_outputs', model.mel_outputs)
tf.summary.histogram('mel_targets', model.mel_targets)
tf.summary.scalar('before_loss', model.before_loss)
tf.summary.scalar('after_loss', model.after_loss)
+ if hparams.predict_linear:
+ tf.summary.scalar('linear_loss', model.linear_loss)
tf.summary.scalar('regularization_loss', model.regularization_loss)
tf.summary.scalar('stop_token_loss', model.stop_token_loss)
+ tf.summary.scalar('attention_loss', model.attention_loss)
tf.summary.scalar('loss', model.loss)
- tf.summary.scalar('learning_rate', model.learning_rate) #control learning rate decay speed
+ tf.summary.scalar('learning_rate', model.learning_rate) #Control learning rate decay speed
+ if hparams.tacotron_teacher_forcing_mode == 'scheduled':
+ tf.summary.scalar('teacher_forcing_ratio', model.ratio) #Control teacher forcing ratio decay when mode = 'scheduled'
gradient_norms = [tf.norm(grad) for grad in model.gradients]
tf.summary.histogram('gradient_norm', gradient_norms)
tf.summary.scalar('max_gradient_norm', tf.reduce_max(gradient_norms)) #visualize gradients (in case of explosion)
return tf.summary.merge_all()
+def add_eval_stats(summary_writer, step, linear_loss, before_loss, after_loss, stop_token_loss, attention_loss, loss):
+ values = [
+ tf.Summary.Value(tag='eval_model/eval_stats/eval_before_loss', simple_value=before_loss),
+ tf.Summary.Value(tag='eval_model/eval_stats/eval_after_loss', simple_value=after_loss),
+ tf.Summary.Value(tag='eval_model/eval_stats/stop_token_loss', simple_value=stop_token_loss),
+ tf.Summary.Value(tag='eval_model/eval_stats/attention_loss', simple_value=attention_loss),
+ tf.Summary.Value(tag='eval_model/eval_stats/eval_loss', simple_value=loss),
+ ]
+ if linear_loss is not None:
+ values.append(tf.Summary.Value(tag='eval_model/eval_stats/eval_linear_loss', simple_value=linear_loss))
+ test_summary = tf.Summary(value=values)
+ summary_writer.add_summary(test_summary, step)
+
def time_string():
return datetime.now().strftime('%Y-%m-%d %H:%M')
-def train(log_dir, args):
- save_dir = os.path.join(log_dir, 'pretrained/')
- checkpoint_path = os.path.join(save_dir, 'model.ckpt')
- input_path = os.path.join(args.base_dir, args.input)
+def model_train_mode(args, feeder, hparams, global_step):
+ with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as scope:
+ model_name = None
+ if args.model == 'Tacotron-2':
+ model_name = 'Tacotron'
+ model = create_model(model_name or args.model, hparams)
+ if hparams.predict_linear:
+ model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.token_targets, linear_targets=feeder.linear_targets,
+ targets_lengths=feeder.targets_lengths, global_step=global_step,
+ is_training=True)
+ else:
+ model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.token_targets,
+ targets_lengths=feeder.targets_lengths, global_step=global_step,
+ is_training=True)
+ model.add_loss()
+ model.add_optimizer(global_step)
+ stats = add_train_stats(model, hparams)
+ return model, stats
+
+def model_test_mode(args, feeder, hparams, global_step):
+ with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as scope:
+ model_name = None
+ if args.model == 'Tacotron-2':
+ model_name = 'Tacotron'
+ model = create_model(model_name or args.model, hparams)
+ if hparams.predict_linear:
+ model.initialize(feeder.eval_inputs, feeder.eval_input_lengths, feeder.eval_mel_targets, feeder.eval_token_targets,
+ linear_targets=feeder.eval_linear_targets, targets_lengths=feeder.eval_targets_lengths, global_step=global_step,
+ is_training=False, is_evaluating=True)
+ else:
+ model.initialize(feeder.eval_inputs, feeder.eval_input_lengths, feeder.eval_mel_targets, feeder.eval_token_targets,
+ targets_lengths=feeder.eval_targets_lengths, global_step=global_step, is_training=False, is_evaluating=True)
+ model.add_loss()
+ return model
+
+def train(log_dir, args, hparams):
+ save_dir = os.path.join(log_dir, 'taco_pretrained')
plot_dir = os.path.join(log_dir, 'plots')
wav_dir = os.path.join(log_dir, 'wavs')
mel_dir = os.path.join(log_dir, 'mel-spectrograms')
+ eval_dir = os.path.join(log_dir, 'eval-dir')
+ eval_plot_dir = os.path.join(eval_dir, 'plots')
+ eval_wav_dir = os.path.join(eval_dir, 'wavs')
+ tensorboard_dir = os.path.join(log_dir, 'tacotron_events')
+ os.makedirs(save_dir, exist_ok=True)
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(wav_dir, exist_ok=True)
os.makedirs(mel_dir, exist_ok=True)
+ os.makedirs(eval_dir, exist_ok=True)
+ os.makedirs(eval_plot_dir, exist_ok=True)
+ os.makedirs(eval_wav_dir, exist_ok=True)
+ os.makedirs(tensorboard_dir, exist_ok=True)
+
+ checkpoint_path = os.path.join(save_dir, 'tacotron_model.ckpt')
+ input_path = os.path.join(args.base_dir, args.tacotron_input)
+
+ if hparams.predict_linear:
+ linear_dir = os.path.join(log_dir, 'linear-spectrograms')
+ os.makedirs(linear_dir, exist_ok=True)
+
log('Checkpoint path: {}'.format(checkpoint_path))
log('Loading training data from: {}'.format(input_path))
log('Using model: {}'.format(args.model))
log(hparams_debug_string())
+ #Start by setting a seed for repeatability
+ tf.set_random_seed(hparams.tacotron_random_seed)
+
#Set up data feeder
coord = tf.train.Coordinator()
with tf.variable_scope('datafeeder') as scope:
feeder = Feeder(coord, input_path, hparams)
#Set up model:
- step_count = 0
- try:
- #simple text file to keep count of global step
- with open(os.path.join(log_dir, 'step_counter.txt'), 'r') as file:
- step_count = int(file.read())
- except:
- print('no step_counter file found, assuming there is no saved checkpoint')
-
- global_step = tf.Variable(step_count, name='global_step', trainable=False)
- with tf.variable_scope('model') as scope:
- model = create_model(args.model, hparams)
- model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.token_targets)
- model.add_loss()
- model.add_optimizer(global_step)
- stats = add_stats(model)
+ global_step = tf.Variable(0, name='global_step', trainable=False)
+ model, stats = model_train_mode(args, feeder, hparams, global_step)
+ eval_model = model_test_mode(args, feeder, hparams, global_step)
#Book keeping
step = 0
- save_step = 0
time_window = ValueWindow(100)
loss_window = ValueWindow(100)
- saver = tf.train.Saver(max_to_keep=5)
+ saver = tf.train.Saver(max_to_keep=1)
+
+ log('Tacotron training set to a maximum of {} steps'.format(args.tacotron_train_steps))
#Memory allocation on the GPU as needed
config = tf.ConfigProto()
@@ -85,90 +147,165 @@ def train(log_dir, args):
#Train
with tf.Session(config=config) as sess:
try:
- summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
+ summary_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph)
sess.run(tf.global_variables_initializer())
#saved model restoring
if args.restore:
- #Restore saved model if the user requested it, Default = True.
+ # Restore saved model if the user requested it, default = True
try:
checkpoint_state = tf.train.get_checkpoint_state(save_dir)
- except tf.errors.OutOfRangeError as e:
- log('Cannot restore checkpoint: {}'.format(e))
-
- if (checkpoint_state and checkpoint_state.model_checkpoint_path):
- log('Loading checkpoint {}'.format(checkpoint_state.model_checkpoint_path))
- saver.restore(sess, checkpoint_state.model_checkpoint_path)
+ if (checkpoint_state and checkpoint_state.model_checkpoint_path):
+ log('Loading checkpoint {}'.format(checkpoint_state.model_checkpoint_path), slack=True)
+ saver.restore(sess, checkpoint_state.model_checkpoint_path)
+ else:
+ log('No model to load at {}'.format(save_dir), slack=True)
+ except tf.errors.OutOfRangeError as e:
+ log('Cannot restore checkpoint: {}'.format(e), slack=True)
else:
- if not args.restore:
- log('Starting new training!')
- else:
- log('No model to load at {}'.format(save_dir))
+ log('Starting new training!', slack=True)
- #initiating feeder
- feeder.start_in_session(sess)
+ #initializing feeder
+ feeder.start_threads(sess)
#Training loop
- while not coord.should_stop():
+ while not coord.should_stop() and step < args.tacotron_train_steps:
start_time = time.time()
step, loss, opt = sess.run([global_step, model.loss, model.optimize])
time_window.append(time.time() - start_time)
loss_window.append(loss)
message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format(
step, time_window.average, loss, loss_window.average)
- log(message, end='\r')
+ log(message, end='\r', slack=(step % args.checkpoint_interval == 0))
if loss > 100 or np.isnan(loss):
log('Loss exploded to {:.5f} at step {}'.format(loss, step))
raise Exception('Loss exploded')
if step % args.summary_interval == 0:
- log('\nWriting summary at step: {}'.format(step))
+ log('\nWriting summary at step {}'.format(step))
summary_writer.add_summary(sess.run(stats), step)
-
- if step % args.checkpoint_interval == 0:
- with open(os.path.join(log_dir,'step_counter.txt'), 'w') as file:
- file.write(str(step))
- log('Saving checkpoint to: {}-{}'.format(checkpoint_path, step))
- saver.save(sess, checkpoint_path, global_step=step)
- save_step = step
-
- log('Saving alignment, Mel-Spectrograms and griffin-lim inverted waveform..')
- input_seq, prediction, alignment, target = sess.run([model.inputs[0],
- model.mel_outputs[0],
- model.alignments[0],
- model.mel_targets[0],
- ])
- #save predicted spectrogram to disk (for plot and manual evaluation purposes)
- mel_filename = 'ljspeech-mel-prediction-step-{}.npy'.format(step)
- np.save(os.path.join(mel_dir, mel_filename), prediction.T, allow_pickle=False)
-
- #save griffin lim inverted wav for debug.
- wav = audio.inv_mel_spectrogram(prediction.T)
- audio.save_wav(wav, os.path.join(wav_dir, 'step-{}-waveform.wav'.format(step)))
+
+ if step % args.eval_interval == 0:
+ #Run eval and save eval stats
+ log('\nRunning evaluation at step {}'.format(step))
+
+ eval_losses = []
+ before_losses = []
+ after_losses = []
+ stop_token_losses = []
+ attention_losses = []
+ linear_losses = []
+ linear_loss = None
+
+ if hparams.predict_linear:
+ for i in tqdm(range(feeder.test_steps)):
+ eloss, before_loss, after_loss, stop_token_loss, linear_loss, attention_loss, mel_p, mel_t, t_len, align, lin_p = sess.run(
+ [eval_model.loss, eval_model.before_loss, eval_model.after_loss, eval_model.stop_token_loss,
+ eval_model.linear_loss, eval_model.attention_loss, eval_model.mel_outputs[0], eval_model.mel_targets[0],
+ eval_model.targets_lengths[0], eval_model.alignments[0], eval_model.linear_outputs[0]])
+ eval_losses.append(eloss)
+ before_losses.append(before_loss)
+ after_losses.append(after_loss)
+ stop_token_losses.append(stop_token_loss)
+ attention_losses.append(attention_loss)
+ linear_losses.append(linear_loss)
+ linear_loss = sum(linear_losses) / len(linear_losses)
+
+ wav = audio.inv_linear_spectrogram(lin_p.T, hparams)
+ audio.save_wav(wav, os.path.join(eval_wav_dir, 'step-{}-eval-waveform-linear.wav'.format(step)), hparams)
+ else:
+ for i in tqdm(range(feeder.test_steps)):
+ eloss, before_loss, after_loss, stop_token_loss, attention_loss, mel_p, mel_t, t_len, align = sess.run(
+ [eval_model.loss, eval_model.before_loss, eval_model.after_loss, eval_model.stop_token_loss,
+ eval_model.attention_loss, eval_model.mel_outputs[0], eval_model.mel_targets[0],
+ eval_model.targets_lengths[0], eval_model.alignments[0]])
+ eval_losses.append(eloss)
+ before_losses.append(before_loss)
+ after_losses.append(after_loss)
+ stop_token_losses.append(stop_token_loss)
+ attention_losses.append(attention_loss)
+
+ eval_loss = sum(eval_losses) / len(eval_losses)
+ before_loss = sum(before_losses) / len(before_losses)
+ after_loss = sum(after_losses) / len(after_losses)
+ stop_token_loss = sum(stop_token_losses) / len(stop_token_losses)
+ attention_loss = sum(attention_losses) / len(attention_losses)
+
+ log('Saving eval log to {}..'.format(eval_dir))
+ # #Save some log to monitor model improvement on same unseen sequence
+ wav = audio.inv_mel_spectrogram(mel_p.T, hparams)
+ audio.save_wav(wav, os.path.join(eval_wav_dir, 'step-{}-eval-waveform-mel.wav'.format(step)), hparams)
+
+ plot.plot_alignment(align, os.path.join(eval_plot_dir, 'step-{}-eval-align.png'.format(step)),
+ info='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(), step, eval_loss),
+ max_len=t_len // hparams.outputs_per_step)
+ plot.plot_spectrogram(mel_p, os.path.join(eval_plot_dir, 'step-{}-eval-mel-spectrogram.png'.format(step)),
+ info='{}, {}, step={}, loss={:.5}'.format(args.model, time_string(), step, eval_loss), target_spectrogram=mel_t,
+ max_len=t_len)
+
+ log('Eval loss for global step {}: {:.3f}'.format(step, eval_loss))
+ log('Writing eval summary!')
+ add_eval_stats(summary_writer, step, linear_loss, before_loss, after_loss, stop_token_loss, attention_loss, eval_loss)
+
+
+ if step % args.checkpoint_interval == 0 or step == args.tacotron_train_steps:
+ #Save model and current global step
+ saver.save(sess, checkpoint_path, global_step=global_step)
+
+ log('\nSaving alignment, Mel-Spectrograms and griffin-lim inverted waveform..')
+ if hparams.predict_linear:
+ input_seq, mel_prediction, linear_prediction, alignment, target, target_length = sess.run([
+ model.inputs[0],
+ model.mel_outputs[0],
+ model.linear_outputs[0],
+ model.alignments[0],
+ model.mel_targets[0],
+ model.targets_lengths[0],
+ ])
+
+ #save predicted linear spectrogram to disk (debug)
+ linear_filename = 'linear-prediction-step-{}.npy'.format(step)
+ np.save(os.path.join(linear_dir, linear_filename), linear_prediction.T, allow_pickle=False)
+
+ #save griffin lim inverted wav for debug (linear -> wav)
+ wav = audio.inv_linear_spectrogram(linear_prediction.T, hparams)
+ audio.save_wav(wav, os.path.join(wav_dir, 'step-{}-wave-from-linear.wav'.format(step)), hparams)
+
+ else:
+ input_seq, mel_prediction, alignment, target, target_length = sess.run([model.inputs[0],
+ model.mel_outputs[0],
+ model.alignments[0],
+ model.mel_targets[0],
+ model.targets_lengths[0],
+ ])
+
+ #save predicted mel spectrogram to disk (debug)
+ mel_filename = 'mel-prediction-step-{}.npy'.format(step)
+ np.save(os.path.join(mel_dir, mel_filename), mel_prediction.T, allow_pickle=False)
+
+ #save griffin lim inverted wav for debug (mel -> wav)
+ wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
+ audio.save_wav(wav, os.path.join(wav_dir, 'step-{}-wave-from-mel.wav'.format(step)), hparams)
#save alignment plot to disk (control purposes)
plot.plot_alignment(alignment, os.path.join(plot_dir, 'step-{}-align.png'.format(step)),
- info='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(), step, loss))
- #save real mel-spectrogram plot to disk (control purposes)
- plot.plot_spectrogram(target, os.path.join(plot_dir, 'step-{}-real-mel-spectrogram.png'.format(step)),
- info='{}, {}, step={}, Real'.format(args.model, time_string(), step, loss))
- #save predicted mel-spectrogram plot to disk (control purposes)
- plot.plot_spectrogram(prediction, os.path.join(plot_dir, 'step-{}-pred-mel-spectrogram.png'.format(step)),
- info='{}, {}, step={}, loss={:.5}'.format(args.model, time_string(), step, loss))
+ info='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(), step, loss),
+ max_len=target_length // hparams.outputs_per_step)
+ #save real and predicted mel-spectrogram plot to disk (control purposes)
+ plot.plot_spectrogram(mel_prediction, os.path.join(plot_dir, 'step-{}-mel-spectrogram.png'.format(step)),
+ info='{}, {}, step={}, loss={:.5}'.format(args.model, time_string(), step, loss), target_spectrogram=target,
+ max_len=target_length)
log('Input at step {}: {}'.format(step, sequence_to_text(input_seq)))
+ log('Tacotron training complete after {} global steps!'.format(args.tacotron_train_steps), slack=True)
+ return save_dir
+
except Exception as e:
log('Exiting due to exception: {}'.format(e), slack=True)
traceback.print_exc()
coord.request_stop(e)
-def tacotron_train(args):
- hparams.parse(args.hparams)
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_log_level)
- run_name = args.model
- log_dir = os.path.join(args.base_dir, 'logs-{}'.format(run_name))
- os.makedirs(log_dir, exist_ok=True)
- infolog.init(os.path.join(log_dir, 'Terminal_train_log'), run_name)
- train(log_dir, args)
+def tacotron_train(args, log_dir, hparams):
+ return train(log_dir, args, hparams)
diff --git a/tacotron/utils/audio.py b/tacotron/utils/audio.py
deleted file mode 100644
index 84de0440..00000000
--- a/tacotron/utils/audio.py
+++ /dev/null
@@ -1,129 +0,0 @@
-import librosa
-import librosa.filters
-import numpy as np
-from scipy import signal
-from tacotron.hparams import hparams
-import tensorflow as tf
-
-
-def load_wav(path):
- return librosa.core.load(path, sr=hparams.sample_rate)[0]
-
-def save_wav(wav, path):
- wav *= 32767 / max(0.01, np.max(np.abs(wav)))
- librosa.output.write_wav(path, wav.astype(np.int16), hparams.sample_rate)
-
-def trim_silence(wav):
- '''Trim leading and trailing silence
-
- Useful for M-AILABS dataset if we choose to trim the extra 0.5 silences.
- '''
- return librosa.effects.trim(wav)[0]
-
-def preemphasis(x):
- return signal.lfilter([1, -hparams.preemphasis], [1], x)
-
-def inv_preemphasis(x):
- return signal.lfilter([1], [1, -hparams.preemphasis], x)
-
-def get_hop_size():
- hop_size = hparams.hop_size
- if hop_size is None:
- assert hparams.frame_shift_ms is not None
- hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
- return hop_size
-
-def melspectrogram(wav):
- D = _stft(wav)
- S = _amp_to_db(_linear_to_mel(np.abs(D))) - hparams.ref_level_db
-
- if hparams.mel_normalization:
- return _normalize(S)
- return S
-
-
-def inv_mel_spectrogram(mel_spectrogram):
- '''Converts mel spectrogram to waveform using librosa'''
- if hparams.mel_normalization:
- D = _denormalize(mel_spectrogram)
- else:
- D = mel_spectrogram
-
- S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db)) # Convert back to linear
-
- return _griffin_lim(S ** hparams.power)
-
-def _griffin_lim(S):
- '''librosa implementation of Griffin-Lim
- Based on https://github.com/librosa/librosa/issues/434
- '''
- angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
- S_complex = np.abs(S).astype(np.complex)
- y = _istft(S_complex * angles)
- for i in range(hparams.griffin_lim_iters):
- angles = np.exp(1j * np.angle(_stft(y)))
- y = _istft(S_complex * angles)
- return y
-
-def _stft(y):
- return librosa.stft(y=y, n_fft=hparams.fft_size, hop_length=get_hop_size())
-
-def _istft(y):
- return librosa.istft(y, hop_length=get_hop_size())
-
-
-# Conversions
-_mel_basis = None
-_inv_mel_basis = None
-
-def _linear_to_mel(spectogram):
- global _mel_basis
- if _mel_basis is None:
- _mel_basis = _build_mel_basis()
- return np.dot(_mel_basis, spectogram)
-
-def _mel_to_linear(mel_spectrogram):
- global _inv_mel_basis
- if _inv_mel_basis is None:
- _inv_mel_basis = np.linalg.pinv(_build_mel_basis())
- return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
-
-def _build_mel_basis():
- assert hparams.fmax <= hparams.sample_rate // 2
- return librosa.filters.mel(hparams.sample_rate, hparams.fft_size, n_mels=hparams.num_mels,
- fmin=hparams.fmin, fmax=hparams.fmax)
-
-def _amp_to_db(x):
- min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
- return 20 * np.log10(np.maximum(min_level, x))
-
-def _db_to_amp(x):
- return np.power(10.0, (x) * 0.05)
-
-def _normalize(S):
- if hparams.allow_clipping_in_normalization:
- if hparams.symmetric_mels:
- return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
- -hparams.max_abs_value, hparams.max_abs_value)
- else:
- return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
-
- assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
- if hparams.symmetric_mels:
- return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
- else:
- return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
-
-def _denormalize(D):
- if hparams.allow_clipping_in_normalization:
- if hparams.symmetric_mels:
- return (((np.clip(D, -hparams.max_abs_value,
- hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
- + hparams.min_level_db)
- else:
- return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
-
- if hparams.symmetric_mels:
- return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
- else:
- return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
\ No newline at end of file
diff --git a/tacotron/utils/cleaners.py b/tacotron/utils/cleaners.py
index aa56c4c6..711be022 100644
--- a/tacotron/utils/cleaners.py
+++ b/tacotron/utils/cleaners.py
@@ -11,9 +11,10 @@
'''
import re
+
from unidecode import unidecode
-from .numbers import normalize_numbers
+from .numbers import normalize_numbers
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
@@ -52,6 +53,8 @@ def expand_numbers(text):
def lowercase(text):
+ '''lowercase input tokens.
+ '''
return text.lower()
diff --git a/tacotron/utils/cmudict.py b/tacotron/utils/cmudict.py
index 85f4cbb3..d52a5d3f 100644
--- a/tacotron/utils/cmudict.py
+++ b/tacotron/utils/cmudict.py
@@ -1,6 +1,5 @@
import re
-
valid_symbols = [
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
diff --git a/tacotron/utils/numbers.py b/tacotron/utils/numbers.py
index ba9eb741..65287a33 100644
--- a/tacotron/utils/numbers.py
+++ b/tacotron/utils/numbers.py
@@ -1,6 +1,6 @@
-import inflect
import re
+import inflect
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
diff --git a/tacotron/utils/plot.py b/tacotron/utils/plot.py
index b9a9c09e..f776567c 100644
--- a/tacotron/utils/plot.py
+++ b/tacotron/utils/plot.py
@@ -1,7 +1,8 @@
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
-import numpy as np
+
+import numpy as np
def split_title_line(title_text, max_words=5):
@@ -12,8 +13,13 @@ def split_title_line(title_text, max_words=5):
seq = title_text.split()
return '\n'.join([' '.join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)])
-def plot_alignment(alignment, path, info=None, split_title=False):
- fig, ax = plt.subplots()
+def plot_alignment(alignment, path, info=None, split_title=False, max_len=None):
+ if max_len is not None:
+ alignment = alignment[:, :max_len]
+
+ fig = plt.figure(figsize=(8, 6))
+ ax = fig.add_subplot(111)
+
im = ax.imshow(
alignment,
aspect='auto',
@@ -31,20 +37,45 @@ def plot_alignment(alignment, path, info=None, split_title=False):
plt.ylabel('Encoder timestep')
plt.tight_layout()
plt.savefig(path, format='png')
+ plt.close()
-def plot_spectrogram(spectrogram, path, info=None, split_title=False):
- plt.figure()
- plt.imshow(np.rot90(spectrogram))
- plt.colorbar(shrink=0.65, orientation='horizontal')
- plt.ylabel('mels')
- xlabel = 'frames'
+def plot_spectrogram(pred_spectrogram, path, info=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False):
+ if max_len is not None:
+ target_spectrogram = target_spectrogram[:max_len]
+ pred_spectrogram = pred_spectrogram[:max_len]
+
if info is not None:
if split_title:
title = split_title_line(info)
else:
title = info
- plt.xlabel(xlabel)
- plt.title(title)
+
+ fig = plt.figure(figsize=(10, 8))
+ # Set common labels
+ fig.text(0.5, 0.18, title, horizontalalignment='center', fontsize=16)
+
+ #target spectrogram subplot
+ if target_spectrogram is not None:
+ ax1 = fig.add_subplot(311)
+ ax2 = fig.add_subplot(312)
+
+ if auto_aspect:
+ im = ax1.imshow(np.rot90(target_spectrogram), aspect='auto', interpolation='none')
+ else:
+ im = ax1.imshow(np.rot90(target_spectrogram), interpolation='none')
+ ax1.set_title('Target Mel-Spectrogram')
+ fig.colorbar(mappable=im, shrink=0.65, orientation='horizontal', ax=ax1)
+ ax2.set_title('Predicted Mel-Spectrogram')
+ else:
+ ax2 = fig.add_subplot(211)
+
+ if auto_aspect:
+ im = ax2.imshow(np.rot90(pred_spectrogram), aspect='auto', interpolation='none')
+ else:
+ im = ax2.imshow(np.rot90(pred_spectrogram), interpolation='none')
+ fig.colorbar(mappable=im, shrink=0.65, orientation='horizontal', ax=ax2)
+
plt.tight_layout()
plt.savefig(path, format='png')
+ plt.close()
diff --git a/tacotron/utils/symbols.py b/tacotron/utils/symbols.py
index 2764c95e..7b7fae07 100644
--- a/tacotron/utils/symbols.py
+++ b/tacotron/utils/symbols.py
@@ -4,14 +4,10 @@
The default is a set of ASCII characters that works well for English or text that has been run
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
'''
-from . import cmudict
_pad = '_'
_eos = '~'
-_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
-
-# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
-_arpabet = ['@' + s for s in cmudict.valid_symbols]
+_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890@!\'(),-.:;? '
# Export all symbols:
-symbols = [_pad, _eos] + list(_characters) + _arpabet
\ No newline at end of file
+symbols = [_pad, _eos] + list(_characters)
diff --git a/tacotron/utils/text.py b/tacotron/utils/text.py
index 7b96538c..1116db7e 100644
--- a/tacotron/utils/text.py
+++ b/tacotron/utils/text.py
@@ -1,8 +1,8 @@
import re
+
from . import cleaners
from .symbols import symbols
-
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
diff --git a/train.py b/train.py
index 65844f6d..46e48f2e 100644
--- a/train.py
+++ b/train.py
@@ -1,32 +1,74 @@
import argparse
+import os
+from time import sleep
+
+import infolog
+import tensorflow as tf
+from hparams import hparams
+from infolog import log
+from tacotron.synthesize import tacotron_synthesize
from tacotron.train import tacotron_train
+log = infolog.log
+
+
+def save_seq(file, sequence, input_path):
+ '''Save Tacotron-2 training state to disk. (To skip for future runs)
+ '''
+ sequence = [str(int(s)) for s in sequence] + [input_path]
+ with open(file, 'w') as f:
+ f.write('|'.join(sequence))
+
+def read_seq(file):
+ '''Load Tacotron-2 training state from disk. (To skip if not first run)
+ '''
+ if os.path.isfile(file):
+ with open(file, 'r') as f:
+ sequence = f.read().split('|')
+
+ return [bool(int(s)) for s in sequence[:-1]], sequence[-1]
+ else:
+ return [0, 0, 0], ''
+
+def prepare_run(args):
+ modified_hp = hparams.parse(args.hparams)
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_log_level)
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+ run_name = args.name or args.model
+ log_dir = os.path.join(args.base_dir, 'logs-{}'.format(run_name))
+ os.makedirs(log_dir, exist_ok=True)
+ infolog.init(os.path.join(log_dir, 'Terminal_train_log'), run_name, args.slack_url)
+ return log_dir, modified_hp
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--base_dir', default='')
parser.add_argument('--hparams', default='',
help='Hyperparameter overrides as a comma-separated list of name=value pairs')
- parser.add_argument('--input', default='training_data/train.txt')
+ parser.add_argument('--tacotron_input', default='training_data/train.txt')
+ parser.add_argument('--wavenet_input', default='tacotron_output/gta/map.txt')
+ parser.add_argument('--name', help='Name of logging directory.')
parser.add_argument('--model', default='Tacotron')
+ parser.add_argument('--input_dir', default='training_data', help='folder to contain inputs sentences/targets')
+ parser.add_argument('--output_dir', default='output', help='folder to contain synthesized mel spectrograms')
+ parser.add_argument('--mode', default='synthesis', help='mode for synthesis of tacotron after training')
+ parser.add_argument('--GTA', default='True', help='Ground truth aligned synthesis, defaults to True, only considered in Tacotron synthesis mode')
parser.add_argument('--restore', type=bool, default=True, help='Set this to False to do a fresh training')
- parser.add_argument('--summary_interval', type=int, default=100,
+ parser.add_argument('--summary_interval', type=int, default=250,
help='Steps between running summary ops')
- parser.add_argument('--checkpoint_interval', type=int, default=500,
+ parser.add_argument('--checkpoint_interval', type=int, default=1000,
help='Steps between writing checkpoints')
+ parser.add_argument('--eval_interval', type=int, default=1000,
+ help='Steps between eval on test data')
+ parser.add_argument('--tacotron_train_steps', type=int, default=150000, help='total number of tacotron training steps')
+ parser.add_argument('--wavenet_train_steps', type=int, default=1300000, help='total number of wavenet training steps')
parser.add_argument('--tf_log_level', type=int, default=1, help='Tensorflow C++ log level.')
+ parser.add_argument('--slack_url', default=None, help='slack webhook notification destination link')
args = parser.parse_args()
- accepted_models = ['Tacotron', 'Wavenet']
-
- if args.model not in accepted_models:
- raise ValueError('please enter a valid model to train: {}'.format(accepted_models))
-
- if args.model == 'Tacotron':
- tacotron_train(args)
- elif args.model == 'Wavenet':
- raise NotImplementedError('Wavenet is still a work in progress, thank you for your patience!')
+ log_dir, hparams = prepare_run(args)
+ tacotron_train(args, log_dir, hparams)
if __name__ == '__main__':
- main()
\ No newline at end of file
+ main()
diff --git a/wavenet_vocoder/__init__.py b/wavenet_vocoder/__init__.py
deleted file mode 100644
index 4287ca86..00000000
--- a/wavenet_vocoder/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-#
\ No newline at end of file
diff --git a/wavenet_vocoder/models/__init__.py b/wavenet_vocoder/models/__init__.py
deleted file mode 100644
index c81a9dea..00000000
--- a/wavenet_vocoder/models/__init__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from .wavenet import WaveNet
-
-
-def create_model(name, hparams):
- if name == 'WaveNet'
- return WaveNet(hparams)
- else:
- raise Exception('Unknow model: {}'.format(name))
\ No newline at end of file
diff --git a/wavenet_vocoder/models/mixture.py b/wavenet_vocoder/models/mixture.py
deleted file mode 100644
index e3d0ce43..00000000
--- a/wavenet_vocoder/models/mixture.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import numpy as np
-import tensorflow as tf
-
-
-
-def log_sum_exp(x):
- """ numerically stable log_sum_exp implementation that prevents overflow """
- axis = len(x.get_shape())-1
- m = tf.reduce_max(x, axis)
- m2 = tf.reduce_max(x, axis, keepdims=True)
- return m + tf.log(tf.reduce_sum(tf.exp(x-m2), axis))
-
-def discretized_mix_logistic_loss(y_hat, y, num_classes=256,
- log_scale_min=-7.0, reduce=True):
- '''Discretized mix of logistic distributions loss.
-
- Note that it is assumed that input is scaled to [-1, 1]
-
- Args:
- y_hat: Tensor [batch_size, channels, time_length], predicted output.
- y: Tensor [batch_size, channels, time_length], Target.
- Returns:
- Tensor loss
- '''
- assert tf.rank(y_hat) == 3
- assert tf.shape(y_hat)[1] % 3 == 0
-
- nr_mix = tf.shape(y_hat)[1] // 3
-
- #[Batch_size, time_length, channels]
- y_hat = tf.transpose(y_hat, [0, 2, 1])
-
- #unpack parameters. [batch_size, time_length, num_mixtures] x 3
- logit_probs = y_hat[:, :, :nr_mix]
- means = y_hat[:, :, nr_mix:2 * nr_mix]
- log_scales = tf.maximum(y_hat[:, :, 2* nr_mix:3 * nr_mix], log_scale_min)
-
- #[batch_size, time_length, 1] -> [batch_size, time_length, num_mixtures]
- y = y * tf.ones(shape=[1, 1, nr_mix], dtype=tf.float32)
-
- centered_y = y - means
- inv_stdv = tf.exp(-log_scales)
- plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
- cdf_plus = tf.nn.sigmoid(plus_in)
- min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
- cdf_min = tf.nn.sigmoid(min_in)
-
- log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling)
- log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling)
-
- #probability for all other cases
- cdf_delta = cdf_plus - cdf_min
-
- mid_in = inv_stdv * centered_y
- #log probability in the center of the bin, to be used in extreme cases
- #(not actually used in this code)
- log_pdf_mid = mid_in - log_scales - 2. * tf.nn.softplus(mid_in)
-
- log_probs = tf.where(x < -0.999, log_cdf_plus,
- tf.where(x > 0.999, log_one_minus_cdf_min,
- tf.where(cdf_delta > 1e-5,
- tf.log(tf.maximum(cdf_delta, 1e-12)),
- log_pdf_mid - np.log((num_classes - 1) / 2))))
- log_probs = tf.reduce_sum(log_probs, axis=-1, keepdims=True) + tf.nn.log_softmax(logit_probs, -1)
-
- if reduce:
- return -tf.reduce_sum(log_sum_exp(log_probs))
- else:
- return -tf.expand_dims(log_sum_exp(log_probs), [-1])
-
-def sample_from_discretized_mix_logistic(y, log_scale_min=-7.):
- '''
- Args:
- y: Tensor, [batch_size, channels, time_length]
- Returns:
- Tensor: sample in range of [-1, 1]
- '''
- assert tf.shape(y)[1] % 3 == 0
- nr_mix = tf.shape(y)[1] // 3
-
- #[batch_size, time_length, channels]
- y = tf.transpose(y, [0, 2, 1])
- logit_probs = y[:, :, :nr_mix]
-
- #sample mixture indicator from softmax
- temp = tf.random_uniform(logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5)
- temp = logit_probs - tf.log(-tf.log(temp))
- argmax = tf.argmax(temp, -1)
-
- #[batch_size, time_length] -> [batch_size, time_length, nr_mix]
- one_hot = tf.one_hot(argmax, depth=nr_mix, dtype=tf.float32)
- #select logistic parameters
- means = tf.reduce_sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1)
- log_scales = tf.maximum(tf.reduce_sum(
- y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1), log_scale_min)
-
- #sample from logistic & clip to interval
- #we don't actually round to the nearest 8-bit value when sampling
- u = tf.random_uniform(means.get_shape(), minval=1e-5, maximum=1. - 1e-5)
- x = means + tf.exp(log_scales) * (tf.log(u) - tf.log(1 -u))
-
- return tf.minimum(tf.maximum(x, -1.), 1.)
diff --git a/wavenet_vocoder/models/modules.py b/wavenet_vocoder/models/modules.py
deleted file mode 100644
index b772bc04..00000000
--- a/wavenet_vocoder/models/modules.py
+++ /dev/null
@@ -1,198 +0,0 @@
-import numpy as np
-import tensorflow as tf
-
-
-class Embedding:
- """Embedding class for global conditions.
- """
- def __init__(self, num_embeddings, embedding_dim, std=0.1, name='gc_embedding'):
- #Create embedding table
- self.embedding_table = tf.get_variable(name,
- [num_embeddings, embedding_dim], dtype=tf.float32,
- initializer=tf.truncated_normal_initializer(mean=0., stddev=std))
-
- def __call__(self, inputs):
- #Do the actual embedding
- return tf.nn.embedding_lookup(self.embedding_table, inputs)
-
-class ReluActivation:
- """Simple class to wrap relu activation function in classe for later call.
- """
- def __init__(self, name=None):
- self.name = name
-
- def __call__(self, inputs):
- return tf.nn.relu(inputs, name=self.name)
-
-
-class Conv1d1x1(tf.layers.Conv1D):
- """Extend tf.layers.Conv1D for dilated layers convolutions.
- """
- def __init__(in_channels, filters, kernel_size=1, padding='same', dilation_rate=1, use_bias=True, **kwargs):
- super(Conv1d1x1, self).__init__(
- filters=filters,
- kernel_size=kernel_size,
- padding=padding,
- dilation_rate=dilation_rate,
- use_bias=use_bias,
- **kwargs)
- self.in_channels = in_channels
- self.input_buffer = None
- self._linearizer_weight = None
- tf.add_to_collections(tf.GraphKeys.UPDATE_OPS, self._clear_linearized_weight)
-
- def incremental_step(self, inputs):
- #input: [batch_size, time_length, channels]
- if self.training:
- raise RuntimeError('incremental_step only supports eval mode')
-
-
- #reshape weight
- weight = self._get_linearized_weight()
- kw = self.kernel_size[0]
- dilation = self.dilation_rate[0]
-
- batch_size = tf.shape(inputs)[0]
- if kw > 1:
- if self.input_buffer is None:
- self.input_buffer = tf.zeros((batch_size, kw + (kw - 1) * (dilation - 1), tf.shape(inputs)[2]))
- else:
- #shift buffer
- self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :]
- #append next input
- self.input_buffer[:, -1, :] = inputs[:, -1, :]
- inputs = self.input_buffer
- if dilation > 1:
- inputs = inputs[:, 0::dilation, :]
- output = tf.add(tf.matmul(inputs, weight), self.bias)
- return tf.reshape(output, [batch_size, 1, -1])
-
- def _get_linearized_weight(self):
- if self._linearizer_weight is None:
- kw = self.kernel.shape[0]
- #layers.Conv1D
- if tf.shape(self.kernel) == (self.filters, self.in_channels, kw):
- weight = tf.transpose(self.kernel, [0, 2, 1])
- else:
- weight = tf.transpose(self.kernel, [2, 0, 1])
- assert tf.shape(weight) == (self.filters, kw, self.in_channels)
- self._linearizer_weight = tf.reshape(self.filters, -1)
- return self._linearizer_weight
-
- def _clear_linearized_weight(self):
- self._linearizer_weight = None
-
- def clear_buffer(self):
- self.input_buffer = None
-
-def _conv1x1_forward(conv, x, is_incremental):
- """conv1x1 step
- """
- if is_incremental:
- x = conv.incremental_step(x)
- else:
- x = conv(x)
-
-class ResidualConv1dGLU():
- '''Residual dilated conv1d + Gated Linear Unit
- '''
-
- def __init__(self, residual_channels, gate_channels, kernel_size,
- skip_out_channels=None, cin_channels=-1, gin_channels=-1,
- dropout=1 - .95, padding=None, dilation=1, causal=True,
- bias=True, *args, **kwargs):
- self.dropout = dropout
-
- if skip_out_channels is None:
- skip_out_channels = residual_channels
-
- if padding is None:
- #No future time stamps available
- if causal:
- padding = (kernel_size - 1) * dilation
- else:
- padding = (kernel_size - 1) // 2 * dilation
-
- self.causal = causal
-
- self.conv = Conv1d1x1(residual_channels, gate_channels, kernel_size,
- padding=padding, dilation=dilation, bias=bias)
-
- #Local conditioning
- if cin_channels > 0:
- self.conv1x1c = Conv1d1x1(cin_channels, gate_channels,
- bias=bias)
- else:
- self.conv1x1c = None
-
- #Global conditioning
- if gin_channels > 0:
- self.conv1x1g = Conv1d1x1(gin_channels, gate_channels,
- bias=bias)
- else:
- self.conv1x1g = None
-
- gate_out_channels = gate_channels // 2
- self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias)
- self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, bias=bias)
-
- def __call__(self, x, c=None, g=None):
- return self.step(x, c, g, False)
-
- def incremental_step(self, x, c=None, g=None):
- return self.step(x, c, g, True)
-
- def step(self, x, c, g, is_incremental):
- '''
-
- Args:
- x: Tensor [batch_size, channels, time_length]
- c: Tensor [batch_size, c_channels, time_length]. Local conditioning features
- g: Tensor [batch_size, g_channels, time_length], global conditioning features
- is_incremental: Boolean, whether incremental mode is on
- Returns:
- Tensor output
- '''
- residual = x
- x = tf.layers.dropout(x, rate=self.dropout, training=not is_incremental)
- if is_incremental:
- splitdim = -1
- x = self.conv.incremental_step(x)
- else:
- splitdim = 1
- x = self.conv(x)
- #Remove future time steps
- x = x[:, :, :tf.shape(residual)[-1]] if self.causal else x
-
- a, b = tf.split(x, num_or_size_splits=2, axis=splitdim)
-
- #local conditioning
- if c is not None:
- assert self.conv1x1c is not None
- c = _conv1x1_forward(self.conv1x1c, c, is_incremental)
- ca, cb = tf.split(c, num_or_size_splits=2, axis=splitdim)
- a, b = a + ca, b + cb
-
- #global conditioning
- if g is not None:
- assert self.conv1x1g is not None
- g = _conv1x1_forward(self.conv1x1g, g, is_incremental)
- ga, gb = tf.split(g, num_or_size_splits=2, axis=splitdim)
- a, b = a + ga, b + gb
-
- x = tf.nn.tanh(a) * tf.nn.sigmoid(b)
-
- #For Skip connection
- s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental)
-
- #For Residual connection
- x = _conv1x1_forward(self.conv1x1_out, x, is_incremental)
-
- x = (x + residual) * np.sqrt(0.5)
- return x, s
-
- def clear_buffer(self):
- for conv in [self.conv, self.conv1x1_out, self.conv1x1_skip,
- self.conv1x1c, self.conv1x1g]:
- if conv is not None:
- self.conv.clear_buffer()
\ No newline at end of file
diff --git a/wavenet_vocoder/models/wavenet.py b/wavenet_vocoder/models/wavenet.py
deleted file mode 100644
index 96344c10..00000000
--- a/wavenet_vocoder/models/wavenet.py
+++ /dev/null
@@ -1,323 +0,0 @@
-import numpy as np
-import tensorflow as tf
-
-from .modules import Conv1d1x1, ResidualConv1dGLU, ConvTranspose2d, Embedding, ReluActivation
-from .mixture import sample_from_discretized_mix_logistic
-
-
-
-def _expand_global_features(batch_size, time_length, global_features, data_format='BCT'):
- """Expand global conditioning features to all time steps
-
- Args:
- batch_size: int
- time_length: int
- global_features: Tensor of shape [batch_size, channels] or [batch_size, channels, 1]
- data_format: string, 'BCT' to get output of shape [batch_size, channels, time_length]
- or 'BTC' to get output of shape [batch_size, time_length, channels]
-
- Returns:
- None or Tensor of shape [batch_size, channels, time_length] or [batch_size, time_length, channels]
- """
- accepted_formats = ['BCT', 'BTC']
- if not (data_format in accepted_formats):
- raise ValueError('{} is an unknow data format, accepted formats are "BCT" and "BTC"'.format(data_format))
-
- if global_features is None:
- return None
-
- #[batch_size, channels] ==> [batch_size, channels, 1]
- g = tf.expand_dims(global_features, axis=-1) if tf.rank(global_features) == 2 else global_features
- g_shape = tf.shape(g)
-
- #[batch_size, channels, 1] ==> [batch_size, channels, time_length]
- ones = tf.ones([g_shape[0], g_shape[1], time_length], tf.float32)
- g = g * ones
-
- if data_format == 'BCT':
- return g
-
- else:
- #[batch_size, channels, time_length] ==> [batch_size, time_length, channels]
- return tf.transpose(g, [0, 2, 1])
-
-
-def receptive_field_size(total_layers, num_cycles, kernel_size, dilation=lambda x: 2**x):
- """Compute receptive field size.
-
- Args:
- total_layers; int
- num_cycles: int
- kernel_size: int
- dilation: callable, function used to compute dilation factor.
- use "lambda x: 1" to disable dilated convolutions.
-
- Returns:
- int: receptive field size in sample.
- """
- assert total_layers % num_cycles == 0
-
- layers_per_cycle = total_layers // num_cycles
- dilations = [dilation(i % layers_per_cycle for i in range(total_layers))]
- return (kernel_size - 1) * sum(dilations) + 1
-
-
-class WaveNet():
- """Tacotron-2 Wavenet Vocoder model.
- """
- def __init__(self, hparams):
- #Get hparams
- self._hparams = hparams
-
- #Initialize model architecture
- assert hparams.layers % hparams.stacks == 0
- layers_per_stack = hparams.layers // hparams.stacks
-
- #first convolution
- if hparams.scalar_input:
- self.first_conv = Conv1d1x1(1, hparams.residual_channels)
- else:
- self.first_conv = Conv1d1x1(out_channels, hparams.residual_channels)
-
- #Residual convolutions
- self.conv_layers = [ResidualConv1dGLU(
- hparams.residual_channels, hparams.gate_channels,
- kernel_size=hparams.kernel_size,
- skip_out_channels=hparams.skip_out_channels,
- bias=hparams.use_bias,
- dilation=2**(layer % layers_per_stack),
- dropout=hparams.dropout,
- cin_channels=hparams.cin_channels,
- gin_channels=hparams.gin_channels,
- weight_normalization=hparams.weight_normalization) for layer in range(hparams.layers)]
-
- #Final convolutions
- self.last_conv_layers = [
- ReluActivation(name='final_conv_relu1'),
- Conv1d1x1(hparams.skip_out_channels, hparams.skip_out_channels,
- weight_normalization=hparams.weight_normalization),
- ReluActivation(name='final_conv_relu2'),
- Conv1d1x1(hparams.skip_out_channels, hparams.out_channels,
- weight_normalization=hparams.weight_normalization)]
-
- #Global conditionning embedding
- if hparams.gin_channels > 0 and hparams.use_speaker_embedding:
- assert hparams.n_speakers is not None
- self.embed_speakers = Embedding(
- hparams.n_speakers, hparams.gin_channels, std=0.1, name='gc_embedding')
- else:
- self.embed_speakers = None
-
- #Upsample conv net
- if hparams.upsample_conditional_features:
- self.upsample_conv = []
- for i, s in enumerate(hparams.upsample_scales):
- freq_axis_padding = (hparams.freq_axis_kernel_size - 1) // 2
- convt = ConvTranspose2d(1, 1, (freq_axis_kernel_size, s),
- padding=(freq_axis_padding, 0),
- dilation=1, stride=(1, s),
- weight_normalization=hparams.weight_normalization)
- self.upsample_conv.append(convt)
- #assuming features are [0, 1] scaled
- #this should avoid negative upsampling output
- self.upsample_conv.append(ReluActivation(name='conditional_upsample_{}'.format(i+1)))
- else:
- self.upsample_conv = None
-
- self.receptive_field = receptive_field_size(hparams.layers,
- hparams.stacks, hparams.kernel_size)
-
- #Sanity check functions
- def has_speaker_embedding(self):
- return self.embed_speakers is not None
-
- def local_conditioning_enabled(self):
- return self.cin_channels > 0
-
- def step(self, x, c=None, g=None, softmax=False):
- """Forward step
-
- Args:
- x: Tensor of shape [batch_size, channels, time_length], One-hot encoded audio signal.
- c: Tensor of shape [batch_size, cin_channels, time_length], Local conditioning features.
- g: Tensor of shape [batch_size, gin_channels, 1] or Ids of shape [batch_size, 1],
- Global conditioning features.
- Note: set hparams.use_speaker_embedding to False to disable embedding layer and
- use extrnal One-hot encoded features.
- softmax: Boolean, Whether to apply softmax.
-
- Returns:
- a Tensor of shape [batch_size, out_channels, time_length]
- """
- batch_size, _, time_length = tf.shape(x)
-
- if g is not None:
- if self.embed_speakers is not None:
- #[batch_size, 1] ==> [batch_size, 1, gin_channels]
- g = self.embed_speakers(tf.reshape(g, [batch_size, -1]))
- #[batch_size, gin_channels, 1]
- g = tf.transpose(g, [0, 2, 1])
- assert tf.rank(g) == 3
-
- #Expand global conditioning features to all time steps
- g_bct = _expand_global_features(batch_size, time_length, g, data_format='BCT')
-
- if c is not None and self.upsample_conv is not None:
- #[batch_size, 1, cin_channels, time_length]
- c = tf.expand_dims(c, axis=1)
- for transposed_conv in self.upsample_conv:
- c = transposed_conv(c)
-
- #[batch_size, cin_channels, time_length]
- c = tf.squeeze(c, [1])
- assert c.shape()[-1] == x.shape()[-1]
-
- #Feed data to network
- x = self.first_conv(x)
- skips = None
- for conv in self.conv_layers:
- x, h = conv(x, c, g_bct)
- if skips is None:
- skips = h
- else:
- skips += h
- skips *= np.sqrt(0.5)
- x = skips
-
- for conv in self.last_conv_layers:
- x = conv(x)
-
- return tf.nn.softmax(x, axis=1) if softmax else x
-
- def _clear_linearized_weights(self):
- self.first_conv._clear_linearized_weight()
- self.last_conv_layers[1]._clear_linearized_weight()
- self.last_conv_layers[-1]._clear_linearized_weight()
-
- def incremental_step(self, initial_input=None, c=None, g=None,
- time_length=100, test_inputs=None,
- tqdm=lambda x: x, softmax=True, quantize=True,
- log_scale_min=-7.0):
- """Inceremental forward step
-
- Inputs of shape [batch_size, channels, time_length] are reshaped to [batch_size, time_length, channels]
- Input of each time step is of shape [batch_size, 1, channels]
-
- Args:
- Initial input: Tensor of shape [batch_size, channels, 1], initial recurrence input.
- c: Tensor of shape [batch_size, cin_channels, time_length], Local conditioning features
- g: Tensor of shape [batch_size, gin_channels, time_length] or [batch_size, gin_channels, 1]
- global conditioning features
- T: int, number of timesteps to generate
- test_inputs: Tensor, teacher forcing inputs (debug)
- tqdm: callable, tqdm style
- softmax: Boolean, whether to apply softmax activation
- quantize: Whether to quantize softmax output before feeding to
- next time step input
- log_scale_min: float, log scale minimum value.
-
- Returns:
- Tensor of shape [batch_size, channels, time_length] or [batch_size, channels, 1]
- Generated one_hot encoded samples
- """
- batch_size = 1
-
- #Note: should reshape to [batch_size, time_length, channels]
- #not [batch_size, channels, time_length]
- if test_inputs is not None:
- if self._hparams.scalar_input:
- if tf.shape(test_inputs)[1] == 1:
- test_inputs = tf.transpose(test_inputs, [0, 2, 1])
- else:
- if tf.shape(test_inputs)[1] == self._hparams.out_channels:
- test_inputs = tf.transpose(test_inputs, [0, 2, 1])
-
- batch_size = tf.shape(test_inputs)[0]
- if time_length is None:
- time_length = tf.shape(test_inputs)[1]
- else:
- time_length = max(time_length, tf.shape(test_inputs)[1])
-
- time_length = int(time_length)
-
- #Global conditioning
- if g in not None:
- if self.embed_speakers is not None:
- g = self.embed_speakers(tf.reshape(g, [batch_size, -1]))
- #[batch_size, channels, 1]
- g = tf.transpose(g, [0, 2, 1])
- assert tf.rank(g) == 3
-
- g_btc = _expand_global_features(batch_size, time_length, g, data_format='BTC')
-
- #Local conditioning
- if c is not None and self.upsample_conv is not None:
- #[batch_size, 1, channels, time_length]
- c = tf.expand_dims(c, axis=1)
- for upsample_conv in self.upsample_conv:
- c = upsample_conv(c)
- #[batch_size, channels, time_length]
- c = tf.squeeze(c, [1])
- assert tf.shape(c)[-1] == time_length
-
- if c is not None and tf.shape(c)[-1] == time_length:
- c = tf.transpose(c, [0, 2, 1])
-
- outputs = []
- if initial_input is None:
- if self.scalar_input:
- initial_input = tf.zeros((batch_size, 1 ,1), tf.float32)
- else:
- np_input = np.zeros((batch_size, 1, self._hparams.out_channels))
- np_input[:, :, 127] = 1
- initial_input = tf.convert_to_tensor(np_input)
- else:
- if tf.shape(initial_input)[1] == self._hparams.out_channels:
- initial_input = tf.transpose(initial_input, [0, 2, 1])
-
- current_input = initial_input
- for t in tqdm(range(time_length)):
- if test_inputs is not None and t < tf.shape(test_inputs)[1]:
- current_input = tf.expand_dims(test_inputs[:, t, :], axis=1)
- else:
- if t > 0:
- current_input = outputs[-1]
-
- #conditioning features for single time step
- ct = None if c is None else tf.expand_dims(c[:, t, :], axis=1)
- gt = None if g is None else tf.expand_dims(g_btc[:, t, :], axis=1)
-
- x = current_input
- x = self.first_conv.incremental_step(x)
- skips = None
- for conv in self.conv_layers:
- x, h = conv.incremental_step(x, ct, gt)
- skips = h if skips is None else (skips + h) * np.sqrt(0.5)
- x = skips
- for conv in self.last_conv_layers:
- try:
- x = conv.incremental_step(x)
- except AttributeError:
- x = conv(x)
-
- #Generate next input by sampling
- if self._hparams.scalar_input:
- x = sample_from_discretized_mix_logistic(
- tf.reshape(x, [batch_size, -1, 1]), log_scale_min=log_scale_min)
- else:
- x = tf.nn.softmax(tf.reshape(x, [batch_size, -1]), axis=1) if softmax \
- else tf.reshape(x, [batch_size, -1])
- if quantize:
- sample = np.random.choice(
- np.arange(self._hparams.out_channels, p=tf.reshape(x, [-1]).eval()))
- new_x = np.zeros(tf.shape(x), np.float32)
- new_x[:, sample] = 1.
-
- x = tf.convert_to_tensor(new_x, tf.float32)
- outputs.append(x)
-
- #[time_length, batch_size, channels]
- outputs = tf.stack(outputs)
- #[batch_size, channels, time_length]
- self._clear_linearized_weights()
- return tf.transpose(outputs, [1, 2, 0])
diff --git a/wavenet_vocoder/train.py b/wavenet_vocoder/train.py
deleted file mode 100644
index 8aea6adb..00000000
--- a/wavenet_vocoder/train.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import argparse
-import sys
-import os
-from datetime import datetime
-
-from wavenet_vocoder.models import create_model
-import numpy as np
-
-from hparams import hparams, hparams_debug_string
-
-
-
-def sanity_check(model, c, g):
- if model.has_speaker_embedding():
- if g is None:
- raise RuntimeError('Wavenet expects speaker embedding, but speaker-id is not defined')
- else:
- if g is not None:
- raise RuntimeError('Wavenet expects no speaker embedding, but speaker-id is provided')
-
- if model.local_conditioning_enabled():
- if c is None:
- raise RuntimeError('Wavenet expected conditional features, but none were given')
- else:
- if c is not None:
- raise RuntimeError('Wavenet expects no conditional features, but features were given')
-
-
diff --git a/wavenet_vocoder/util.py b/wavenet_vocoder/util.py
deleted file mode 100644
index 6798e814..00000000
--- a/wavenet_vocoder/util.py
+++ /dev/null
@@ -1,185 +0,0 @@
-import numpy as np
-
-
-
-def _assert_valid_input_type(s):
- assert s == 'mulaw-quantize' or s == 'mulaw' or s == 'raw'
-
-def is_mulaw_quantize(s):
- _assert_valid_input_type(s)
- return s == 'mulaw-quantize'
-
-def is_mulaw(s):
- _assert_valid_input_type(s)
- return s == 'mulaw'
-
-def is_raw(s):
- _assert_valid_input_type(s)
- return s == 'raw'
-
-def is_scalar_input(s):
- return is_raw(s) or is_mulaw(s)
-
-
-#From https://github.com/r9y9/nnmnkwii/blob/master/nnmnkwii/preprocessing/generic.py
-def mulaw(x, mu=256):
- """Mu-Law companding
- Method described in paper [1]_.
- .. math::
- f(x) = sign(x) \ln (1 + \mu |x|) / \ln (1 + \mu)
- Args:
- x (array-like): Input signal. Each value of input signal must be in
- range of [-1, 1].
- mu (number): Compression parameter ``μ``.
- Returns:
- array-like: Compressed signal ([-1, 1])
- See also:
- :func:`nnmnkwii.preprocessing.inv_mulaw`
- :func:`nnmnkwii.preprocessing.mulaw_quantize`
- :func:`nnmnkwii.preprocessing.inv_mulaw_quantize`
- .. [1] Brokish, Charles W., and Michele Lewis. "A-law and mu-law companding
- implementations using the tms320c54x." SPRA163 (1997).
- """
- return _sign(x) * _log1p(mu * _abs(x)) / _log1p(mu)
-
-
-def inv_mulaw(y, mu=256):
- """Inverse of mu-law companding (mu-law expansion)
- .. math::
- f^{-1}(x) = sign(y) (1 / \mu) (1 + \mu)^{|y|} - 1)
- Args:
- y (array-like): Compressed signal. Each value of input signal must be in
- range of [-1, 1].
- mu (number): Compression parameter ``μ``.
- Returns:
- array-like: Uncomprresed signal (-1 <= x <= 1)
- See also:
- :func:`nnmnkwii.preprocessing.inv_mulaw`
- :func:`nnmnkwii.preprocessing.mulaw_quantize`
- :func:`nnmnkwii.preprocessing.inv_mulaw_quantize`
- """
- return _sign(y) * (1.0 / mu) * ((1.0 + mu)**_abs(y) - 1.0)
-
-
-def mulaw_quantize(x, mu=256):
- """Mu-Law companding + quantize
- Args:
- x (array-like): Input signal. Each value of input signal must be in
- range of [-1, 1].
- mu (number): Compression parameter ``μ``.
- Returns:
- array-like: Quantized signal (dtype=int)
- - y ∈ [0, mu] if x ∈ [-1, 1]
- - y ∈ [0, mu) if x ∈ [-1, 1)
- .. note::
- If you want to get quantized values of range [0, mu) (not [0, mu]),
- then you need to provide input signal of range [-1, 1).
- Examples:
- >>> from scipy.io import wavfile
- >>> import pysptk
- >>> import numpy as np
- >>> from nnmnkwii import preprocessing as P
- >>> fs, x = wavfile.read(pysptk.util.example_audio_file())
- >>> x = (x / 32768.0).astype(np.float32)
- >>> y = P.mulaw_quantize(x)
- >>> print(y.min(), y.max(), y.dtype)
- 15 246 int64
- See also:
- :func:`nnmnkwii.preprocessing.mulaw`
- :func:`nnmnkwii.preprocessing.inv_mulaw`
- :func:`nnmnkwii.preprocessing.inv_mulaw_quantize`
- """
- y = mulaw(x, mu)
- # scale [-1, 1] to [0, mu]
- return _asint((y + 1) / 2 * mu)
-
-
-def inv_mulaw_quantize(y, mu=256):
- """Inverse of mu-law companding + quantize
- Args:
- y (array-like): Quantized signal (∈ [0, mu]).
- mu (number): Compression parameter ``μ``.
- Returns:
- array-like: Uncompressed signal ([-1, 1])
- Examples:
- >>> from scipy.io import wavfile
- >>> import pysptk
- >>> import numpy as np
- >>> from nnmnkwii import preprocessing as P
- >>> fs, x = wavfile.read(pysptk.util.example_audio_file())
- >>> x = (x / 32768.0).astype(np.float32)
- >>> x_hat = P.inv_mulaw_quantize(P.mulaw_quantize(x))
- >>> x_hat = (x_hat * 32768).astype(np.int16)
- See also:
- :func:`nnmnkwii.preprocessing.mulaw`
- :func:`nnmnkwii.preprocessing.inv_mulaw`
- :func:`nnmnkwii.preprocessing.mulaw_quantize`
- """
- # [0, m) to [-1, 1]
- y = 2 * _asfloat(y) / mu - 1
- return inv_mulaw(y, mu)
-
-def _sign(x):
- isnumpy = isinstance(x, np.ndarray)
- isscalar = np.isscalar(x)
- return np.sign(x) if isnumpy or isscalar else x.sign()
-
-
-def _log1p(x):
- isnumpy = isinstance(x, np.ndarray)
- isscalar = np.isscalar(x)
- return np.log1p(x) if isnumpy or isscalar else x.log1p()
-
-
-def _abs(x):
- isnumpy = isinstance(x, np.ndarray)
- isscalar = np.isscalar(x)
- return np.abs(x) if isnumpy or isscalar else x.abs()
-
-
-def _asint(x):
- # ugly wrapper to support torch/numpy arrays
- isnumpy = isinstance(x, np.ndarray)
- isscalar = np.isscalar(x)
- return x.astype(np.int) if isnumpy else int(x) if isscalar else x.long()
-
-
-def _asfloat(x):
- # ugly wrapper to support torch/numpy arrays
- isnumpy = isinstance(x, np.ndarray)
- isscalar = np.isscalar(x)
- return x.astype(np.float32) if isnumpy else float(x) if isscalar else x.float()
-
-
-#From https://github.com/r9y9/wavenet_vocoder/blob/master/lrschedule.py
-def noam_learning_rate_decay(init_lr, global_step, warmup_steps=4000):
- # Noam scheme from tensor2tensor:
- warmup_steps = float(warmup_steps)
- step = global_step + 1.
- lr = init_lr * warmup_steps**0.5 * np.minimum(
- step * warmup_steps**-1.5, step**-0.5)
- return lr
-
-
-def step_learning_rate_decay(init_lr, global_step,
- anneal_rate=0.98,
- anneal_interval=30000):
- return init_lr * anneal_rate ** (global_step // anneal_interval)
-
-
-def cyclic_cosine_annealing(init_lr, global_step, T, M):
- """Cyclic cosine annealing
-
- https://arxiv.org/pdf/1704.00109.pdf
-
- Args:
- init_lr (float): Initial learning rate
- global_step (int): Current iteration number
- T (int): Total iteration number (i,e. nepoch)
- M (int): Number of ensembles we want
-
- Returns:
- float: Annealed learning rate
- """
- TdivM = T // M
- return init_lr / 2.0 * (np.cos(np.pi * ((global_step - 1) % TdivM) / TdivM) + 1.0)
\ No newline at end of file