-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
define pheme
- Loading branch information
0 parents
commit 3658ae7
Showing
104 changed files
with
4,364 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[flake8] | ||
max-line-length = 88 | ||
exclude = .git,__pycache__,build,dist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
app.py | ||
ckpt/* | ||
*/__pycache__/* | ||
__pycache__/* | ||
exp/* | ||
datasets/* | ||
wandb/* |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Pheme Model | ||
This repo contains recipes and models used for training TTS models. | ||
|
||
Our model validates several hypotheses: | ||
1. We can train VALL-E style models with 10x less training data. | ||
2. The fundamental ingredients are the right semantic/acoustic token definition. | ||
3. The training can be performed with conversational, podcast, and noisy data like GIGA. | ||
4. The inference can be run parallelly through MASKGIT style inference. | ||
5. The quality can be improved through student-teacher training with data generated by third-party providers. | ||
|
||
|
||
Official implementation for the paper: TODO[] | ||
|
||
# Setup the environment | ||
Setup conda environment: | ||
``` | ||
conda create --name pheme3 python=3.10 | ||
conda activate pheme3 | ||
pip3 install torch torchvision torchaudio | ||
pip3 install -r requirements.txt --no-deps | ||
``` | ||
|
||
Download pre-trained SpeechTokenizer and unique token list models: | ||
``` bash | ||
st_dir="ckpt/speechtokenizer/" | ||
mkdir -p ${st_dir} | ||
cd ${st_dir} | ||
wget "https://huggingface.co/fnlp/SpeechTokenizer/resolve/main/speechtokenizer_hubert_avg/SpeechTokenizer.pt" | ||
wget "https://huggingface.co/fnlp/SpeechTokenizer/resolve/main/speechtokenizer_hubert_avg/config.json" | ||
cd .. | ||
wget "https://huggingface.co/fnlp/USLM/resolve/main/USLM_libritts/unique_text_tokens.k2symbols" | ||
``` | ||
|
||
You need to create an access token to use the speaker embedding of pyannote. | ||
``` | ||
export HUGGING_FACE_HUB_TOKEN=YOUR_PRIVATE_TOKEN | ||
``` | ||
|
||
Download pre-trained T2S and S2A models (100M): | ||
``` bash | ||
git clone https://huggingface.co/PolyAI/pheme_small ckpt/pheme | ||
mkdir -p "ckpt/t2s" | ||
mkdir -p "ckpt/s2a" | ||
mv ckpt/pheme/config_t2s.json ckpt/t2s/config.json | ||
mv ckpt/pheme/generation_config.json ckpt/t2s/generation_config.json | ||
mv ckpt/pheme/t2s.bin ckpt/t2s/pytorch_model.bin | ||
mv ckpt/pheme/config_s2a.json ckpt/s2a/config.json | ||
mv ckpt/pheme/s2a.ckpt ckpt/s2a/s2a.ckpt | ||
``` | ||
or the larger version (300M) at `https://huggingface.co/PolyAI/pheme` | ||
|
||
# Prompt-based Generation | ||
The generation can be invoked by: | ||
``` | ||
python transformer_infer.py | ||
``` | ||
# Training | ||
|
||
## Data Preparation | ||
The package requires data of the format: `datasets/example/train.json` with `datasets/audios/` where you store wav files. | ||
The manifest should follow the format: | ||
``` | ||
{ | ||
"LJ001-0051.wav": { | ||
"text": "and paying great attention to the press work or actual process of printing,", | ||
"raw-text": "and paying great attention to the press work or actual process of printing,", | ||
"duration": 4.860090702947846, | ||
"phoneme": "æ|n|d|_|p|eɪ|ɪ|ŋ|_|ɡ|ɹ|eɪ|t|_|ɐ|t|ɛ|n|ʃ|ə|n|_|t|ə|_|ð|ə|_|\"|p|ɹ|ɛ|s|_|w|ɜː|k|\"|_|ɔː|ɹ|_|æ|k|tʃ|uː|əl|_|p|ɹ|ɑː|s|ɛ|s|_|ʌ|v|_|p|ɹ|ɪ|n|t|ɪ|ŋ|," | ||
}, | ||
"LJ001-0120.wav": { | ||
... | ||
}, | ||
... | ||
} | ||
``` | ||
The following command will create semantic and acoustic tokens based on the `audios` folder. | ||
``` | ||
python utils/get_tokens_speech_tokenizer.py \ | ||
--config_path ckpt/speechtokenizer/config.json \ | ||
--ckpt_path ckpt/speechtokenizer/SpeechTokenizer.pt \ | ||
--encoding_input datasets/example/audios \ | ||
--encoding_output datasets/example/audios-speech-tokenizer | ||
``` | ||
## T2S | ||
``` | ||
python train_t2s.py --metapath datasets/example/train.json \ | ||
--val_metapath datasets/example/train.json \ | ||
--output_dir ~/experiments/t2s \ | ||
--model_size tiny --batch_size 16 \ | ||
--nworkers 12 --warmup_steps 10000 \ | ||
--save_steps 500 --n_epochs 10 | ||
``` | ||
## A2S | ||
``` | ||
python train_s2a.py --saving_path exp/a2s --sampledir exp/a2s --vocoder_type SPEECHTOKENIZER \ | ||
--n_codes 1024 --n_cluster_groups 7 --metapath datasets/example/train.json \ | ||
--val_metapath datasets/example/train.json \ | ||
--warmup_step 10000 --nworkers 12 --first_n_lvls 7 \ | ||
--batch_size 1 --ffd_size 512 --hidden_size 512 --enc_nlayers 1 --nheads 8 \ | ||
--depthwise_conv_kernel_size 5 \ | ||
--val_check_interval 1 --sample_rate 16000 --lr 5e-4 \ | ||
--check_val_every_n_epoch 1 --n_semantic_codes 1024 \ | ||
--distributed | ||
``` | ||
|
||
## Speed test | ||
### A100 GPU | ||
| Model | Batch Size | Steps | RTF (ms) | | ||
| --------------------------- | --------- | ----------- | ----------- | | ||
| T2S-S2A Short sentence | 1 | 16 | 0.133 | | ||
| T2S-S2A Long sentence | 1 | 16 | 0.133 | | ||
|
||
### A10 GPU | ||
| Model | Batch Size | Steps | RTF (ms) | | ||
| --------------------------- | --------- | ----------- | ----------- | | ||
| T2S-S2A Short sentence | 1 | 16 | 0.143 | | ||
| T2S-S2A Long sentence | 1 | 16 | 0.143 | | ||
|
||
|
||
## Acknowledge | ||
[MQTTS](https://github.com/b04901014/MQTTS)\ | ||
[SpeechTokenizer](https://github.com/ZhangXInFD/soundstorm-speechtokenizer)\ | ||
[maskgit](https://github.com/google-research/maskgit)\ | ||
[SoundStorm](https://github.com/lifeiteng/SoundStorm) | ||
|
||
## TODO | ||
1. Add Tensorrt-LLM image | ||
|
||
## Citation | ||
If you use this code or result in your paper, please cite our work as: | ||
```Tex | ||
@misc{TODO} | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
"""Constants file. | ||
Copyright PolyAI Limited. | ||
""" | ||
SPKR_EMB_SIZE = 512 | ||
|
||
PAD = 1024 | ||
|
||
SPKR_1 = 1025 | ||
SPKR_2 = 1026 | ||
|
||
BOS_TOKEN_ID = 0 | ||
PAD_TOKEN_ID = 0 | ||
EOS_TOKEN_ID = 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
"""Collators for T2S and S2A. | ||
Copyright PolyAI Limited. | ||
""" | ||
from pathlib import Path | ||
from typing import List, Tuple, Union | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from utils.symbol_table import SymbolTable | ||
|
||
|
||
class GlobalCollater: | ||
def __init__(self, n_codes, n_semantic_codes): | ||
self.n_codes = n_codes | ||
self.sem_mask_id = n_semantic_codes | ||
|
||
def collate(self, batch): | ||
output = { | ||
'speaker': [], | ||
'tts_quantize_input': [], | ||
'tts_quantize_output': [], | ||
'quantize_mask': [], | ||
'f_names': [], | ||
'semantic_tokens': [], | ||
'quantization_lengths': [], | ||
} | ||
# Get the max length of everything | ||
max_len_q = 0 | ||
for _, q_s, q_e, _, _ in batch: | ||
if len(q_s) > max_len_q: | ||
max_len_q = len(q_s) | ||
|
||
output['quantization_lengths'].append(len(q_s)) | ||
|
||
# Pad each element, create mask | ||
for spkr, qs, qe, itm_name, s_tokens in batch: | ||
# Deal with quantizations | ||
q_mask = np.array( | ||
[False] * len(qs) + [True] * (max_len_q - len(qs))) | ||
qs = np.pad( | ||
qs, | ||
[[0, max_len_q-len(qs)], [0, 0]], | ||
constant_values=self.n_codes | ||
) | ||
qe = np.pad( | ||
qe, | ||
[[0, max_len_q-len(qe)], [0, 0]], | ||
constant_values=self.n_codes | ||
) | ||
|
||
# Deal with semantics | ||
s_tokens = s_tokens.flatten() | ||
s_tokens = np.pad( | ||
s_tokens, | ||
(0, max_len_q-len(s_tokens)), | ||
constant_values=self.sem_mask_id | ||
) | ||
|
||
# Speaker padding | ||
spkr = np.concatenate( | ||
(spkr, np.zeros((max_len_q - len(spkr), 512)))) | ||
|
||
# Aggregate | ||
output['speaker'].append(spkr) | ||
output['tts_quantize_input'].append(qs) | ||
output['tts_quantize_output'].append(qe) | ||
output['quantize_mask'].append(q_mask) | ||
output['f_names'].append(itm_name) | ||
output["semantic_tokens"].append(s_tokens) | ||
|
||
for k in output.keys(): | ||
if k == 'f_names': | ||
continue | ||
output[k] = np.array(output[k]) | ||
if 'mask' in k: | ||
output[k] = torch.BoolTensor(output[k]) | ||
elif k in [ | ||
'tts_quantize_input', 'tts_quantize_output', | ||
'semantic_tokens', 'quantization_lengths' | ||
]: | ||
output[k] = torch.LongTensor(output[k]) | ||
else: | ||
output[k] = torch.FloatTensor(output[k]) | ||
return output | ||
|
||
|
||
class TextTokenCollater: | ||
def __init__( | ||
self, | ||
text_tokens: List[str], | ||
add_eos: bool = True, | ||
add_bos: bool = True, | ||
pad_symbol: str = "<pad>", | ||
bos_symbol: str = "<bos>", | ||
eos_symbol: str = "<eos>", | ||
spkr_1_symbol: str = "spkr_1", | ||
spkr_2_symbol: str = "spkr_2", | ||
): | ||
self.pad_symbol = pad_symbol | ||
|
||
self.add_eos = add_eos | ||
self.add_bos = add_bos | ||
|
||
self.bos_symbol = bos_symbol | ||
self.eos_symbol = eos_symbol | ||
self.spkr_1_symbol = spkr_1_symbol | ||
self.spkr_2_symbol = spkr_2_symbol | ||
|
||
unique_tokens = ( | ||
[pad_symbol] | ||
+ ([bos_symbol] if add_bos else []) | ||
+ ([eos_symbol] if add_eos else []) | ||
+ ([spkr_1_symbol]) | ||
+ ([spkr_2_symbol]) | ||
+ sorted(text_tokens) | ||
) | ||
|
||
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} | ||
self.idx2token = [token for token in unique_tokens] | ||
|
||
def __call__( | ||
self, texts: List[str], texts_2: Union[None, List[str]] = None | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
tokens_seqs = [[p for p in text] for text in texts] | ||
|
||
if texts_2 is None: | ||
seqs = [ | ||
([self.bos_symbol] if self.add_bos else []) | ||
+ [self.spkr_1_symbol] | ||
+ list(seq) | ||
+ ([self.eos_symbol] if self.add_eos else []) | ||
for seq in tokens_seqs | ||
] | ||
else: | ||
tokens_seqs_2 = [[p for p in text] for text in texts_2] | ||
seqs = [ | ||
([self.bos_symbol] if self.add_bos else []) | ||
+ [self.spkr_1_symbol] | ||
+ list(seq) | ||
+ ([self.spkr_2_symbol]) | ||
+ list(seq_2) | ||
+ ([self.eos_symbol] if self.add_eos else []) | ||
for seq, seq_2 in zip(tokens_seqs, tokens_seqs_2) | ||
] | ||
|
||
tokens_batch = torch.from_numpy( | ||
np.array( | ||
[[self.token2idx[token] for token in seq] for seq in seqs], | ||
dtype=np.int64, | ||
) | ||
) | ||
|
||
return tokens_batch | ||
|
||
|
||
def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater: | ||
text_tokens_path = Path(text_tokens_file) | ||
unique_tokens = SymbolTable.from_file(text_tokens_path) | ||
collater = TextTokenCollater( | ||
unique_tokens.symbols, add_bos=True, add_eos=True | ||
) | ||
return collater | ||
|
||
|
||
def get_text_semantic_token_collater( | ||
text_tokens_file: str, n_semantic_tokens=1024) -> TextTokenCollater: | ||
text_tokens_path = Path(text_tokens_file) | ||
unique_tokens = SymbolTable.from_file(text_tokens_path) | ||
for semantic_idx in range(n_semantic_tokens): | ||
unique_tokens.add(str(semantic_idx)) | ||
|
||
collater = TextTokenCollater( | ||
unique_tokens.symbols, add_bos=True, add_eos=True | ||
) | ||
return collater | ||
|
||
|
||
if __name__ == '__main__': | ||
text_tokens_file = 'ckpt/unique_text_tokens.k2symbols' | ||
collater = get_text_semantic_token_collater(text_tokens_file) |
Oops, something went wrong.