Skip to content

Commit

Permalink
PyThaiTTS v0.3.0
Browse files Browse the repository at this point in the history
- Add Lunarlist TTS model (ONNX)
- Change default model to Lunarlist TTS model
  • Loading branch information
wannaphong committed Jan 24, 2024
1 parent c51b102 commit 3508e48
Show file tree
Hide file tree
Showing 8 changed files with 520 additions and 21 deletions.
Binary file added notebook/cat.wav
Binary file not shown.
307 changes: 307 additions & 0 deletions notebook/use_lunarlist_model_onnx.ipynb

Large diffs are not rendered by default.

25 changes: 17 additions & 8 deletions pythaitts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
"""
PyThaiTTS
"""
__version__ = "0.2.1"
__version__ = "0.3.0"


class TTS:
def __init__(self, pretrained="khanomtan", mode="last_checkpoint", version="1.0", device:str="cpu") -> None:
def __init__(self, pretrained="lunarlist_onnx", mode="last_checkpoint", version="1.0", device:str="cpu") -> None:
"""
:param str pretrained: TTS pretrained (khanomtan, lunarlist)
:param str mode: pretrained mode
:param str pretrained: TTS pretrained (lunarlist_onnx, khanomtan, lunarlist)
:param str mode: pretrained mode (lunarlist_onnx don't support)
:param str version: model version (default is 1.0 or 1.1)
:param str device: device for running model. (lunarlist_onnx support CPU only.)
**Options for mode**
* *last_checkpoint* (default) - last checkpoint of model
Expand All @@ -21,6 +22,11 @@ def __init__(self, pretrained="khanomtan", mode="last_checkpoint", version="1.0"
For lunarlist tts model, you must to install nemo before use the model by pip install nemo_toolkit['tts'].
You can see more about lunarlist tts at `https://link.medium.com/OpPjQis6wBb <https://link.medium.com/OpPjQis6wBb>`_
For lunarlist_onnx tts model, \
You can see more about lunarlist tts at `https://github.com/PyThaiNLP/thaitts-onnx <https://github.com/PyThaiNLP/thaitts-onnx>`_
"""
self.pretrained = pretrained
Expand All @@ -32,11 +38,14 @@ def load_pretrained(self,version):
"""
Load pretrained
"""
if self.pretrained == "khanomtan":
from pythaitts.pretrained import KhanomTan
if self.pretrained == "lunarlist_onnx":
from pythaitts.pretrained.lunarlist_onnx import LunarlistONNX
self.model = LunarlistONNX()
elif self.pretrained == "khanomtan":
from pythaitts.pretrained.khanomtan_tts import KhanomTan
self.model = KhanomTan(mode=self.mode, version=version)
elif self.pretrained == "lunarlist":
from pythaitts.pretrained import LunarlistModel
from pythaitts.pretrained.lunarlist_model import LunarlistModel
self.model = LunarlistModel(mode=self.mode, device=self.device)
else:
raise NotImplemented(
Expand All @@ -53,7 +62,7 @@ def tts(self, text: str, speaker_idx: str = "Linda", language_idx: str = "th-th"
:param str return_type: return type (default is file)
:param str filename: path filename for save wav file if return_type is file.
"""
if self.pretrained == "lunarlist":
if self.pretrained == "lunarlist" or self.pretrained == "lunarlist_onnx":
return self.model(text=text,return_type=return_type,filename=filename)
return self.model(
text=text,
Expand Down
8 changes: 0 additions & 8 deletions pythaitts/pretrained/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +0,0 @@
# -*- coding: utf-8 -*-
from pythaitts.pretrained.khanomtan_tts import KhanomTan
from pythaitts.pretrained.lunarlist_model import LunarlistModel

__all__ = [
"KhanomTan",
"LunarlistModel"
]
5 changes: 4 additions & 1 deletion pythaitts/pretrained/lunarlist_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
You can see more about lunarlist tts at `https://link.medium.com/OpPjQis6wBb <https://link.medium.com/OpPjQis6wBb>`_
"""
import tempfile
import torch
try:
import torch
except ImportError:
raise ImportError("You must to install torch before use this model.")


class LunarlistModel:
Expand Down
189 changes: 189 additions & 0 deletions pythaitts/pretrained/lunarlist_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# -*- coding: utf-8 -*-
"""
Lunarlist TTS model (ONNX)
You can see more about lunarlist tts at `https://link.medium.com/OpPjQis6wBb <https://link.medium.com/OpPjQis6wBb>`_
ONNX port: `https://github.com/PyThaiNLP/thaitts-onnx <https://github.com/PyThaiNLP/thaitts-onnx>`_
"""
import tempfile
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download


# from https://huggingface.co/lunarlist/tts-thai-last-step
index_list=['ก', 'ข', 'ค', 'ฆ', 'ง', 'จ', 'ฉ', 'ช', 'ซ', 'ฌ', 'ญ', 'ฎ', 'ฏ', 'ฐ', 'ฑ', 'ฒ', 'ณ', 'ด', 'ต', 'ถ', 'ท', 'ธ', 'น', 'บ', 'ป', 'ผ', 'ฝ', 'พ', 'ฟ', 'ภ', 'ม', 'ย', 'ร', 'ฤ', 'ล', 'ว', 'ศ', 'ษ', 'ส', 'ห', 'ฬ', 'อ', 'ฮ', 'ะ', 'ั', 'า', 'ำ', 'ิ', 'ี', 'ึ', 'ื', 'ุ', 'ู', 'เ', 'แ', 'โ', 'ใ', 'ไ', 'ๅ', '็', '่', '้', '๊', '๋', '์', ' ']
dict_idx={k:i for i,k in enumerate(index_list)}

def clean(text):
seq = np.array([[66]+[dict_idx[i] for i in text if i]+[67]])
_s=np.array([len(seq[0])])
return seq,_s

n_mel_channels = 80
n_frames_per_step = 1
attention_rnn_dim = 1024
decoder_rnn_dim=1024
encoder_embedding_dim=512

def initialize_decoder_states(memory):
B = memory.shape[0]
MAX_TIME = memory.shape[1]

attention_hidden = np.zeros((B, attention_rnn_dim), dtype=np.float32)
attention_cell = np.zeros((B, attention_rnn_dim), dtype=np.float32)

decoder_hidden = np.zeros((B, decoder_rnn_dim), dtype=np.float32)
decoder_cell = np.zeros((B, decoder_rnn_dim), dtype=np.float32)

attention_weights = np.zeros((B, MAX_TIME), dtype=np.float32)
attention_weights_cum = np.zeros((B, MAX_TIME), dtype=np.float32)
attention_context = np.zeros((B, encoder_embedding_dim), dtype=np.float32)

return (
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
)


def get_go_frame(memory):
B = memory.shape[0]
decoder_input = np.zeros((B, n_mel_channels*n_frames_per_step), dtype=np.float32)
return decoder_input


def sigmoid(x):
return np.exp(-np.logaddexp(0, -x))


def parse_decoder_outputs(mel_outputs, gate_outputs, alignments):
# (T_out, B) -> (B, T_out)
alignments = np.stack(alignments).transpose((1, 0, 2, 3))
# (T_out, B) -> (B, T_out)
# Add a -1 to prevent squeezing the batch dimension in case
# batch is 1
gate_outputs = np.stack(gate_outputs).squeeze(-1).transpose((1, 0, 2))
# (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
mel_outputs = np.stack(mel_outputs).transpose((1, 0, 2, 3))
# decouple frames per step
mel_outputs = mel_outputs.reshape(mel_outputs.shape[0], -1, n_mel_channels)
# (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
mel_outputs = mel_outputs.transpose((0, 2, 1))

return mel_outputs, gate_outputs, alignments


# only numpy operations
def inference(text, encoder, decoder_iter, postnet):
sequences, sequence_lengths = clean(text)

# print("Running Tacotron2 Encoder")
inputs = {"seq": sequences, "seq_len": sequence_lengths}
memory, processed_memory, _ = encoder.run(None, inputs)

# print("Running Tacotron2 Decoder")
mel_lengths = np.zeros([memory.shape[0]], dtype=np.int32)
not_finished = np.ones([memory.shape[0]], dtype=np.int32)
mel_outputs, gate_outputs, alignments = [], [], []
gate_threshold = 0.5
max_decoder_steps = 5000
first_iter = True

(
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
) = initialize_decoder_states(memory)

decoder_input = get_go_frame(memory)

while True:
inputs = {
"decoder_input": decoder_input,
"attention_hidden": attention_hidden,
"attention_cell": attention_cell,
"decoder_hidden": decoder_hidden,
"decoder_cell": decoder_cell,
"attention_weights": attention_weights,
"attention_weights_cum": attention_weights_cum,
"attention_context": attention_context,
"memory": memory,
"processed_memory": processed_memory,
}
(
mel_output,
gate_output,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
) = decoder_iter.run(None, inputs)

if first_iter:
mel_outputs = [np.expand_dims(mel_output, 2)]
gate_outputs = [np.expand_dims(gate_output, 2)]
alignments = [np.expand_dims(attention_weights, 2)]
first_iter = False
else:
mel_outputs += [np.expand_dims(mel_output, 2)]
gate_outputs += [np.expand_dims(gate_output, 2)]
alignments += [np.expand_dims(attention_weights, 2)]

dec = np.less(sigmoid(gate_output), gate_threshold)
dec = np.squeeze(dec, axis=1)
not_finished = not_finished * dec
mel_lengths += not_finished

if not_finished.sum() == 0:
# print("Stopping after ", len(mel_outputs), " decoder steps")
break
if len(mel_outputs) == max_decoder_steps:
# print("Warning! Reached max decoder steps")
break

decoder_input = mel_output

mel_outputs, gate_outputs, alignments = parse_decoder_outputs(
mel_outputs, gate_outputs, alignments
)

# print("Running Tacotron2 PostNet")
inputs = {"mel_spec": mel_outputs}
mel_outputs_postnet = postnet.run(None, inputs)

return mel_outputs_postnet

class LunarlistONNX:
def __init__(self) -> None:
self.encoder = ort.InferenceSession(hf_hub_download(repo_id="pythainlp/thaitts-onnx",filename="tacotron2encoder-th.onnx"))
self.decoder = ort.InferenceSession(hf_hub_download(repo_id="pythainlp/thaitts-onnx",filename="tacotron2decoder-th.onnx"))
self.postnet = ort.InferenceSession(hf_hub_download(repo_id="pythainlp/thaitts-onnx",filename="tacotron2postnet-th.onnx"))
self.hifi = ort.InferenceSession(hf_hub_download(repo_id="pythainlp/thaitts-onnx",filename="vocoder.onnx"))
def tts(self, text: str):
mel = inference(text, self.encoder, self.decoder, self.postnet)
return self.hifi.run(None, {"spec": mel[0]})
def __call__(self, text: str,return_type: str = "file", filename: str = None):
wavs = self.tts(text)
if return_type == "waveform":
return wavs[0][0, 0, :]
import soundfile as sf
if filename != None:
sf.write(filename, wavs[0][0, 0, :], 22050)
return filename
else:
with tempfile.NamedTemporaryFile(suffix = ".wav", delete = False) as fp:
sf.write(fp.name, wavs[0][0, 0, :], 22050)
return fp.name
5 changes: 2 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
TTS>=0.8.0
pythainlp>=3.0.0
huggingface_hub
torch
numpy>=1.22
onnxruntime
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setup(
name="PyThaiTTS",
version="0.2.1",
version="0.3.0",
description="Open Source Thai Text-to-speech library in Python",
long_description=readme,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 3508e48

Please sign in to comment.