diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..49cc41f --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 88 +exclude = .git,__pycache__,build,dist diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..85afc90 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +app.py +ckpt/* +*/__pycache__/* +__pycache__/* +exp/* +datasets/* +wandb/* \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..52bd145 --- /dev/null +++ b/LICENSE @@ -0,0 +1,395 @@ +Attribution 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution 4.0 International Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution 4.0 International Public License ("Public License"). To the +extent this Public License may be interpreted as a contract, You are +granted the Licensed Rights in consideration of Your acceptance of +these terms and conditions, and the Licensor grants You such rights in +consideration of benefits the Licensor receives from making the +Licensed Material available under these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + j. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + k. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part; and + + b. produce, reproduce, and Share Adapted Material. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..941edb4 --- /dev/null +++ b/README.md @@ -0,0 +1,135 @@ +# Pheme Model +This repo contains recipes and models used for training TTS models. + +Our model validates several hypotheses: +1. We can train VALL-E style models with 10x less training data. +2. The fundamental ingredients are the right semantic/acoustic token definition. +3. The training can be performed with conversational, podcast, and noisy data like GIGA. +4. The inference can be run parallelly through MASKGIT style inference. +5. The quality can be improved through student-teacher training with data generated by third-party providers. + + +Official implementation for the paper: TODO[] + +# Setup the environment +Setup conda environment: +``` +conda create --name pheme3 python=3.10 +conda activate pheme3 + +pip3 install torch torchvision torchaudio +pip3 install -r requirements.txt --no-deps +``` + +Download pre-trained SpeechTokenizer and unique token list models: +``` bash +st_dir="ckpt/speechtokenizer/" +mkdir -p ${st_dir} +cd ${st_dir} +wget "https://huggingface.co/fnlp/SpeechTokenizer/resolve/main/speechtokenizer_hubert_avg/SpeechTokenizer.pt" +wget "https://huggingface.co/fnlp/SpeechTokenizer/resolve/main/speechtokenizer_hubert_avg/config.json" +cd .. +wget "https://huggingface.co/fnlp/USLM/resolve/main/USLM_libritts/unique_text_tokens.k2symbols" +``` + +You need to create an access token to use the speaker embedding of pyannote. +``` +export HUGGING_FACE_HUB_TOKEN=YOUR_PRIVATE_TOKEN +``` + +Download pre-trained T2S and S2A models (100M): +``` bash +git clone https://huggingface.co/PolyAI/pheme_small ckpt/pheme +mkdir -p "ckpt/t2s" +mkdir -p "ckpt/s2a" +mv ckpt/pheme/config_t2s.json ckpt/t2s/config.json +mv ckpt/pheme/generation_config.json ckpt/t2s/generation_config.json +mv ckpt/pheme/t2s.bin ckpt/t2s/pytorch_model.bin +mv ckpt/pheme/config_s2a.json ckpt/s2a/config.json +mv ckpt/pheme/s2a.ckpt ckpt/s2a/s2a.ckpt +``` +or the larger version (300M) at `https://huggingface.co/PolyAI/pheme` + +# Prompt-based Generation +The generation can be invoked by: +``` +python transformer_infer.py +``` +# Training + +## Data Preparation +The package requires data of the format: `datasets/example/train.json` with `datasets/audios/` where you store wav files. +The manifest should follow the format: +``` +{ + "LJ001-0051.wav": { + "text": "and paying great attention to the press work or actual process of printing,", + "raw-text": "and paying great attention to the press work or actual process of printing,", + "duration": 4.860090702947846, + "phoneme": "æ|n|d|_|p|eɪ|ɪ|ŋ|_|ɡ|ɹ|eɪ|t|_|ɐ|t|ɛ|n|ʃ|ə|n|_|t|ə|_|ð|ə|_|\"|p|ɹ|ɛ|s|_|w|ɜː|k|\"|_|ɔː|ɹ|_|æ|k|tʃ|uː|əl|_|p|ɹ|ɑː|s|ɛ|s|_|ʌ|v|_|p|ɹ|ɪ|n|t|ɪ|ŋ|," + }, + "LJ001-0120.wav": { + ... + }, + ... +} + +``` +The following command will create semantic and acoustic tokens based on the `audios` folder. +``` +python utils/get_tokens_speech_tokenizer.py \ + --config_path ckpt/speechtokenizer/config.json \ + --ckpt_path ckpt/speechtokenizer/SpeechTokenizer.pt \ + --encoding_input datasets/example/audios \ + --encoding_output datasets/example/audios-speech-tokenizer +``` +## T2S +``` +python train_t2s.py --metapath datasets/example/train.json \ + --val_metapath datasets/example/train.json \ + --output_dir ~/experiments/t2s \ + --model_size tiny --batch_size 16 \ + --nworkers 12 --warmup_steps 10000 \ + --save_steps 500 --n_epochs 10 +``` +## A2S +``` +python train_s2a.py --saving_path exp/a2s --sampledir exp/a2s --vocoder_type SPEECHTOKENIZER \ + --n_codes 1024 --n_cluster_groups 7 --metapath datasets/example/train.json \ + --val_metapath datasets/example/train.json \ + --warmup_step 10000 --nworkers 12 --first_n_lvls 7 \ + --batch_size 1 --ffd_size 512 --hidden_size 512 --enc_nlayers 1 --nheads 8 \ + --depthwise_conv_kernel_size 5 \ + --val_check_interval 1 --sample_rate 16000 --lr 5e-4 \ + --check_val_every_n_epoch 1 --n_semantic_codes 1024 \ + --distributed +``` + +## Speed test +### A100 GPU +| Model | Batch Size | Steps | RTF (ms) | +| --------------------------- | --------- | ----------- | ----------- | +| T2S-S2A Short sentence | 1 | 16 | 0.133 | +| T2S-S2A Long sentence | 1 | 16 | 0.133 | + +### A10 GPU +| Model | Batch Size | Steps | RTF (ms) | +| --------------------------- | --------- | ----------- | ----------- | +| T2S-S2A Short sentence | 1 | 16 | 0.143 | +| T2S-S2A Long sentence | 1 | 16 | 0.143 | + + +## Acknowledge +[MQTTS](https://github.com/b04901014/MQTTS)\ +[SpeechTokenizer](https://github.com/ZhangXInFD/soundstorm-speechtokenizer)\ +[maskgit](https://github.com/google-research/maskgit)\ +[SoundStorm](https://github.com/lifeiteng/SoundStorm) + +## TODO +1. Add Tensorrt-LLM image + +## Citation +If you use this code or result in your paper, please cite our work as: +```Tex +@misc{TODO} +``` diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/constants.py b/constants.py new file mode 100644 index 0000000..b6fa07a --- /dev/null +++ b/constants.py @@ -0,0 +1,14 @@ +"""Constants file. + +Copyright PolyAI Limited. +""" +SPKR_EMB_SIZE = 512 + +PAD = 1024 + +SPKR_1 = 1025 +SPKR_2 = 1026 + +BOS_TOKEN_ID = 0 +PAD_TOKEN_ID = 0 +EOS_TOKEN_ID = 2 \ No newline at end of file diff --git a/data/collation.py b/data/collation.py new file mode 100644 index 0000000..9afdfc4 --- /dev/null +++ b/data/collation.py @@ -0,0 +1,182 @@ +"""Collators for T2S and S2A. + +Copyright PolyAI Limited. +""" +from pathlib import Path +from typing import List, Tuple, Union + +import numpy as np +import torch + +from utils.symbol_table import SymbolTable + + +class GlobalCollater: + def __init__(self, n_codes, n_semantic_codes): + self.n_codes = n_codes + self.sem_mask_id = n_semantic_codes + + def collate(self, batch): + output = { + 'speaker': [], + 'tts_quantize_input': [], + 'tts_quantize_output': [], + 'quantize_mask': [], + 'f_names': [], + 'semantic_tokens': [], + 'quantization_lengths': [], + } + # Get the max length of everything + max_len_q = 0 + for _, q_s, q_e, _, _ in batch: + if len(q_s) > max_len_q: + max_len_q = len(q_s) + + output['quantization_lengths'].append(len(q_s)) + + # Pad each element, create mask + for spkr, qs, qe, itm_name, s_tokens in batch: + # Deal with quantizations + q_mask = np.array( + [False] * len(qs) + [True] * (max_len_q - len(qs))) + qs = np.pad( + qs, + [[0, max_len_q-len(qs)], [0, 0]], + constant_values=self.n_codes + ) + qe = np.pad( + qe, + [[0, max_len_q-len(qe)], [0, 0]], + constant_values=self.n_codes + ) + + # Deal with semantics + s_tokens = s_tokens.flatten() + s_tokens = np.pad( + s_tokens, + (0, max_len_q-len(s_tokens)), + constant_values=self.sem_mask_id + ) + + # Speaker padding + spkr = np.concatenate( + (spkr, np.zeros((max_len_q - len(spkr), 512)))) + + # Aggregate + output['speaker'].append(spkr) + output['tts_quantize_input'].append(qs) + output['tts_quantize_output'].append(qe) + output['quantize_mask'].append(q_mask) + output['f_names'].append(itm_name) + output["semantic_tokens"].append(s_tokens) + + for k in output.keys(): + if k == 'f_names': + continue + output[k] = np.array(output[k]) + if 'mask' in k: + output[k] = torch.BoolTensor(output[k]) + elif k in [ + 'tts_quantize_input', 'tts_quantize_output', + 'semantic_tokens', 'quantization_lengths' + ]: + output[k] = torch.LongTensor(output[k]) + else: + output[k] = torch.FloatTensor(output[k]) + return output + + +class TextTokenCollater: + def __init__( + self, + text_tokens: List[str], + add_eos: bool = True, + add_bos: bool = True, + pad_symbol: str = "", + bos_symbol: str = "", + eos_symbol: str = "", + spkr_1_symbol: str = "spkr_1", + spkr_2_symbol: str = "spkr_2", + ): + self.pad_symbol = pad_symbol + + self.add_eos = add_eos + self.add_bos = add_bos + + self.bos_symbol = bos_symbol + self.eos_symbol = eos_symbol + self.spkr_1_symbol = spkr_1_symbol + self.spkr_2_symbol = spkr_2_symbol + + unique_tokens = ( + [pad_symbol] + + ([bos_symbol] if add_bos else []) + + ([eos_symbol] if add_eos else []) + + ([spkr_1_symbol]) + + ([spkr_2_symbol]) + + sorted(text_tokens) + ) + + self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} + self.idx2token = [token for token in unique_tokens] + + def __call__( + self, texts: List[str], texts_2: Union[None, List[str]] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + tokens_seqs = [[p for p in text] for text in texts] + + if texts_2 is None: + seqs = [ + ([self.bos_symbol] if self.add_bos else []) + + [self.spkr_1_symbol] + + list(seq) + + ([self.eos_symbol] if self.add_eos else []) + for seq in tokens_seqs + ] + else: + tokens_seqs_2 = [[p for p in text] for text in texts_2] + seqs = [ + ([self.bos_symbol] if self.add_bos else []) + + [self.spkr_1_symbol] + + list(seq) + + ([self.spkr_2_symbol]) + + list(seq_2) + + ([self.eos_symbol] if self.add_eos else []) + for seq, seq_2 in zip(tokens_seqs, tokens_seqs_2) + ] + + tokens_batch = torch.from_numpy( + np.array( + [[self.token2idx[token] for token in seq] for seq in seqs], + dtype=np.int64, + ) + ) + + return tokens_batch + + +def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater: + text_tokens_path = Path(text_tokens_file) + unique_tokens = SymbolTable.from_file(text_tokens_path) + collater = TextTokenCollater( + unique_tokens.symbols, add_bos=True, add_eos=True + ) + return collater + + +def get_text_semantic_token_collater( + text_tokens_file: str, n_semantic_tokens=1024) -> TextTokenCollater: + text_tokens_path = Path(text_tokens_file) + unique_tokens = SymbolTable.from_file(text_tokens_path) + for semantic_idx in range(n_semantic_tokens): + unique_tokens.add(str(semantic_idx)) + + collater = TextTokenCollater( + unique_tokens.symbols, add_bos=True, add_eos=True + ) + return collater + + +if __name__ == '__main__': + text_tokens_file = 'ckpt/unique_text_tokens.k2symbols' + collater = get_text_semantic_token_collater(text_tokens_file) diff --git a/data/data_module.py b/data/data_module.py new file mode 100644 index 0000000..7f2532e --- /dev/null +++ b/data/data_module.py @@ -0,0 +1,119 @@ +"""Data module. + +Copyright PolyAI Limited. +""" +import typing +from pathlib import Path +from typing import List + +import lightning.pytorch as pl +from torch.utils import data + +from data.collation import GlobalCollater +from data.sampler import RandomBucketSampler +from data.single_speaker_dataset import QuantizeDataset +from utils import breakpoint_on_error + + +class ConcatDataset(data.ConcatDataset): + def __init__(self, datasets) -> None: + super().__init__(datasets) + self.lengths = [] + for dataset in datasets: + self.lengths.extend(dataset.lengths) + + +class DataModule(pl.LightningDataModule): + def __init__( + self, hp, metapath: List[str], val_metapath: List[str], + world_size, local_rank + ): + super().__init__() + self.hp = hp + self.metapath = metapath + self.val_metapath = val_metapath + self.world_size = world_size + self.local_rank = local_rank + self.collater = GlobalCollater( + self.hp.n_codes, self.hp.n_semantic_codes) + + def setup(self, stage: str) -> None: + if stage == "fit": + self.train_data = self.concatenate_datasets( + self.metapath, dataset_class=QuantizeDataset + ) + + if stage == "valid": + self.val_data = [] + self.val_data_keys = [] + self.prepare_val_datasets() + assert len(self.val_data) > 0 + assert len(self.val_data_keys) > 0 + + @breakpoint_on_error + def concatenate_datasets( + self, metapaths, dataset_class: typing.Type[QuantizeDataset]): + data = [] + for _, metapath in enumerate(metapaths): + metapath = Path(metapath) + # assumption that audios and audios-embeddings + # are in the same folder as metapath + datadir = metapath.with_name("audios") + assert datadir.exists() + data.append( + dataset_class( + self.hp, + metapath, + datadir=datadir, + speaker_embedding_dir=None, + ) + ) + return ConcatDataset(data) + + def prepare_val_datasets(self): + for manifest in self.val_metapath: + self.val_data.append( + self.concatenate_datasets( + [manifest], dataset_class=QuantizeDataset) + ) + name = Path(manifest).parent.name + self.val_data_keys.append(name) + + assert len(self.val_data) == len(self.val_data_keys) + + def train_dataloader(self): + length = self.train_data.lengths + sampler = RandomBucketSampler( + self.hp.train_bucket_size, + length, + self.hp.batch_size, + drop_last=True, + distributed=self.hp.distributed, + world_size=self.world_size, + rank=self.local_rank, + ) + dataloader = data.DataLoader( + self.train_data, + num_workers=self.hp.nworkers, + batch_sampler=sampler, + collate_fn=self.collater.collate, + pin_memory=True + ) + + return dataloader + + def val_dataloader(self): + val_loaders = [] + for dataset in self.val_data: + val_loaders.append( + data.DataLoader( + dataset, + num_workers=self.hp.nworkers, + batch_size=int(self.hp.batch_size), + collate_fn=self.collater.collate, + shuffle=False, + pin_memory=True + ) + ) + + return val_loaders diff --git a/data/sampler.py b/data/sampler.py new file mode 100644 index 0000000..3cc0b71 --- /dev/null +++ b/data/sampler.py @@ -0,0 +1,115 @@ +"""Original sampling logic of MQTTS. + +Copyright PolyAI Limited. +""" +import math +import random + +import numpy as np +from torch.utils import data + + +def StandardSampler(dataset, shuffle, distributed=False, + world_size=None, rank=None): + if distributed: + return data.distributed.DistributedSampler( + dataset, shuffle=shuffle, num_replicas=world_size, rank=rank) + if shuffle: + return data.RandomSampler(dataset) + return data.SequentialSampler(dataset) + + +def RandomBucketSampler( + nbuckets, length, batch_size, drop_last, distributed=False, + world_size=None, rank=None): + if distributed: + return DistributedRandomBucketSampler( + nbuckets, length, batch_size, drop_last, world_size, rank) + return SingleRandomBucketSampler(nbuckets, length, batch_size, drop_last) + + +class SingleRandomBucketSampler(data.Sampler): + def __init__(self, nbuckets, length, batch_size, drop_last): + self.length = length + self.batch_size = batch_size + self.drop_last = drop_last + indices = np.argsort([-x for x in length]) + split = len(indices) // nbuckets + self.indices = [] + for i in range(nbuckets): + self.indices.append(indices[i*split:(i+1)*split]) + if nbuckets * split < len(length): + self.indices.append(indices[nbuckets*split:]) + + def __iter__(self): + random.shuffle(self.indices) + for x in self.indices: + random.shuffle(x) + idxs = [i for x in self.indices for i in x] + batches, batch, sum_len, max_len = [], [], 0, 0 + for idx in idxs: + batch.append(idx) + sum_len += self.length[idx] + max_len = max(self.length[idx], max_len) + if max_len * len(batch) > self.batch_size: + batches.append(batch[:-1]) + batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx] # noqa + if len(batch) > 0 and not self.drop_last: + batches.append(batch) + random.shuffle(batches) + return iter(batches) + + +class DistributedRandomBucketSampler(data.Sampler): + def __init__(self, nbuckets, length, batch_size, + drop_last, num_replicas, rank, seed=1234): + if rank >= num_replicas or rank < 0: + raise ValueError( + "Invalid rank {}, rank should be in the interval" + " [0, {}]".format(rank, num_replicas - 1)) + indices = np.argsort(length) + split = len(indices) // nbuckets + self.length = length + self.batch_size = batch_size + self.drop_last = drop_last + self.indices = [] + for i in range(nbuckets): + self.indices.append(indices[i*split:(i+1)*split]) + if nbuckets * split < len(length): + self.indices.append(indices[nbuckets*split:]) + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.seed = seed + + def __iter__(self): + # Deterministic shuffling + random.Random(self.epoch + self.seed).shuffle(self.indices) + for i, x in enumerate(self.indices): + seed = self.epoch + self.seed + i * 5 + random.Random(seed).shuffle(x) + indices = [i for x in self.indices for i in x] + + # Batching + batches, batch, sum_len, max_len = [], [], 0, 0 + for idx in indices: + batch.append(idx) + sum_len += self.length[idx] + max_len = max(self.length[idx], max_len) + if max_len * len(batch) > self.batch_size: + batches.append(batch[:-1]) + batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx] # noqa + # Subsample + num_samples = math.ceil( + (len(batches) - self.num_replicas) / self.num_replicas) + total_size = num_samples * self.num_replicas + batches = batches[:total_size] + batches = batches[self.rank*num_samples: (self.rank+1)*num_samples] + assert len(batches) == num_samples + + # Stochastic suffling + random.shuffle(batches) + return iter(batches) + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/data/semantic_dataset.py b/data/semantic_dataset.py new file mode 100644 index 0000000..c660435 --- /dev/null +++ b/data/semantic_dataset.py @@ -0,0 +1,207 @@ +"""Semantic tokens loading logic. + +Copyright PolyAI Limited. +""" +import json +import logging +import random +import re +from logging import getLogger +from pathlib import Path +from typing import List, Pattern, Union + +import numpy as np +import torch +from phonemizer.backend import EspeakBackend +from phonemizer.backend.espeak.language_switch import LanguageSwitch +from phonemizer.backend.espeak.words_mismatch import WordMismatch +from phonemizer.punctuation import Punctuation +from phonemizer.separator import Separator +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from data.collation import get_text_semantic_token_collater + + +class TextTokenizer: + """Phonemize Text.""" + + def __init__( + self, + language="en-us", + backend="espeak", + separator=Separator(word="_", syllable="-", phone="|"), + preserve_punctuation=True, + punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), + with_stress: bool = False, + tie: Union[bool, str] = False, + language_switch: LanguageSwitch = "keep-flags", + words_mismatch: WordMismatch = "ignore", + ) -> None: + logger = getLogger("phonemizer") + logger.setLevel(logging.ERROR) + if backend == "espeak": + phonemizer = EspeakBackend( + language, + punctuation_marks=punctuation_marks, + preserve_punctuation=preserve_punctuation, + with_stress=with_stress, + tie=tie, + language_switch=language_switch, + words_mismatch=words_mismatch, + logger=logger, + ) + else: + raise NotImplementedError(f"{backend}") + + self.backend = phonemizer + self.separator = separator + + def to_list(self, phonemized: str) -> List[str]: + fields = [] + for word in phonemized.split(self.separator.word): + # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. + pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) + fields.extend( + [p for p in pp if p != self.separator.phone] + [self.separator.word] + ) + assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( + self.separator.phone + ) + return fields[:-1] + + def __call__(self, text, strip=True) -> List[List[str]]: + if isinstance(text, str): + text = [text] + + phonemized = self.backend.phonemize( + text, separator=self.separator, strip=strip, njobs=1 + ) + return [self.to_list(p) for p in phonemized] + + +class Collator: + def collate(self, batch): + input_ids = [item["input_ids"] for item in batch] + output_sequences = [item["labels"] for item in batch] + + # Pad sequences to the maximum length in the batch + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, batch_first=True, padding_value=0 + ) + output_sequences = torch.nn.utils.rnn.pad_sequence( + output_sequences, batch_first=True, padding_value=-100 + ) + # 1 - token is unmasked, 0 - token is masked. + attention_mask = input_ids != 0 + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": output_sequences, + } + +class ConcatenateSemanticDataset(Dataset): + def __init__( + self, manifest_path: str, symbol_table_path: str, + n_samples: int = 0, max_duration=15): + self.data = [] + self.phonemizer = TextTokenizer() + self.text_collater = get_text_semantic_token_collater( + symbol_table_path) + self.manifest_path = manifest_path + self.n_samples = n_samples + self.max_duration = max_duration + if manifest_path is not None: + self._build() + + def __len__(self): + if self.n_samples: + return min(self.n_samples, len(self.data)) + return len(self.data) + + def remove_unknown_symbols(self, text: List[str]): + res = [] + for sym in text: + if sym not in self.text_collater.token2idx: + # print(f'{sym} is unk') + continue + res.append(sym) + return res + + def __getitem__(self, idx): + item = self.data[idx] + + input_ids = item["phoneme"].split("|") + input_ids = self.remove_unknown_symbols(input_ids) + + input_ids_2 = None + if item.get("phoneme_2"): + input_ids_2 = item["phoneme_2"].split("|") + input_ids_2 = [self.remove_unknown_symbols(input_ids_2)] + + input_ids = self.text_collater( + [input_ids], input_ids_2).to(dtype=torch.long) + input_ids = input_ids.to(dtype=torch.long) + + labels = np.load(item["semantic_path"]) + labels = [str(lbl) for lbl in labels] + + labels_2 = None + if item.get("semantic_path_2"): + labels_2 = np.load(item["semantic_path_2"]) + labels_2 = [[str(lbl) for lbl in labels_2]] + + labels = self.text_collater([labels], labels_2).to(dtype=torch.long) + + return {"input_ids": input_ids.squeeze(0), "labels": labels.squeeze(0)} + + # TODO - remove this to not load to the memory + def _build(self): + for manifest_path in self.manifest_path: + dataset_path = Path(manifest_path).parent + + with open(manifest_path, "r") as manifest_file: + manifest_data = json.load(manifest_file) + + for key, value in tqdm(manifest_data.items()): + if float(value["duration"]) > self.max_duration: + continue + text = value["text"] + phoneme = value["phoneme"] + npy_path = f"{dataset_path}/audios-speech-tokenizer/semantic/{key.split('.wav')[0]}.npy" # noqa + datapoint = { + "text": text, + "semantic_path": npy_path, + "phoneme": phoneme + } + self.data.append(datapoint) + + print(f"Total length of the dataset {manifest_path}: {len(self.data)}") + + random.shuffle(self.data) + + +if __name__ == "__main__": + # Create an instance of the dataset + manifest_path = "datasets/ljspeech-training-data/dev.json" + text_tokens_file = "ckpt/unique_text_tokens.k2symbols" + seq2seq_dataset = ConcatenateSemanticDataset( + [manifest_path, manifest_path], text_tokens_file) + + # seq2seq_dataset.phonemize_and_rewrite_manifest() + batch_size = 1 # Adjust to your desired batch size + dataloader = DataLoader( + seq2seq_dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=Collator().collate, + ) + + for batch in dataloader: + print(batch["input_ids"]) + print(batch["labels"]) + print(batch["input_ids"][0].unique().max()) + print(batch["input_ids"][0].unique().min()) + print(batch["input_ids"].shape) + print(batch["labels"].shape) + break # Stop after the first batch if needed diff --git a/data/single_speaker_dataset.py b/data/single_speaker_dataset.py new file mode 100644 index 0000000..def1595 --- /dev/null +++ b/data/single_speaker_dataset.py @@ -0,0 +1,167 @@ +"""Main loading function. + +Copyright PolyAI Limited. +""" +import json +import os +import random +from pathlib import Path + +import numpy as np +import soundfile as sf +import torch +from librosa.util import normalize +from pyannote.audio import Inference +from torch.utils import data + +import constants as c + + +def random_crop(x, maxseqlen): + if x.shape[0] >= maxseqlen: + offset = random.randrange(x.shape[0] - maxseqlen + 1) + x = x[offset: offset + maxseqlen] + else: + offset = 0 + return x, offset + + +def dynamic_range_compression(x, C=0.3, M=6.5, clip_val=1e-5): + return (np.log(np.clip(x, a_min=clip_val, a_max=None)) + M) * C + + +def dynamic_range_decompression(x, C=0.3, M=6.5): + return np.exp(x / C - M) + + +class QuantizeDataset(data.Dataset): + def __init__(self, hp, metapath, datadir=None, speaker_embedding_dir=None): + self.hp = hp + self.datadir = Path(datadir) + self.speaker_embedding_dir = speaker_embedding_dir + self.sem_mask_id = hp.n_semantic_codes + + print(f"Loading metadata in {metapath}...") + with open(metapath, "r") as f: + self.text = json.load(f) + if 0 < self.hp.max_dataset_samples < len(self.text): + self.new_text = {} + num = 0 + for k, v in self.text.items(): + if num >= self.hp.max_dataset_samples: + break + self.new_text[k] = v + num += 1 + self.text = self.new_text + + self.datasetbase = [x for x in self.text.keys()] + self.dataset = [ + os.path.join(self.datadir, x) for x in self.datasetbase] + + if self.speaker_embedding_dir is None: + self.spkr_embedding = Inference( + "pyannote/embedding", + window="whole", + use_auth_token=os.environ["HUGGING_FACE_HUB_TOKEN"], + ) + + # Print statistics: + n = len(self.dataset) + print(f"Total {n} examples") + + self.lengths = [float(v["duration"]) for v in self.text.values()] + total_duration = sum(self.lengths) + avglen = total_duration / len(self.lengths) + maxlen = max(self.lengths) + minlen = min(self.lengths) + print( + f"Average duration of audio: {avglen} sec, " + "Maximum duration: {maxlen} sec, Minimum duration: {minlen} sec" + ) + + def __len__(self): + return len(self.dataset) + + def load_quantization(self, _name): + if self.hp.vocoder_type == 'NATIVE': + metadata = self.text[_name] + quantization = np.array(metadata["quantization"]).T # ..., 4 + elif self.hp.vocoder_type == 'DAC': + codes_path = self.datadir.parent / 'audios-dac' / (os.path.splitext(_name)[0] + ".npy") # noqa + quantization = np.load(codes_path).T # ..., 12 + elif self.hp.vocoder_type == 'ENCODEC': + codes_path = self.datadir.parent / 'audios-encodec' / (os.path.splitext(_name)[0] + ".npy") # noqa + quantization = np.load(codes_path).squeeze(0).T # ..., 8 + elif self.hp.vocoder_type == 'SPEECHTOKENIZER': + codes_path = self.datadir.parent / 'audios-speech-tokenizer/acoustic' / (os.path.splitext(_name)[0] + ".npy") # noqa + quantization = np.load(codes_path).T # ..., 7 + else: + raise ValueError(f"Unknown vocoder_type {self.hp.vocoder_type}") + + return quantization + + def __getitem__(self, i): + dataname = self.dataset[i] + _name = self.datasetbase[i] + metadata = self.text[_name] + + # Speaker 1 + acoustic_tokens = self.load_quantization(_name) + acoustic_tokens = np.pad( + acoustic_tokens, [[1, 0],[0,0]], constant_values=c.SPKR_1) + + npy_path = self.datadir.parent / 'audios-speech-tokenizer/semantic' / (os.path.splitext(_name)[0] + ".npy") # noqa + semantic_tokens = np.load(npy_path)[None] + semantic_tokens = np.pad( + semantic_tokens,[[0,0], [1, 0]], constant_values=c.SPKR_1) + + if "name_2" in metadata: + wav, _ = sf.read(dataname.split(".")[0] + "_1.wav") + else: + wav, _ = sf.read(dataname) + audio = normalize(wav) * 0.95 + speaker_embedding = self.spkr_embedding( + {"waveform": torch.FloatTensor(audio).unsqueeze(0), + "sample_rate": self.hp.sample_rate,} + ).reshape(1, -1) + speaker_embedding = np.repeat( + speaker_embedding, semantic_tokens.shape[1], axis=0) + + # Speaker 2 + if "text_2" in metadata: + _name = _name.split(".wav")[0] + "_2.wav" + acoustic_tokens_2 = self.load_quantization(_name) + acoustic_tokens_2 = np.pad( + acoustic_tokens_2, [[1, 0],[0,0]], constant_values=c.SPKR_2) + + npy_path = self.datadir.parent / 'audios-speech-tokenizer/semantic' / (os.path.splitext(_name)[0] + ".npy") # noqa + semantic_tokens_2 = np.load(npy_path)[None] + semantic_tokens_2 = np.pad( + semantic_tokens_2,[[0,0], [1, 0]], constant_values=c.SPKR_2) + + wav, _ = sf.read(dataname.split(".wav")[0] + "_2.wav") + audio = normalize(wav) * 0.95 + speaker_embedding_2 = self.spkr_embedding( + {"waveform": torch.FloatTensor(audio).unsqueeze(0), + "sample_rate": self.hp.sample_rate,} + ).reshape(1, -1) + speaker_embedding_2 = np.repeat( + speaker_embedding_2, semantic_tokens_2.shape[1], axis=0) + + # Merge both speakers + acoustic_tokens = np.concatenate( + (acoustic_tokens, acoustic_tokens_2), axis=0) + semantic_tokens = np.concatenate( + (semantic_tokens, semantic_tokens_2), axis=1) + speaker_embedding = np.concatenate( + (speaker_embedding, speaker_embedding_2), axis=0) + + speaker_embedding = speaker_embedding[:self.hp.max_length, :] + acoustic_tokens = acoustic_tokens[:self.hp.max_length, :] + semantic_tokens = semantic_tokens[:, :self.hp.max_length] + + # # HACK - we have no 8 lvls pfb30 + # acoustic_tokens = np.concatenate((semantic_tokens.T, acoustic_tokens), axis=1) + # # END HACK + + return speaker_embedding, acoustic_tokens, acoustic_tokens, dataname, semantic_tokens # noqa diff --git a/datasets/example/train.json b/datasets/example/train.json new file mode 100644 index 0000000..eecad12 --- /dev/null +++ b/datasets/example/train.json @@ -0,0 +1,14 @@ +{ + "LJ001-0051.wav": { + "text": "and paying great attention to the \"press work\" or actual process of printing,", + "raw-text": "and paying great attention to the \"press work\" or actual process of printing,", + "duration": 4.860090702947846, + "phoneme": "æ|n|d|_|p|eɪ|ɪ|ŋ|_|ɡ|ɹ|eɪ|t|_|ɐ|t|ɛ|n|ʃ|ə|n|_|t|ə|_|ð|ə|_|\"|p|ɹ|ɛ|s|_|w|ɜː|k|\"|_|ɔː|ɹ|_|æ|k|tʃ|uː|əl|_|p|ɹ|ɑː|s|ɛ|s|_|ʌ|v|_|p|ɹ|ɪ|n|t|ɪ|ŋ|," + }, + "LJ001-0120.wav": { + "text": "In the old print each figure has its definite individuality, and one cannot be mistaken for the other;", + "raw-text": "In the old print each figure has its definite individuality, and one cannot be mistaken for the other;", + "duration": 6.973106575963719, + "phoneme": "ɪ|n|ð|ɪ|_|oʊ|l|d|_|p|ɹ|ɪ|n|t|_|iː|tʃ|_|f|ɪ|ɡ|j|ɚ|_|h|ɐ|z|_|ɪ|t|s|_|d|ɛ|f|ɪ|n|ə|t|_|ɪ|n|d|ɪ|v|ɪ|d|uː|æ|l|ɪ|ɾ|i|,|_|æ|n|d|_|w|ʌ|n|_|k|æ|n|ɑː|t|_|b|iː|_|m|ɪ|s|t|eɪ|k|ə|n|_|f|ɚ|ð|ɪ|_|ʌ|ð|ɚ|;" + } +} \ No newline at end of file diff --git a/demo/audios-speech-tokenizer/acoustic/POD0000004393_S0000029.npy b/demo/audios-speech-tokenizer/acoustic/POD0000004393_S0000029.npy new file mode 100644 index 0000000..f107152 Binary files /dev/null and b/demo/audios-speech-tokenizer/acoustic/POD0000004393_S0000029.npy differ diff --git a/demo/audios-speech-tokenizer/acoustic/POD0000007005_S0000568.npy b/demo/audios-speech-tokenizer/acoustic/POD0000007005_S0000568.npy new file mode 100644 index 0000000..6b39c02 Binary files /dev/null and b/demo/audios-speech-tokenizer/acoustic/POD0000007005_S0000568.npy differ diff --git a/demo/audios-speech-tokenizer/acoustic/POD0000009720_S0000244.npy b/demo/audios-speech-tokenizer/acoustic/POD0000009720_S0000244.npy new file mode 100644 index 0000000..0c9b15f Binary files /dev/null and b/demo/audios-speech-tokenizer/acoustic/POD0000009720_S0000244.npy differ diff --git a/demo/audios-speech-tokenizer/acoustic/POD0000014360_S0000082.npy b/demo/audios-speech-tokenizer/acoustic/POD0000014360_S0000082.npy new file mode 100644 index 0000000..468d5dd Binary files /dev/null and b/demo/audios-speech-tokenizer/acoustic/POD0000014360_S0000082.npy differ diff --git a/demo/audios-speech-tokenizer/acoustic/POD0000015908_S0000037.npy b/demo/audios-speech-tokenizer/acoustic/POD0000015908_S0000037.npy new file mode 100644 index 0000000..5f8d4cd Binary files /dev/null and b/demo/audios-speech-tokenizer/acoustic/POD0000015908_S0000037.npy differ diff --git a/demo/audios-speech-tokenizer/acoustic/POD1000000022_S0000028.npy b/demo/audios-speech-tokenizer/acoustic/POD1000000022_S0000028.npy new file mode 100644 index 0000000..4841da6 Binary files /dev/null and b/demo/audios-speech-tokenizer/acoustic/POD1000000022_S0000028.npy differ diff --git a/demo/audios-speech-tokenizer/acoustic/male_voice.npy b/demo/audios-speech-tokenizer/acoustic/male_voice.npy new file mode 100644 index 0000000..4573aa1 Binary files /dev/null and b/demo/audios-speech-tokenizer/acoustic/male_voice.npy differ diff --git a/demo/audios-speech-tokenizer/semantic/POD0000004393_S0000029.npy b/demo/audios-speech-tokenizer/semantic/POD0000004393_S0000029.npy new file mode 100644 index 0000000..440006d Binary files /dev/null and b/demo/audios-speech-tokenizer/semantic/POD0000004393_S0000029.npy differ diff --git a/demo/audios-speech-tokenizer/semantic/POD0000007005_S0000568.npy b/demo/audios-speech-tokenizer/semantic/POD0000007005_S0000568.npy new file mode 100644 index 0000000..f5c0ea6 Binary files /dev/null and b/demo/audios-speech-tokenizer/semantic/POD0000007005_S0000568.npy differ diff --git a/demo/audios-speech-tokenizer/semantic/POD0000009720_S0000244.npy b/demo/audios-speech-tokenizer/semantic/POD0000009720_S0000244.npy new file mode 100644 index 0000000..c3faee3 Binary files /dev/null and b/demo/audios-speech-tokenizer/semantic/POD0000009720_S0000244.npy differ diff --git a/demo/audios-speech-tokenizer/semantic/POD0000014360_S0000082.npy b/demo/audios-speech-tokenizer/semantic/POD0000014360_S0000082.npy new file mode 100644 index 0000000..4f16a9c Binary files /dev/null and b/demo/audios-speech-tokenizer/semantic/POD0000014360_S0000082.npy differ diff --git a/demo/audios-speech-tokenizer/semantic/POD0000015908_S0000037.npy b/demo/audios-speech-tokenizer/semantic/POD0000015908_S0000037.npy new file mode 100644 index 0000000..ed09d88 Binary files /dev/null and b/demo/audios-speech-tokenizer/semantic/POD0000015908_S0000037.npy differ diff --git a/demo/audios-speech-tokenizer/semantic/POD1000000022_S0000028.npy b/demo/audios-speech-tokenizer/semantic/POD1000000022_S0000028.npy new file mode 100644 index 0000000..93ac286 Binary files /dev/null and b/demo/audios-speech-tokenizer/semantic/POD1000000022_S0000028.npy differ diff --git a/demo/audios-speech-tokenizer/semantic/male_voice.npy b/demo/audios-speech-tokenizer/semantic/male_voice.npy new file mode 100644 index 0000000..f8a3a8c Binary files /dev/null and b/demo/audios-speech-tokenizer/semantic/male_voice.npy differ diff --git a/demo/audios/POD0000004393_S0000029.wav b/demo/audios/POD0000004393_S0000029.wav new file mode 100644 index 0000000..77a2b22 Binary files /dev/null and b/demo/audios/POD0000004393_S0000029.wav differ diff --git a/demo/audios/POD0000007005_S0000568.wav b/demo/audios/POD0000007005_S0000568.wav new file mode 100644 index 0000000..d0a6be2 Binary files /dev/null and b/demo/audios/POD0000007005_S0000568.wav differ diff --git a/demo/audios/POD0000009720_S0000244.wav b/demo/audios/POD0000009720_S0000244.wav new file mode 100644 index 0000000..cee4d9b Binary files /dev/null and b/demo/audios/POD0000009720_S0000244.wav differ diff --git a/demo/audios/POD0000014360_S0000082.wav b/demo/audios/POD0000014360_S0000082.wav new file mode 100644 index 0000000..963ed62 Binary files /dev/null and b/demo/audios/POD0000014360_S0000082.wav differ diff --git a/demo/audios/POD0000015908_S0000037.wav b/demo/audios/POD0000015908_S0000037.wav new file mode 100644 index 0000000..62e8b35 Binary files /dev/null and b/demo/audios/POD0000015908_S0000037.wav differ diff --git a/demo/audios/POD1000000022_S0000028.wav b/demo/audios/POD1000000022_S0000028.wav new file mode 100644 index 0000000..54ab1e3 Binary files /dev/null and b/demo/audios/POD1000000022_S0000028.wav differ diff --git a/demo/audios/male_voice.wav b/demo/audios/male_voice.wav new file mode 100644 index 0000000..3602790 Binary files /dev/null and b/demo/audios/male_voice.wav differ diff --git a/demo/male_voice.wav b/demo/male_voice.wav new file mode 100644 index 0000000..bb9180a Binary files /dev/null and b/demo/male_voice.wav differ diff --git a/demo/manifest.json b/demo/manifest.json new file mode 100644 index 0000000..404129b --- /dev/null +++ b/demo/manifest.json @@ -0,0 +1,6 @@ +{"audio_filepath":"male_voice.wav","text":"Welcome to Casino Lakes Charles. I'm very happy to help you today. We have a broad range of goods for you!","speaker":0,"audio_prompt_filepath":"audios/male_voice.wav"} +{"audio_filepath":"POD0000015908_S0000037.wav","text":"another important thing was that there was no long-term follow-up of the patients to see if they had really stayed cancer free.","speaker":0,"audio_prompt_filepath":"audios/POD0000015908_S0000037.wav"} +{"audio_filepath":"POD0000009720_S0000244.wav","text":"and the whole thing is just so cozy that he wants to be part of it. he wants to be in their club.","speaker":0,"audio_prompt_filepath":"audios/POD0000009720_S0000244.wav"} +{"audio_filepath":"POD0000014360_S0000082.wav","text":"and this is where a large amount of the profits come, such as elsevier making eight hundred and forty six million dollars profit last year.","speaker":0,"audio_prompt_filepath":"audios/POD0000014360_S0000082.wav"} +{"audio_filepath":"POD0000004393_S0000029.wav","text": "just like with uber, when there is less demand, a sports franchise can also lower prices to try to drive up ticket sales.","speaker":0,"audio_prompt_filepath":"audios/POD0000004393_S0000029.wav"} +{"audio_filepath":"POD0000007005_S0000568.wav","text":"but let's let's just cover it now. if you could make a plea or a suggestion to people involved with nonprofits out there and say,","speaker":0,"audio_prompt_filepath":"audios/POD0000007005_S0000568.wav"} diff --git a/docs/_config.yml b/docs/_config.yml new file mode 100644 index 0000000..277f1f2 --- /dev/null +++ b/docs/_config.yml @@ -0,0 +1 @@ +theme: jekyll-theme-cayman diff --git a/docs/_layouts/default.html b/docs/_layouts/default.html new file mode 100644 index 0000000..fb7be6d --- /dev/null +++ b/docs/_layouts/default.html @@ -0,0 +1,43 @@ + + + + + +{% seo %} + + + + + + + {% include head-custom.html %} + + + Skip to the content. + + + +
+ {{ content }} + + +
+ + + diff --git a/docs/assets/css/style.scss b/docs/assets/css/style.scss new file mode 100644 index 0000000..e9aafc4 --- /dev/null +++ b/docs/assets/css/style.scss @@ -0,0 +1,29 @@ +--- +--- + +@import "{{ site.theme }}"; + +.page-header { + background-image: linear-gradient(170deg, #2876c4, #d9ee50); +} +.main-content { + max-width: 90%; + font-size: 0.8rem; +} + +audio { + width: 140px; +} + +.footnotes { + list-style: none; + padding-left: 0; +} + +.footnotes li { + font-size: 0.8em; +} + +.footnotes a { + text-decoration: none; +} diff --git a/docs/assets/img/polyai-logo.webp b/docs/assets/img/polyai-logo.webp new file mode 100644 index 0000000..74cab92 Binary files /dev/null and b/docs/assets/img/polyai-logo.webp differ diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..a65154b --- /dev/null +++ b/docs/index.md @@ -0,0 +1,45 @@ +## PHEME: Efficient and Conversational Speech Generation. + + - Abstract. In recent years, speech generation has seen remarkable progress, now achieving one-shot generation capability that is often virtually indistinguishable from real human voice. Integrating such advancements in speech generation with large language models might revolutionize a wide range of applications. However, certain applications, such as assistive conversational systems, require natural and conversational speech generation which also operates efficiently in real time. Current state-of-the-art models like VALL-E and SoundStorm, powered by hierarchical neural audio codecs, require large neural components and extensive training data to work well. In contrast, MQTTS aims to build more compact conversational TTS models while capitalizing on smaller-scale real-life conversational speech data. However, its autoregressive nature yields high inference latency and thus limits its real-time usage. In order to mitigate the current limitations of the state-of-the-art TTS models while capitalizing on their strengths, in this work we propose the *PHEME* model series that **1)** offers compact yet high-performing models, **2)** allows for parallel speech generation of **3)** natural conversational speech, and **4)** it can be trained efficiently on smaller-scale conversational data, cutting data demands by more than 10x but still matching the quality of the autoregressive TTS models. We also show that through simple teacher-student distillation we can meet significant improvements in voice quality for single-speaker setups on top of pretrained *PHEME* checkpoints, relying solely on synthetic speech generated by much larger teacher models. + - [Code](https://github.com/PolyAI-LDN/pheme) + - [Paper](...) + + +### GigaSpeech One-shot1 TTS Examples +
+
    +
  1. + One-shot - inference setup for voices unseen at the training time, when prompts and speaker embeddings are provided as additional model inputs. +
  2. +
+
+ + +| Prompt audio | Reference audio | PHEME (100M) | PHEME (300M) no speaker embeddings | PHEME (300M) | Prompt text | Reference text | +| :----------------------------------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | +| | | | | | let's just say in her own words, once i sat down and watched it i never moved, i w as enthralled by it. | and she told me the next time she went back she would take me with her. and i waited, of course, like i said, thirteen years. | +| | | | | | in early twenty-twenty, blue apron put the word out that it was interested in possibly getting scooped up. maybe by a big grocery chain. or someone else with deep pockets who wanted to own a meal kit delivery business. | at the same time, garcia says, the company acted like it was in turnaround mode. it decid ed to streamline operations, including shutting down its fulfillment center in texas | +| | | | | | aside from influencing basically everyone who matters he was one of the first if not, in fact the first artist to bring an electric guitar player with him on to the grand oleopry stag e. | if you want to call it a honky tonk, and it happened after ernest tubb. it was influenced by ernest tubb. before i get to the story and episode, i'd like to address one other thing. | +| | | | | | so it's ah i think there's a range of risks, but generally speaking ah there's goi ng to be a study increase in the floor of the skill level as these ah a i technologies diffuse. | that is, there will be more and more ah capabilities available to people at the bottom of the scale, that is individuals as well as people with more access to computing power, ah money, and data at the higher end. | +| | | | | | so after they put in their name, phone number, email address onto your landing pag e. where would you like to send them? would you like to send them to your facebook page your website? | book an appointment to a buyer on facebook messenger bot, a seller messenger bot. where w ould you like to send them? so for this example i'm just gonna say book an appointment. | + + + +### Artificial Voice TTS Examples + +| Prompt audio | Reference audio | PHEME (300M) no training on artificial voice | PHEME (300M) | Prompt text | Reference text | +| ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | +| | | | | Our garden terrace is a lovely spot for afternoon tea. | The city’s ghost walk is a spooky and fascinating evening adventure. | +| | | | | If you need a quiet place to work, our library is just perfect. | Our hotel’s evening bonfires are a great place to socialize. | +| | | | | There’s a delightful chocolate factory tour, great for families. | Our rooftop jazz nights feature some of the best local talent. | +| | | | | The rooftop bar hosts a live DJ on Friday nights. | Our in-house sommelier leads an exquisite wine and cheese pairing event. | +| | | | | The comedy club in town is known for its hilarious acts. | The annual food fair showcases the best of local cuisine. | + +### Inference speed with Triton-LLM (RTFs, lower is better) for short and long sentences + +| Model | *short* | *long* | GPU | +| ------------------ | --------- | --------- |--------- | +| MQTTS (100M) | 1.930 | 1.842 | A100 | +| PHEME-SMALL (100M) | **0.133** | **0.133** | A100 | +| PHEME-LARGE (300M) | 0.143 | 0.143 | A100 | + diff --git a/docs/samples/empress/114.wav b/docs/samples/empress/114.wav new file mode 100644 index 0000000..1644071 Binary files /dev/null and b/docs/samples/empress/114.wav differ diff --git a/docs/samples/empress/148.wav b/docs/samples/empress/148.wav new file mode 100644 index 0000000..5a5c55c Binary files /dev/null and b/docs/samples/empress/148.wav differ diff --git a/docs/samples/empress/161.wav b/docs/samples/empress/161.wav new file mode 100644 index 0000000..53f829c Binary files /dev/null and b/docs/samples/empress/161.wav differ diff --git a/docs/samples/empress/189.wav b/docs/samples/empress/189.wav new file mode 100644 index 0000000..360e29a Binary files /dev/null and b/docs/samples/empress/189.wav differ diff --git a/docs/samples/empress/217.wav b/docs/samples/empress/217.wav new file mode 100644 index 0000000..56a6a3d Binary files /dev/null and b/docs/samples/empress/217.wav differ diff --git a/docs/samples/empress/226.wav b/docs/samples/empress/226.wav new file mode 100644 index 0000000..a913063 Binary files /dev/null and b/docs/samples/empress/226.wav differ diff --git a/docs/samples/empress/234.wav b/docs/samples/empress/234.wav new file mode 100644 index 0000000..de76478 Binary files /dev/null and b/docs/samples/empress/234.wav differ diff --git a/docs/samples/empress/242.wav b/docs/samples/empress/242.wav new file mode 100644 index 0000000..b6260d0 Binary files /dev/null and b/docs/samples/empress/242.wav differ diff --git a/docs/samples/empress/262.wav b/docs/samples/empress/262.wav new file mode 100644 index 0000000..fc75dc7 Binary files /dev/null and b/docs/samples/empress/262.wav differ diff --git a/docs/samples/empress/269.wav b/docs/samples/empress/269.wav new file mode 100644 index 0000000..1207c16 Binary files /dev/null and b/docs/samples/empress/269.wav differ diff --git a/docs/samples/empress/29.wav b/docs/samples/empress/29.wav new file mode 100644 index 0000000..aa04f1e Binary files /dev/null and b/docs/samples/empress/29.wav differ diff --git a/docs/samples/empress/46.wav b/docs/samples/empress/46.wav new file mode 100644 index 0000000..ec0b7fa Binary files /dev/null and b/docs/samples/empress/46.wav differ diff --git a/docs/samples/gigaspeech/POD1000000004_S0000246.wav b/docs/samples/gigaspeech/POD1000000004_S0000246.wav new file mode 100644 index 0000000..1a1815a Binary files /dev/null and b/docs/samples/gigaspeech/POD1000000004_S0000246.wav differ diff --git a/docs/samples/gigaspeech/POD1000000004_S0000247.wav b/docs/samples/gigaspeech/POD1000000004_S0000247.wav new file mode 100644 index 0000000..2a25cc2 Binary files /dev/null and b/docs/samples/gigaspeech/POD1000000004_S0000247.wav differ diff --git a/docs/samples/gigaspeech/POD1000000018_S0000253.wav b/docs/samples/gigaspeech/POD1000000018_S0000253.wav new file mode 100644 index 0000000..7021c5a Binary files /dev/null and b/docs/samples/gigaspeech/POD1000000018_S0000253.wav differ diff --git a/docs/samples/gigaspeech/POD1000000018_S0000254.wav b/docs/samples/gigaspeech/POD1000000018_S0000254.wav new file mode 100644 index 0000000..0ec4e1d Binary files /dev/null and b/docs/samples/gigaspeech/POD1000000018_S0000254.wav differ diff --git a/docs/samples/gigaspeech/POD1000000048_S0000035.wav b/docs/samples/gigaspeech/POD1000000048_S0000035.wav new file mode 100644 index 0000000..68c9b69 Binary files /dev/null and b/docs/samples/gigaspeech/POD1000000048_S0000035.wav differ diff --git a/docs/samples/gigaspeech/POD1000000048_S0000036.wav b/docs/samples/gigaspeech/POD1000000048_S0000036.wav new file mode 100644 index 0000000..a7e5d5c Binary files /dev/null and b/docs/samples/gigaspeech/POD1000000048_S0000036.wav differ diff --git a/docs/samples/gigaspeech/YOU1000000006_S0000051.wav b/docs/samples/gigaspeech/YOU1000000006_S0000051.wav new file mode 100644 index 0000000..41464a3 Binary files /dev/null and b/docs/samples/gigaspeech/YOU1000000006_S0000051.wav differ diff --git a/docs/samples/gigaspeech/YOU1000000006_S0000052.wav b/docs/samples/gigaspeech/YOU1000000006_S0000052.wav new file mode 100644 index 0000000..bbe06f1 Binary files /dev/null and b/docs/samples/gigaspeech/YOU1000000006_S0000052.wav differ diff --git a/docs/samples/gigaspeech/YOU1000000044_S0000798.wav b/docs/samples/gigaspeech/YOU1000000044_S0000798.wav new file mode 100644 index 0000000..6e0db37 Binary files /dev/null and b/docs/samples/gigaspeech/YOU1000000044_S0000798.wav differ diff --git a/docs/samples/gigaspeech/YOU1000000044_S0000799.wav b/docs/samples/gigaspeech/YOU1000000044_S0000799.wav new file mode 100644 index 0000000..59eb0e8 Binary files /dev/null and b/docs/samples/gigaspeech/YOU1000000044_S0000799.wav differ diff --git a/docs/samples/pheme-100/019.wav b/docs/samples/pheme-100/019.wav new file mode 100644 index 0000000..f697780 Binary files /dev/null and b/docs/samples/pheme-100/019.wav differ diff --git a/docs/samples/pheme-100/042.wav b/docs/samples/pheme-100/042.wav new file mode 100644 index 0000000..9ddd61e Binary files /dev/null and b/docs/samples/pheme-100/042.wav differ diff --git a/docs/samples/pheme-100/080.wav b/docs/samples/pheme-100/080.wav new file mode 100644 index 0000000..db3e5e5 Binary files /dev/null and b/docs/samples/pheme-100/080.wav differ diff --git a/docs/samples/pheme-100/188.wav b/docs/samples/pheme-100/188.wav new file mode 100644 index 0000000..4cf2cf1 Binary files /dev/null and b/docs/samples/pheme-100/188.wav differ diff --git a/docs/samples/pheme-100/209.wav b/docs/samples/pheme-100/209.wav new file mode 100644 index 0000000..2a77b3b Binary files /dev/null and b/docs/samples/pheme-100/209.wav differ diff --git a/docs/samples/pheme-300/019.wav b/docs/samples/pheme-300/019.wav new file mode 100644 index 0000000..10c9993 Binary files /dev/null and b/docs/samples/pheme-300/019.wav differ diff --git a/docs/samples/pheme-300/042.wav b/docs/samples/pheme-300/042.wav new file mode 100644 index 0000000..f9316dd Binary files /dev/null and b/docs/samples/pheme-300/042.wav differ diff --git a/docs/samples/pheme-300/080.wav b/docs/samples/pheme-300/080.wav new file mode 100644 index 0000000..7cc16e8 Binary files /dev/null and b/docs/samples/pheme-300/080.wav differ diff --git a/docs/samples/pheme-300/188.wav b/docs/samples/pheme-300/188.wav new file mode 100644 index 0000000..85a82a8 Binary files /dev/null and b/docs/samples/pheme-300/188.wav differ diff --git a/docs/samples/pheme-300/209.wav b/docs/samples/pheme-300/209.wav new file mode 100644 index 0000000..dead72c Binary files /dev/null and b/docs/samples/pheme-300/209.wav differ diff --git a/docs/samples/pheme-empress-300/001.wav b/docs/samples/pheme-empress-300/001.wav new file mode 100644 index 0000000..261a3d9 Binary files /dev/null and b/docs/samples/pheme-empress-300/001.wav differ diff --git a/docs/samples/pheme-empress-300/002.wav b/docs/samples/pheme-empress-300/002.wav new file mode 100644 index 0000000..3a6662a Binary files /dev/null and b/docs/samples/pheme-empress-300/002.wav differ diff --git a/docs/samples/pheme-empress-300/190.wav b/docs/samples/pheme-empress-300/190.wav new file mode 100644 index 0000000..8da67a5 Binary files /dev/null and b/docs/samples/pheme-empress-300/190.wav differ diff --git a/docs/samples/pheme-empress-300/227.wav b/docs/samples/pheme-empress-300/227.wav new file mode 100644 index 0000000..dfe23f4 Binary files /dev/null and b/docs/samples/pheme-empress-300/227.wav differ diff --git a/docs/samples/pheme-empress-300/235.wav b/docs/samples/pheme-empress-300/235.wav new file mode 100644 index 0000000..9142699 Binary files /dev/null and b/docs/samples/pheme-empress-300/235.wav differ diff --git a/docs/samples/pheme-empress-300/243.wav b/docs/samples/pheme-empress-300/243.wav new file mode 100644 index 0000000..40a2968 Binary files /dev/null and b/docs/samples/pheme-empress-300/243.wav differ diff --git a/docs/samples/pheme-empress-300/270.wav b/docs/samples/pheme-empress-300/270.wav new file mode 100644 index 0000000..307a65a Binary files /dev/null and b/docs/samples/pheme-empress-300/270.wav differ diff --git a/docs/samples/pheme-no-empress-300/190.wav b/docs/samples/pheme-no-empress-300/190.wav new file mode 100644 index 0000000..8075d18 Binary files /dev/null and b/docs/samples/pheme-no-empress-300/190.wav differ diff --git a/docs/samples/pheme-no-empress-300/227.wav b/docs/samples/pheme-no-empress-300/227.wav new file mode 100644 index 0000000..749de07 Binary files /dev/null and b/docs/samples/pheme-no-empress-300/227.wav differ diff --git a/docs/samples/pheme-no-empress-300/235.wav b/docs/samples/pheme-no-empress-300/235.wav new file mode 100644 index 0000000..2dfa0e7 Binary files /dev/null and b/docs/samples/pheme-no-empress-300/235.wav differ diff --git a/docs/samples/pheme-no-empress-300/243.wav b/docs/samples/pheme-no-empress-300/243.wav new file mode 100644 index 0000000..9aaff9a Binary files /dev/null and b/docs/samples/pheme-no-empress-300/243.wav differ diff --git a/docs/samples/pheme-no-empress-300/270.wav b/docs/samples/pheme-no-empress-300/270.wav new file mode 100644 index 0000000..387dd9b Binary files /dev/null and b/docs/samples/pheme-no-empress-300/270.wav differ diff --git a/docs/samples/pheme-no-spkr-300/019.wav b/docs/samples/pheme-no-spkr-300/019.wav new file mode 100644 index 0000000..34638a9 Binary files /dev/null and b/docs/samples/pheme-no-spkr-300/019.wav differ diff --git a/docs/samples/pheme-no-spkr-300/042.wav b/docs/samples/pheme-no-spkr-300/042.wav new file mode 100644 index 0000000..307ccba Binary files /dev/null and b/docs/samples/pheme-no-spkr-300/042.wav differ diff --git a/docs/samples/pheme-no-spkr-300/080.wav b/docs/samples/pheme-no-spkr-300/080.wav new file mode 100644 index 0000000..a7b3785 Binary files /dev/null and b/docs/samples/pheme-no-spkr-300/080.wav differ diff --git a/docs/samples/pheme-no-spkr-300/188.wav b/docs/samples/pheme-no-spkr-300/188.wav new file mode 100644 index 0000000..d50fbcb Binary files /dev/null and b/docs/samples/pheme-no-spkr-300/188.wav differ diff --git a/docs/samples/pheme-no-spkr-300/209.wav b/docs/samples/pheme-no-spkr-300/209.wav new file mode 100644 index 0000000..68c376c Binary files /dev/null and b/docs/samples/pheme-no-spkr-300/209.wav differ diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/conformer.py b/modules/conformer.py new file mode 100644 index 0000000..a7685d0 --- /dev/null +++ b/modules/conformer.py @@ -0,0 +1,671 @@ +"""Conformer definition adjusted given the Lucidrain's repo. +https://github.com/lucidrains/soundstorm-pytorch/blob/main/soundstorm_pytorch/soundstorm.py # noqa + +Copyright PolyAI Limited. +""" +from collections import namedtuple +from functools import wraps +from typing import Dict, Union + +import torch +import torch.nn.functional as F +from einops import rearrange, reduce +from einops.layers.torch import EinMix, Rearrange +from torch import einsum, nn + + +# rotary embedding +class RotaryEmbedding(nn.Module): + def __init__(self, dim, theta = 10000): + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent = False) + + @property + def device(self): + return next(self.buffers()).device + + def forward(self, seq_len): + t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq) + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim = -1) + return freqs + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(pos, t): + return (t * pos.cos()) + (rotate_half(t) * pos.sin()) + + +# constants +EfficientAttentionConfig = namedtuple( + 'EfficientAttentionConfig', + ['enable_flash', 'enable_math', 'enable_mem_efficient'] +) + +# helpers +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def divisible_by(numer, denom): + return (numer % denom) == 0 + +def calc_same_padding(kernel_size): + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + +def eval_decorator(fn): + @wraps(fn) + def inner(model, *args, **kwargs): + was_training = model.training + model.eval() + out = fn(model, *args, **kwargs) + model.train(was_training) + return out + return inner + + +def once(fn): + called = False + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + return inner + +print_once = once(print) + + +# t5 relative positional bias +class T5RelativePositionBias(nn.Module): + def __init__( + self, + scale = 1., + num_buckets = 32, + max_distance = 128, + heads = 8 + ): + super().__init__() + self.scale = scale + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket( + relative_position, + num_buckets = 32, + max_distance = 128 + ): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log( + max_distance / max_exact) * (num_buckets - max_exact) + ).long() + + val_if_large = torch.min( + val_if_large, + torch.full_like(val_if_large, num_buckets - 1) + ) + + ret += torch.where(is_small, n, val_if_large) + return ret + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, n): + pos = torch.arange(n, device = self.device).long() + rel_pos = rearrange(pos, 'j -> 1 j') - rearrange(pos, 'i -> i 1') + + rp_bucket = self._relative_position_bucket( + rel_pos, num_buckets = self.num_buckets, + max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + + bias = rearrange(values, 'i j h -> h i j') + return bias * self.scale + + +# main class +class Attend(nn.Module): + def __init__( + self, + causal = False, + dropout = 0., + flash = False + ): + super().__init__() + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.causal = causal + self.flash = flash + + # determine efficient attention configs for cuda and cpu + self.cpu_config = EfficientAttentionConfig(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + print_once('A100 GPU detected, using flash attention if input tensor is on cuda') # noqa + self.cuda_config = EfficientAttentionConfig(True, True, True) + else: + print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') # noqa + self.cuda_config = EfficientAttentionConfig(False, True, True) + + def get_mask(self, i, j, device): + return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) # noqa + + def flash_attn(self, q, k, v, mask = None, attn_bias = None): + _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device # noqa + + # single headed key / values + + if k.ndim == 3: + k = rearrange(k, 'b n d -> b 1 n d') + + if v.ndim == 3: + v = rearrange(v, 'b n d -> b 1 n d') + + # Check if mask exists and expand to compatible shape + # The mask is B L, so it would have to be expanded to B H N L + if exists(mask) and mask.ndim != 4: + mask = rearrange(mask, 'b j -> b 1 1 j') + mask = mask.expand(-1, heads, q_len, -1) + + # Check if there is a compatible device for flash attention + config = self.cuda_config if is_cuda else self.cpu_config + causal = self.causal + + # handle attention bias + if exists(attn_bias): + mask_value = -torch.finfo(q.dtype).max // 2 + causal_mask = self.get_mask(q_len, k_len, device) + attn_bias = attn_bias.masked_fill(causal_mask, mask_value) + + if exists(mask): + attn_bias = attn_bias.masked_fill(~mask, mask_value) + + mask = attn_bias + causal = False + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, k, v, + attn_mask = mask, + dropout_p = self.dropout if self.training else 0., + is_causal = causal + ) + + return out + + def forward(self, q, k, v, mask = None, attn_bias = None): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + q_len, k_len, device = q.shape[-2], k.shape[-2], q.device + + scale = q.shape[-1] ** -0.5 + + kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' + + if self.flash: + assert not exists(attn_bias) + return self.flash_attn(q, k, v, mask = mask) + + # similarity + + sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale + + # attention bias + + if exists(attn_bias): + sim = sim + attn_bias + + # causal mask + if self.causal: + causal_mask = self.get_mask(q_len, k_len, device) + sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + + # key padding mask + if exists(mask): + if mask.ndim != 4: + mask = rearrange(mask, 'b j -> b 1 1 j') + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # attention + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) + + return out + + +class Swish(nn.Module): + def forward(self, x): + return x * x.sigmoid() + + +class GLU(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + out, gate = x.chunk(2, dim=self.dim) + return out * gate.sigmoid() + + +class DepthWiseConv1d(nn.Module): + def __init__(self, chan_in, chan_out, kernel_size, padding): + super().__init__() + self.padding = padding + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in) + + def forward(self, x): + x = F.pad(x, self.padding) + return self.conv(x) + + +class Scale(nn.Module): + def __init__(self, scale, fn): + super().__init__() + self.fn = fn + self.scale = scale + + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) * self.scale + + +class ChanLayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.ones(1, dim, 1)) + + def forward(self, x): + eps = 1e-6 if x.dtype == torch.float32 else 1e-4 + var = torch.var(x, dim = 1, unbiased = False, keepdim = True) + mean = torch.mean(x, dim = 1, keepdim = True) + return (x - mean) * var.clamp(min = eps).rsqrt() * self.gamma + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = nn.LayerNorm(dim) + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + + +class Attention(nn.Module): + def __init__( + self, + dim, + heads = 8, + dim_head = 64, + dropout = 0., + flash = True + ): + super().__init__() + inner_dim = dim_head * heads + self.heads= heads + self.scale = dim_head ** -0.5 + + self.attend = Attend( + flash = flash, + dropout = dropout + ) + + self.dropout = nn.Dropout(dropout) + + self.to_q = nn.Linear(dim, inner_dim, bias = False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) + self.to_out = nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context = None, + mask = None, + rotary_emb = None, + attn_bias = None + ): + n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context) + context = default(context, x) + + q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) + q, k, v = map( + lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + + if exists(rotary_emb): + q = apply_rotary_pos_emb(rotary_emb, q) + k = apply_rotary_pos_emb(rotary_emb, k) + + out = self.attend(q, k, v, mask = mask, attn_bias = attn_bias) + + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class FeedForward(nn.Module): + def __init__( + self, + dim, + mult = 4, + dropout = 0. + ): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, dim * mult), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim * mult, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +class ConformerConvModule(nn.Module): + def __init__( + self, + dim, + causal = False, + expansion_factor = 2, + kernel_size = 31, + dropout = 0. + ): + super().__init__() + + inner_dim = dim * expansion_factor + padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) + + self.net = nn.Sequential( + nn.LayerNorm(dim), + Rearrange('b n c -> b c n'), + nn.Conv1d(dim, inner_dim * 2, 1), + GLU(dim=1), + DepthWiseConv1d( + inner_dim, inner_dim, kernel_size = kernel_size, + padding = padding + ), + Swish(), + ChanLayerNorm(inner_dim), + nn.Conv1d(inner_dim, dim, 1), + Rearrange('b c n -> b n c'), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +# Conformer Block +class ConformerBlock(nn.Module): + def __init__( + self, + *, + dim, + dim_head = 64, + heads = 8, + ff_mult = 4, + conv_expansion_factor = 2, + conv_kernel_size = 31, + attn_dropout = 0., + attn_flash = True, + ff_dropout = 0., + conv_dropout = 0., + conv_causal = False + ): + super().__init__() + self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) + self.attn = Attention( + dim = dim, dim_head = dim_head, heads = heads, + dropout = attn_dropout, flash = attn_flash + ) + self.conv = ConformerConvModule( + dim = dim, causal = conv_causal, + expansion_factor = conv_expansion_factor, + kernel_size = conv_kernel_size, dropout = conv_dropout + ) + self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) + + self.attn = PreNorm(dim, self.attn) + self.ff1 = Scale(0.5, PreNorm(dim, self.ff1)) + self.ff2 = Scale(0.5, PreNorm(dim, self.ff2)) + + self.post_norm = nn.LayerNorm(dim) + + def forward( + self, + x, + mask = None, + rotary_emb = None, + attn_bias = None + ): + x = self.ff1(x) + x + x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x # noqa + x = self.conv(x) + x + x = self.ff2(x) + x + x = self.post_norm(x) + return x + + +# Conformer +class Conformer(nn.Module): + def __init__( + self, + dim, + *, + num_layers, + dim_head = 64, + heads = 8, + ff_mult = 4, + conv_expansion_factor = 2, + conv_kernel_size = 31, + attn_dropout = 0., + ff_dropout = 0., + conv_dropout = 0., + conv_causal = False, + attn_flash = True, + t5_rel_pos_bias = False + ): + super().__init__() + + assert not (t5_rel_pos_bias and attn_flash), 'flash attention is not compatible with learned bias' # noqa + + self.dim = dim + self.layers = nn.ModuleList([]) + + self.rotary_emb = RotaryEmbedding( + dim_head) if not t5_rel_pos_bias else None + self.rel_pos_bias = T5RelativePositionBias( + dim_head ** 0.5, heads = heads) if t5_rel_pos_bias else None + + for _ in range(num_layers): + self.layers.append(ConformerBlock( + dim = dim, + dim_head = dim_head, + heads = heads, + ff_mult = ff_mult, + conv_expansion_factor = conv_expansion_factor, + conv_kernel_size = conv_kernel_size, + attn_dropout = attn_dropout, + ff_dropout = ff_dropout, + conv_dropout = conv_dropout, + conv_causal = conv_causal, + attn_flash = attn_flash + )) + + def forward(self, x, mask = None): + seq_len = x.shape[-2] + + rotary_emb = self.rotary_emb(seq_len) if exists(self.rotary_emb) else None # noqa + attn_bias = self.rel_pos_bias(seq_len) if exists(self.rel_pos_bias) else None #noqa + + for block in self.layers: + x = block( + x, + mask = mask, + rotary_emb = rotary_emb, + attn_bias = attn_bias + ) + return x + + +# conformer with sum reduction across quantized tokens at the beginning, +# along with heads +class ConformerWrapper(nn.Module): + def __init__( + self, + *, + codebook_size, + num_quantizers, + conformer: Union[Conformer, Dict[str, any]], + grouped_quantizers = 1 + ): + super().__init__() + self.conformer = conformer + + if isinstance(conformer, dict): + self.conformer = Conformer(**self.conformer) + + dim = self.conformer.dim + + self.embedding_proj = nn.Sequential( + nn.Linear(dim * grouped_quantizers, dim), + nn.LayerNorm(dim) + ) if grouped_quantizers > 1 else nn.Identity() + + num_codes_with_mask = codebook_size + 1 + num_effective_quantizers = num_quantizers * grouped_quantizers + + self.code_embeds = nn.Embedding( + num_codes_with_mask * num_effective_quantizers, dim) + + self.register_buffer( + 'quantizer_offsets', + torch.arange(num_effective_quantizers) * num_codes_with_mask, + persistent = False + ) + self.register_buffer( + 'mask_tokens', self.quantizer_offsets + num_codes_with_mask, + persistent = False + ) + + self.dim = dim + self.codebook_size = codebook_size + + self.num_codes_with_mask = num_codes_with_mask + self.num_quantizers = num_quantizers + self.grouped_quantizers = grouped_quantizers + + self.heads = nn.Sequential( + nn.Linear(dim, dim * num_effective_quantizers), + Rearrange('b n (h d) -> b (n h) d', h = num_effective_quantizers) + ) + + # each quantizer codebook would require its own logits weight + # and bias matrices + # the amazing einops makes this easy with 'EinMix' + self.to_logits = nn.Sequential( + nn.LayerNorm(dim), + Rearrange('b (n gq) d -> b n gq d', gq = num_effective_quantizers), + EinMix( + 'b n gq d -> b n gq l', + weight_shape = 'gq d l', + bias_shape = 'gq l', + gq = num_effective_quantizers, + l = codebook_size, + d = dim + ), + Rearrange('b ... d -> b (...) d') + ) + + def forward( + self, + x, + *, + mask = None, + cond = None, + sum_embeds = None, + return_embeddings = False, + return_logits_and_embeddings = False + ): + """ + einops notation: + b - batch + n - sequence + g - groups + q - quantizers + d - feature dimension + """ + + n, q, g = x.shape[-1], self.num_quantizers, self.grouped_quantizers + assert divisible_by(n, g * q), 'sequence must be divisible by number of quantizers' # noqa + + x = rearrange(x, 'b (n gq) -> b n gq', gq = g * q) + x = x + self.quantizer_offsets + + x = self.code_embeds(x) + + x = reduce(x, 'b n (g q) d -> b n (g d)', 'sum', g = g) + + x = self.embedding_proj(x) + + if exists(sum_embeds): + x = x + sum_embeds + + if exists(cond): + if cond.ndim == 2: + cond = rearrange(cond, 'b d -> b 1 d') + + x = x + cond + + x = self.conformer(x, mask = mask) + embeds = self.heads(x) + + if return_embeddings or not exists(self.to_logits): + return embeds + + logits = self.to_logits(embeds) + + if return_logits_and_embeddings: + return logits, embeds + + return logits diff --git a/modules/masking_logic.py b/modules/masking_logic.py new file mode 100644 index 0000000..e3451b2 --- /dev/null +++ b/modules/masking_logic.py @@ -0,0 +1,111 @@ +"""Masking and sampling logic adapted from MaskGIT original paper: +https://github.com/google-research/maskgit + +Copyright PolyAI Limited. +""" +from dataclasses import dataclass + +import numpy as np +import torch +import torch.nn.functional as F + + +@dataclass +class State: + """Holds decoding state data.""" + # The position of the decoding loop in the length dimension. + cur_index: None + # The active sequence log probabilities and finished sequence scores. + cur_seqs: None + final_seqs: None + + +def state_init(init_indices, num_iter, start_iter=0): + """Initializes the decoding state data structure.""" + cur_index_0 = start_iter + cur_seqs_0 = init_indices + final_seqs_0 = torch.unsqueeze(init_indices, 1) + final_seqs_0 = torch.tile(final_seqs_0, (1, num_iter, 1)) + return State( + cur_index=cur_index_0, cur_seqs=cur_seqs_0, final_seqs=final_seqs_0) + + +def schedule(ratio, method="cosine"): + if method == "uniform": + mask_ratio = 1. - ratio + elif "pow" in method: + exponent = float(method.replace("pow", "")) + mask_ratio = 1. - ratio**exponent + elif method == "cosine": + mask_ratio = np.cos(ratio * (np.pi/2)) + + mask_ratio = np.clip(mask_ratio, 1e-6, 1.) + return mask_ratio + + +def mask_by_random_topk(mask_len, probs, temperature=1.0): + noise = gumbel_noise_like(probs) + confidence = torch.log(probs) + temperature * noise + sorted_confidence, _ = torch.sort(confidence, dim=-1) + # Obtains cut off threshold given the mask lengths. + cut_off = torch.take_along_dim(sorted_confidence, mask_len.long(), dim=-1) + # Masks tokens with lower confidence. + masking = (confidence < cut_off) + return masking + + +def gumbel_noise_like(t): + noise = torch.zeros_like(t).uniform_(1e-20, 1) + return -torch.log(-torch.log(noise)) + + +def sample_from_logits( + logits, + sample: bool = True, + temperature: float = 1.0, + top_k: int = None, + top_p: float = None, + return_probs: bool = False +): + shp = logits.shape[:-1] + + # Apply top_k sampling + if top_k is not None: + v, _ = logits.topk(top_k) + logits[logits < v[..., [-1]]] = -float("inf") + + # Apply top_p (nucleus) sampling + if top_p is not None and top_p < 1.0: + v, sorted_indices = logits.sort(descending=True) + cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1) + + sorted_indices_to_remove = cumulative_probs > top_p + # Right shift indices_to_remove to keep 1st token over threshold + sorted_indices_to_remove = F.pad( + sorted_indices_to_remove, (1, 0), value=False)[..., :-1] + + # Compute indices_to_remove in unsorted array + indices_to_remove = sorted_indices_to_remove.scatter( + -1, sorted_indices, sorted_indices_to_remove + ) + + logits[indices_to_remove] = -float("inf") + + # Perform multinomial sampling after normalizing logits + probs = ( + F.softmax(logits / temperature, dim=-1) + if temperature > 0 + else logits.softmax(dim=-1) + ) + token = ( + probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp) + if sample + else logits.argmax(-1) + ) + + if return_probs: + token_probs = probs.take_along_dim( + token.unsqueeze(-1), dim=-1).squeeze(-1) + return token, token_probs + else: + return token diff --git a/modules/s2a_model.py b/modules/s2a_model.py new file mode 100644 index 0000000..5808279 --- /dev/null +++ b/modules/s2a_model.py @@ -0,0 +1,563 @@ +"""A2S model definition. + +Copyright PolyAI Limited. +""" +from typing import Union + +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from einops import rearrange + +import constants as c +from modules import masking_logic +from modules.conformer import Conformer +from modules.masking_logic import (State, mask_by_random_topk, + sample_from_logits, state_init) +from utils import load_checkpoint + + +class Pheme(pl.LightningModule): + def __init__(self, hp): + super().__init__() + self.hp = hp + self.model = TTSConformer(hp) + self.cross_entropy = nn.CrossEntropyLoss( + label_smoothing=self.hp.label_smoothing, + ignore_index=self.hp.n_codes + ) + if self.hp.pretrained_path: + self.load() + else: + self.apply(self.init_weights) + + if self.hp.only_inference: + self.model.eval() + + self.save_hyperparameters() + + def load(self): + state_dict = load_checkpoint(self.hp.pretrained_path) + print(f"Parameters loaded from {self.hp.pretrained_path}") + self.load_state_dict(state_dict, strict=True) + + def init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=0.02) + module._fill_padding_idx_with_zero() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + + def configure_optimizers(self): + optimizer_adam = optim.AdamW( + self.parameters(), lr=self.hp.lr, + betas=(self.hp.adam_beta1, self.hp.adam_beta2)) + + # Learning rate scheduler + num_training_steps = self.hp.training_step + num_warmup_steps = self.hp.warmup_step + num_flat_steps = int(self.hp.optim_flat_percent * num_training_steps) + + def lambda_lr(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step < (num_warmup_steps + num_flat_steps): + return 1.0 + return max( + 0.0, + float(num_training_steps - current_step) + / float( + max(1, num_training_steps - (num_warmup_steps + num_flat_steps)) # noqa + ), + ) + + scheduler_adam = { + "scheduler": optim.lr_scheduler.LambdaLR( + optimizer_adam, lambda_lr), + "interval": "step", + } + return [optimizer_adam], [scheduler_adam] + + def top_k_accuracy(self, y_true, y_pred_probabilities, k): + _, sorted_indices = torch.sort(y_pred_probabilities, descending=True) + + # Get the top-k predictions + top_k_indices = sorted_indices[:, :k] + expanded_y_true = y_true.unsqueeze(1).expand_as(top_k_indices) + + # Check if true labels exist in top-k predictions + hits = torch.sum(torch.eq(top_k_indices, expanded_y_true)) + accuracy = hits.item() / (len(y_true) + 1e-7) + + return accuracy + + def training_step(self, batch, batch_idx): + # Sample training level + rvq_level = torch.randint( + 0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,)).item() + + target, chosen_tokens, _, _ = self.model( + batch["tts_quantize_input"], rvq_level, batch["semantic_tokens"], + batch["quantization_lengths"], + speaker_emb=batch["speaker"], + min_seq_length=batch["quantization_lengths"].min().item()) + + # Mask targets and labels + mask = chosen_tokens + target = target[mask] + + labels = batch["tts_quantize_input"][:, :, rvq_level] + labels = labels[mask] + + loss = self.cross_entropy(target, labels) + acc = (target.argmax(-1) == labels).float().mean() + self.log("train/loss", loss, on_step=True, prog_bar=True) + self.log("train/acc", acc, on_step=True, prog_bar=True) + self.log( + f"train/acc_lvl_{rvq_level}", acc, on_step=True, prog_bar=False) + + return loss + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + speaker_emb = batch["speaker"] + acoustic_tokens = batch["tts_quantize_input"] + semantic_tokens = batch["semantic_tokens"] + + if self.hp.only_inference: + self.inference( + acoustic_tokens, semantic_tokens, self.hp.first_n_lvls) + else: + rvq_level = torch.randint( + 0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,) + ).item() + + # FIXME: edge case + if len(semantic_tokens.shape) == 3: + semantic_tokens = rearrange(semantic_tokens, "B 1 T -> B T") + + target, chosen_tokens, _, _ = self.model( + acoustic_tokens, rvq_level, semantic_tokens, + torch.tensor([acoustic_tokens.shape[1]]).to(self.device), + speaker_emb=speaker_emb, + min_seq_length=acoustic_tokens.shape[1] + ) + + target = target[chosen_tokens] + labels = acoustic_tokens[:, :, rvq_level][chosen_tokens] + loss = self.cross_entropy(target, labels) + + acc = (target.argmax(-1) == labels).float().mean() + acc_5 = self.top_k_accuracy(labels, target, 5) + + self.log( + f"val/dataset_{dataloader_idx}/loss", + loss, + on_epoch=True, + logger=True, + add_dataloader_idx=False, + ) + self.log( + f"val/dataset_{dataloader_idx}/acc_lvl", + acc, + on_epoch=True, + logger=True, + add_dataloader_idx=False, + ) + self.log( + f"val/dataset_{dataloader_idx}/acc_lvl_{rvq_level}", + acc, + on_epoch=True, + logger=True, + add_dataloader_idx=False, + ) + self.log( + f"val/dataset_{dataloader_idx}/acc_top_5", + acc_5, + on_epoch=True, + logger=True, + add_dataloader_idx=False, + ) + self.log( + f"val/dataset_{dataloader_idx}/acc_top_5_lvl_{rvq_level}", + acc_5, + on_epoch=True, + logger=True, + add_dataloader_idx=False, + ) + + def compute_stats(self, logits, labels, mask_ratio=0, rvq_level=0): + acc = (logits.argmax(-1) == labels).float().mean() + acc_5 = self.top_k_accuracy(labels, logits, 5) + acc_10 = self.top_k_accuracy(labels, logits, 10) + + idx = torch.randperm(logits.shape[0]) + logits_shuffled = logits[idx] + random = self.top_k_accuracy(labels, logits_shuffled, 10) + print(f"Mask ratio: {mask_ratio}, Level {rvq_level}: acc {acc}," + f"acc 5 {acc_5}, acc 10 {acc_10}, quasi random {random}") + + +class TTSConformer(pl.LightningModule): + def __init__(self, hp): + super().__init__() + self.hp = hp + self.padding_id = self.hp.n_codes + + additional_codes = [c.PAD, c.SPKR_1, c.SPKR_2] + + self.embedding = nn.ModuleList( + [ + nn.Embedding( + self.hp.n_codes + len(additional_codes), + self.hp.hidden_size, + padding_idx=self.padding_id) + for _ in range(self.hp.n_cluster_groups) + ] + ) + + # Additional modules + self.semantic_embedding = nn.Embedding( + self.hp.n_semantic_codes + len(additional_codes), + self.hp.hidden_size, + padding_idx=self.padding_id) + + if self.hp.use_spkr_emb: + self.spkr_linear = nn.Linear(c.SPKR_EMB_SIZE, self.hp.hidden_size) + + self.conformer = Conformer( + dim=self.hp.hidden_size, + num_layers=self.hp.enc_nlayers, + heads=self.hp.nheads, + dim_head=64, + ff_mult=4, # 512*4=2048 + conv_expansion_factor=2, + conv_kernel_size=self.hp.depthwise_conv_kernel_size, + attn_dropout=self.hp.dropout, + ff_dropout=self.hp.dropout, + conv_dropout=self.hp.dropout, + attn_flash=True, + t5_rel_pos_bias=False + ) + + self.heads = nn.ModuleList( + [ + nn.Linear( + self.hp.hidden_size, + self.hp.n_codes + len(additional_codes) + ) + for _ in range(self.hp.n_cluster_groups) + ] + ) + + def build_mask_from_lengths(self, length, max_len=None): + max_len = max_len or length.max().item() + mask = torch.arange( + max_len, device=length.device)[None, :] >= length[:, None] + return mask.bool() + + @torch.no_grad() + def create_mask( + self, B, T, lengths, mask_ratio=None, start_t=None, + min_seq_length=None + ): + # 1. Define the random length of condition tokens given the shortest + # audio in the batch + if start_t is None: + start_t = torch.randint(1, min_seq_length - 1, (1,)).item() + + # 2. Mask other tokens - sample different masking levels per + if mask_ratio is None: + ratio = torch.rand(1).item() + mask_ratio = masking_logic.schedule(ratio) + + # Create a random tensor with values between 0 and 1 + random_tensor = torch.rand( + (B, T - start_t), dtype=torch.float).to(self.device) + # Create a mask where values less than p are set to True + initial_mask = random_tensor < mask_ratio + length_mask = self.build_mask_from_lengths( + lengths - start_t, T - start_t) + # we can't pick up tokens past token lengths + initial_mask = torch.logical_and(initial_mask, ~length_mask) + + # Constrain ratio to always include some samples + # If all are False let's pick up at least one: + if torch.sum(initial_mask) == 0: + choose_steps = torch.randint(low=0, high=(T - start_t), size=(B,)) + initial_mask[torch.arange(B), choose_steps] = torch.tensor( + True, device=self.device) + + # 3. Add condition tokens containing information + acoustic_token_mask = torch.cat( + (torch.full((B, start_t), False, device=self.device), initial_mask), # noqa + 1 + ) + + return acoustic_token_mask, start_t, mask_ratio + + def process_input( + self, data, lengths, rvq_level, min_seq_length=None, + mask_ratio=None, start_t=None, acoustic_token_mask=None + ): + """ + data: (B, T, code_level, D) + rvq_level: int + """ + B = data.size(0) + T = data.size(1) + level_data = data[:, :, rvq_level, :] # [B, T, C, D] -> [B, T, D] + + # Choose acoustic tokens to mask + if acoustic_token_mask is None: + acoustic_token_mask, start_t, mask_ratio = self.create_mask( + B, T, lengths, mask_ratio=mask_ratio, start_t=start_t, + min_seq_length=min_seq_length) + # Remove code information from chosen tokens + level_data[acoustic_token_mask, :] = 0 + + # Embed only lower rvq_level + lower_code_data = data[:, :, :rvq_level, :].sum(dim=2) + + # Combine with chosen tokens at rvq_level. + # Note: all tokens at rvq_level+1: will be discarded. + summed_data = torch.add(lower_code_data, level_data) + + return summed_data, acoustic_token_mask, mask_ratio, start_t + + def forward( + self, x, code_level, semantic_tokens, lengths, + speaker_emb=None, min_seq_length=10, mask_ratio=None, start_t=None, + acoustic_token_mask=None + ): + # FIXME: parallelize this + batch = [] + for lvl, embed in enumerate(self.embedding[:(code_level + 1)]): + batch.append(embed(x[:, :, lvl])) # [B T D] + + x = torch.stack(batch, dim=2) # [B T C D] + x, acoustic_token_mask, mask_ratio, start_t = self.process_input( + x, lengths, code_level, min_seq_length=min_seq_length, + mask_ratio=mask_ratio, start_t=start_t, + acoustic_token_mask=acoustic_token_mask + ) + + # Add phoneme embeddings + # Cross attention for all tokens? + + # Add semantic tokens + # HACK ME + semantic_emb = self.semantic_embedding(semantic_tokens) + x = torch.add(x, semantic_emb) + # FIXME pfb30 + + # Merge different modalities + if self.hp.use_spkr_emb: + spkr_emb = F.normalize(speaker_emb, dim=-1) + spkr_emb = self.spkr_linear( + F.dropout(spkr_emb, self.hp.speaker_embed_dropout) + ) + x = torch.add(x, spkr_emb) + + output_frames = self.conformer(x, None) + + x = self.heads[code_level](output_frames) + + return x, acoustic_token_mask, mask_ratio, start_t + + @torch.no_grad() + def inference( + self, codes, semantic_tokens, + length: torch.LongTensor, rvq_levels=7, + mask_ratio=0.99, maskgit_inference=True, + start_t: Union[torch.LongTensor, None] = None, + speaker_emb=None, steps=16 + ): + # Use half of the recording for the conditioning + if start_t is None: + start_t = torch.tensor(int((codes.shape[1]) / 2)).long() + + start_t = start_t.item() + + for rvq_level in range(rvq_levels): + original_codes = torch.clone(codes) + if rvq_level == 0 and maskgit_inference: + codes = self.multi_step_inference( + original_codes, semantic_tokens, length, + start_t=start_t, vamp_filtering=False, + speaker_emb=speaker_emb, steps=16 + ) + else: + codes = self.one_step_inference( + original_codes, semantic_tokens, length, + code_level=rvq_level, + mask_ratio=mask_ratio, start_t=start_t, + speaker_emb=speaker_emb + ) + + codes = rearrange(codes, 'T C -> 1 T C') + + # Remove any padding left + codes = rearrange(codes, '1 T C -> 1 C T') + codes = torch.where(codes >= self.hp.n_codes, 0, codes) + acoustic_tokens = codes + semantic_tokens = rearrange(semantic_tokens, 'b c -> b 1 c') + semantic_tokens = torch.where( + semantic_tokens >= self.hp.n_codes, 0, semantic_tokens) + codes = torch.cat([semantic_tokens, acoustic_tokens], dim=1) + + return codes + + @torch.no_grad() + def one_step_inference( + self, original_codes, semantic_tokens, lengths, code_level=0, + mask_ratio=0.99, start_t=0, inference_setup="argmax", speaker_emb=None + ): + codes = torch.clone(original_codes) + logits, _, _, _ = self.forward( + codes, code_level, semantic_tokens, lengths, + mask_ratio=mask_ratio, start_t=start_t, + speaker_emb=speaker_emb, acoustic_token_mask=False) + + if inference_setup == "argmax": + probs = torch.nn.functional.softmax(logits, dim=-1) + top_indeces = torch.argmax(probs, dim=-1) + + if inference_setup == "sampling": + top_indeces = torch.distributions.Categorical( + logits=logits).sample() + + codes = rearrange(codes, '1 T C -> T C') + codes[start_t:, code_level] = top_indeces[0, start_t:] + + return codes + + @torch.no_grad() + def multi_step_inference( + self, original_codes, semantic_tokens, lengths, + start_t: torch.LongTensor=None, + choice_temperature=1.0, start_iter=0, + steps=16, vamp_filtering=False, speaker_emb=None + ): + codes = torch.clone(original_codes) + code_level = 0 + _, seq_len, _ = original_codes.shape + mask_token_id = self.padding_id + + # Get true codes for the prompt + prompt_mask = codes[:, :start_t, code_level] + + # Fill up rest with masks + mask = torch.full( + (1, seq_len - start_t), mask_token_id, device=self.device) + inputs = torch.cat((prompt_mask, mask), 1) + + num_mask_tokens_at_start = torch.sum(inputs == mask_token_id, axis=-1) + + # Initializes state + state = state_init(inputs, steps, start_iter=start_iter) + + def loop_cond_fn(state): + """Beam search loop termination condition.""" + not_at_end = (state.cur_index < steps) + return not_at_end + + while loop_cond_fn(state): + """Beam search loop state update function.""" + step = state.cur_index + # Current input ids: [batch_size, seq_length]. + cur_ids = state.cur_seqs + + # Calls model on current seqs to get next-iteration seqs. + with torch.no_grad(): + logits, _, _, _ = self.forward( + rearrange(inputs, 'B T -> B T 1'), + code_level, + semantic_tokens, lengths, + acoustic_token_mask=False, + speaker_emb=speaker_emb) + + # Samples the ids using categorical sampling: + if vamp_filtering: + typical_mass = 0.2 + typical_min_tokens = 1 + top_p = None + sample_cutoff = 0.5 + typical_filtering = False + sampled_ids, selected_probs = sample_from_logits( + logits, sample=((step / steps) <= sample_cutoff), + temperature=choice_temperature, + typical_filtering=typical_filtering, + typical_mass=typical_mass, + typical_min_tokens=typical_min_tokens, + top_k=None, top_p=top_p, return_probs=True, + ) + else: + sampled_ids = torch.distributions.Categorical( + logits=logits).sample() + + # Just updates the masked tokens. + unknown_map = (cur_ids == mask_token_id) + sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids) + # Defines the mask ratio for the next round. The number to mask out + # is determined by mask_ratio * unknown_number_in_the_beginning. + ratio = 1. * (step + 1) / steps + mask_ratio = masking_logic.schedule(ratio) + + # Updates final seqs with the current sampled_ids. + final_seqs = torch.clone(state.final_seqs) + final_seqs[:, step, :] = sampled_ids + # Computes the probabilities of each selected tokens. + probs = torch.nn.functional.softmax(logits, dim=-1) + # Extract the probabilities of sampled ids + selected_probs = torch.squeeze( + torch.take_along_dim( + probs, torch.unsqueeze(sampled_ids, -1) , -1), + -1 + ) + + # Ignores the tokens given in the input + # by overwriting their confidence. + selected_probs = torch.where( + unknown_map, selected_probs, torch.inf) + # Gets mask lens for each sample in the + # batch according to the mask ratio. + num_to_mask = torch.unsqueeze( + torch.floor(num_mask_tokens_at_start * mask_ratio), 1) + + # Keeps at least one of prediction in this + # round and also masks out at least + # one and for the next iteration + num_to_mask = torch.maximum( + torch.tensor(1), + torch.minimum( + torch.sum(unknown_map, dim=-1, keepdim=True) - 1, + num_to_mask) + ) + # Adds noise for randomness + masking = mask_by_random_topk( + num_to_mask, selected_probs, choice_temperature * (1. - ratio)) + # Masks tokens with lower confidence. + sampled_ids = torch.where(masking, mask_token_id, sampled_ids) + + state = State( + cur_index=state.cur_index + 1, + cur_seqs=sampled_ids, + final_seqs=final_seqs + ) + + codes = torch.clone(original_codes) + codes = rearrange(codes, '1 T C -> T C') + codes[:, 0] = state.final_seqs[0][-1] + + return codes diff --git a/modules/speech_tokenizer.py b/modules/speech_tokenizer.py new file mode 100644 index 0000000..5c3581f --- /dev/null +++ b/modules/speech_tokenizer.py @@ -0,0 +1,86 @@ +"""Speech tokenizer class. + +Copyright PolyAI Limited. +""" +import logging +import os + +import numpy as np +import torch +import torchaudio +from speechtokenizer import SpeechTokenizer as ST + +from modules.tokenizer import BaseTokenizer + + +class SpeechTokenizer(BaseTokenizer): + def __init__(self, config_path: str, ckpt_path: str): + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") + self.model = ST.load_from_checkpoint( + config_path, ckpt_path).to(self.device) + self.model.eval() + + def encode_file( + self, folder_path: str, destination_folder: str, filename: str): + dest_path = os.path.join( + destination_folder, "semantic", + os.path.splitext(filename)[0] + ".npy" + ) + dest_path2 = os.path.join( + destination_folder, "acoustic", + os.path.splitext(filename)[0] + ".npy" + ) + if os.path.exists(dest_path) and os.path.exists(dest_path2): + pass + else: + self._create_subfolders(destination_folder=destination_folder) + + file_path = os.path.join(folder_path, filename) + wav_info = torchaudio.info(file_path) + wav_dur_sec = wav_info.num_frames / wav_info.sample_rate + if wav_dur_sec > 60: + logging.info( + f"Skipping {file_path} is too long: {wav_dur_sec:.3f} sec," + "can cause CUDA OOM" + ) + return + wav, sr = torchaudio.load(file_path) + if sr != self.model.sample_rate: + logging.warning( + "Wav sample rate %(wav_sr)s does not match the model" + "sampling rate %(model_sr)s. Resampling audio", + {"wav_sr": sr, "model_sr": self.model.sample_rate}, + ) + wav = torchaudio.functional.resample( + wav, sr, self.model.sample_rate) + wav = wav.unsqueeze(0) + wav = wav.to(self.device) + + # Extract discrete codes from SpeechTokenizer + with torch.no_grad(): + codes = self.model.encode(wav) # codes: (n_q, B, T) + + semantic_tokens = codes[0, 0, :] + acoustic_tokens = codes[1:, 0, :] + + # Save the encoding as .npy + dest_path = os.path.join( + destination_folder, "acoustic", + os.path.splitext(filename)[0] + ".npy" + ) + np.save(dest_path, acoustic_tokens.cpu().numpy()) + + dest_path = os.path.join( + destination_folder, "semantic", + os.path.splitext(filename)[0] + ".npy" + ) + np.save(dest_path, semantic_tokens.cpu().numpy()) + + @staticmethod + def _create_subfolders(destination_folder: str): + if not os.path.exists(destination_folder + "/acoustic"): + os.makedirs(destination_folder + "/acoustic") + + if not os.path.exists(destination_folder + "/semantic"): + os.makedirs(destination_folder + "/semantic") diff --git a/modules/t2s_model.py b/modules/t2s_model.py new file mode 100644 index 0000000..7759bd2 --- /dev/null +++ b/modules/t2s_model.py @@ -0,0 +1,111 @@ +"""T2S model definition. + +Copyright PolyAI Limited. +""" +import os + +import numpy as np +from torch import nn +from transformers import EvalPrediction, T5Config, T5ForConditionalGeneration + +from data.collation import get_text_semantic_token_collater + + +def compute_custom_metrics(eval_prediction: EvalPrediction): + # eval_prediction: tuple + # eval_prediction[0]: tensor of decoder outputs(logits) (n_batch, n_semantic, n_tokens) # noqa + # eval_prediction[1]: tensor of encoder outputs (n_batch, n_text/n_phone, n_hidden) # noqa + logits = eval_prediction.predictions[0] + labels = eval_prediction.label_ids + n_vocab = logits.shape[-1] + mask = labels == -100 + top_1 = np.argmax(logits, axis=-1) == labels + top_1[mask] = False + top_5 = np.argsort(logits, axis=-1)[:, :, -5:] + top_5 = np.any(top_5 == np.expand_dims(labels, axis=-1), axis=-1) + top_5[mask] = False + + top_10 = np.argsort(logits, axis=-1)[:, :, -10:] + top_10 = np.any(top_10 == np.expand_dims(labels, axis=-1), axis=-1) + top_10[mask] = False + + top_1_accuracy = np.sum(top_1) / np.sum(~mask) + top_5_accuracy = np.sum(top_5) / np.sum(~mask) + top_10_accuracy = np.sum(top_10) / np.sum(~mask) + + return { + "top_1_accuracy": top_1_accuracy, + "top_5_accuracy": top_5_accuracy, + "top_10_accuracy": top_10_accuracy, + } + + +class T2S(nn.Module): + def __init__(self, hp): + super().__init__() + self.text_tokens_file = "ckpt/unique_text_tokens.k2symbols" + self.collater = get_text_semantic_token_collater(self.text_tokens_file) + self.model_size = hp.model_size + self.vocab_size = len(self.collater.idx2token) + self.config = self._define_model_config(self.model_size) + + print(f"{self.config = }") + self.t2s = T5ForConditionalGeneration(self.config) + + def _define_model_config(self, model_size): + if model_size == "test": + # n_params = 16M + d_ff = 16 + d_model = 8 + d_kv = 32 + num_heads = 1 + num_decoder_layers = 1 + num_layers = 1 + elif model_size == "tiny": + # n_params = 16M + d_ff = 1024 + d_model = 256 + d_kv = 32 + num_heads = 4 + num_decoder_layers = 4 + num_layers = 4 + elif model_size == "t5small": + # n_params = 60M + d_ff = 2048 + d_model = 512 + d_kv = 64 + num_heads = 8 + num_decoder_layers = 6 + num_layers = 6 + elif model_size == "large": + # n_params = 100M + d_ff = 2048 + d_model = 512 + d_kv = 64 + num_heads = 8 + num_decoder_layers = 14 + num_layers = 14 + elif model_size == "Large": + # n_params = 114M + d_ff = 4096 + d_model = 512 + d_kv = 64 + num_heads = 8 + num_decoder_layers = 6 + num_layers = 10 + else: + raise ValueError(f"unknown {model_size}") + + config = T5Config( + d_ff=d_ff, + d_model=d_model, + d_kv=d_kv, + num_heads=num_heads, + num_decoder_layers=num_decoder_layers, + num_layers=num_layers, + decoder_start_token_id=0, + eos_token_id=2, + vocab_size=self.vocab_size, + ) + + return config diff --git a/modules/tokenizer.py b/modules/tokenizer.py new file mode 100644 index 0000000..c84156a --- /dev/null +++ b/modules/tokenizer.py @@ -0,0 +1,73 @@ +"""Base tokenizer class. + +Copyright PolyAI Limited. +""" +import os +from asyncio import as_completed +from concurrent.futures import ThreadPoolExecutor + +from tqdm import tqdm + +from utils import measure_duration + + +class BaseTokenizer: + @measure_duration + def encode_files_with_model_seq( + self, folder_path: str, destination_folder: str): + # Ensure destination folder exists + if not os.path.exists(destination_folder): + os.makedirs(destination_folder) + + # Go through each file in the folder + filenames = os.listdir(folder_path) + # encoding files has no side effects + for filename in tqdm(filenames): + self.encode_file( + folder_path=folder_path, + destination_folder=destination_folder, + filename=filename, + ) + + def get_chunk(self, folder_path, start_percent=0, end_percent=100): + filenames = os.listdir(folder_path) + total_files = len(filenames) + + start_idx = int(total_files * (start_percent / 100)) + end_idx = int(total_files * (end_percent / 100)) + + return filenames[start_idx:end_idx] + + @measure_duration + def encode_files_with_model_concurrent( + self, folder_path: str, destination_folder: str, start_percent: int, + end_percent: int, + ): + # Ensure destination folder exists + if not os.path.exists(destination_folder): + os.makedirs(destination_folder) + + # Go through each file in the folder + filenames = self.get_chunk(folder_path, start_percent, end_percent) + + # encoding files has no side effects + with ThreadPoolExecutor(max_workers=40) as executor: + futures = [ + executor.submit( + self.encode_file, + folder_path=folder_path, + destination_folder=destination_folder, + filename=filename, + ) + for filename in filenames + ] + # Wait for all tasks to complete + for future in as_completed(futures): + future.result() + + # Explicitly shut down the thread pool + executor.shutdown() + + def encode_file( + self, folder_path: str, destination_folder: str, filename: str): + raise NotImplementedError diff --git a/modules/vocoder.py b/modules/vocoder.py new file mode 100644 index 0000000..0a941f4 --- /dev/null +++ b/modules/vocoder.py @@ -0,0 +1,79 @@ +"""Vocoder wrapper. + +Copyright PolyAI Limited. +""" +import enum + +import numpy as np +import soundfile as sf +import torch +import torch.nn as nn +from speechtokenizer import SpeechTokenizer + + +class VocoderType(enum.Enum): + SPEECHTOKENIZER = ("SPEECHTOKENIZER", 320) + + def __init__(self, name, compression_ratio): + self._name_ = name + self.compression_ratio = compression_ratio + + def get_vocoder(self, ckpt_path, config_path, **kwargs): + if self.name == "SPEECHTOKENIZER": + if ckpt_path: + vocoder = STWrapper(ckpt_path, config_path) + else: + vocoder = STWrapper() + else: + raise ValueError(f"Unknown vocoder type {self.name}") + return vocoder + + +class STWrapper(nn.Module): + def __init__( + self, + ckpt_path: str = './ckpt/speechtokenizer/SpeechTokenizer.pt', + config_path = './ckpt/speechtokenizer/config.json', + ): + super().__init__() + self.model = SpeechTokenizer.load_from_checkpoint( + config_path, ckpt_path) + + def eval(self): + self.model.eval() + + @torch.no_grad() + def decode(self, codes: torch.Tensor, verbose: bool = False): + original_device = codes.device + + codes = codes.to(self.device) + audio_array = self.model.decode(codes) + + return audio_array.to(original_device) + + def decode_to_file(self, codes_path, out_path) -> None: + codes = np.load(codes_path) + codes = torch.from_numpy(codes) + wav = self.decode(codes).cpu().numpy() + sf.write(out_path, wav, samplerate=self.model.sample_rate) + + @torch.no_grad() + def encode(self, wav, verbose=False, n_quantizers: int = None): + original_device = wav.device + wav = wav.to(self.device) + codes = self.model.encode(wav) # codes: (n_q, B, T) + return codes.to(original_device) + + def encode_to_file(self, wav_path, out_path) -> None: + wav, _ = sf.read(wav_path, dtype='float32') + wav = torch.from_numpy(wav).unsqueeze(0).unsqueeze(0) + codes = self.encode(wav).cpu().numpy() + np.save(out_path, codes) + + def remove_weight_norm(self): + pass + + @property + def device(self): + return next(self.model.parameters()).device + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7088c87 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,161 @@ +aiofiles==23.2.1 +aiohttp==3.9.1 +aiosignal==1.3.1 +alembic==1.13.0 +altair==5.2.0 +annotated-types==0.6.0 +antlr4-python3-runtime==4.9.3 +anyio==3.7.1 +asteroid-filterbanks==0.4.0 +async-timeout==4.0.3 +attrs==23.1.0 +audioread==3.0.1 +Babel==2.13.1 +certifi==2023.11.17 +cffi==1.16.0 +charset-normalizer==3.3.2 +click==8.1.7 +clldutils==3.20.0 +colorama==0.4.6 +colorlog==6.8.0 +contourpy==1.2.0 +csvw==3.2.1 +cycler==0.12.1 +decorator==5.1.1 +dlinfo==1.2.1 +docopt==0.6.2 +einops==0.7.0 +exceptiongroup==1.2.0 +fastapi==0.104.1 +ffmpy==0.3.1 +filelock==3.13.1 +fonttools==4.46.0 +frozenlist==1.4.0 +fsspec==2023.10.0 +gradio==3.48.0 +gradio_client==0.6.1 +greenlet==3.0.1 +h11==0.14.0 +httpcore==1.0.2 +httpx==0.25.2 +huggingface-hub==0.19.4 +HyperPyYAML==1.2.2 +idna==3.6 +importlib-resources==6.1.1 +isodate==0.6.1 +Jinja2==3.1.2 +joblib==1.3.2 +jsonschema==4.20.0 +jsonschema-specifications==2023.11.2 +julius==0.2.7 +kiwisolver==1.4.5 +language-tags==1.2.0 +lazy_loader==0.3 +librosa==0.10.1 +lightning==2.1.2 +lightning-utilities==0.10.0 +llvmlite==0.41.1 +lxml==4.9.3 +Mako==1.3.0 +Markdown==3.5.1 +markdown-it-py==3.0.0 +MarkupSafe==2.1.3 +matplotlib==3.8.2 +mdurl==0.1.2 +mpmath==1.3.0 +msgpack==1.0.7 +multidict==6.0.4 +networkx==3.2.1 +numba==0.58.1 +numpy==1.26.2 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.18.1 +nvidia-nvjitlink-cu12==12.3.101 +nvidia-nvtx-cu12==12.1.105 +omegaconf==2.3.0 +optuna==3.4.0 +orjson==3.9.10 +packaging==23.2 +pandas==2.1.3 +phonemizer==3.2.1 +Pillow==10.1.0 +platformdirs==4.0.0 +pooch==1.8.0 +primePy==1.3 +protobuf==4.25.1 +pyannote.audio @ https://github.com/pyannote/pyannote-audio/archive/develop.zip +pyannote.core==5.0.0 +pyannote.database==5.0.1 +pyannote.metrics==3.2.1 +pyannote.pipeline==3.0.1 +pycparser==2.21 +pydantic==2.5.2 +pydantic_core==2.14.5 +pydub==0.25.1 +Pygments==2.17.2 +pylatexenc==2.10 +pyparsing==3.1.1 +python-dateutil==2.8.2 +python-multipart==0.0.6 +pytorch-lightning==2.1.2 +pytorch-metric-learning==2.3.0 +pytz==2023.3.post1 +PyYAML==6.0.1 +rdflib==7.0.0 +referencing==0.31.1 +regex==2023.10.3 +requests==2.31.0 +rfc3986==1.5.0 +rich==13.7.0 +rpds-py==0.13.2 +ruamel.yaml==0.18.5 +ruamel.yaml.clib==0.2.8 +safetensors==0.4.1 +scikit-learn==1.3.2 +scipy==1.11.4 +segments==2.2.1 +semantic-version==2.10.0 +semver==3.0.2 +sentencepiece==0.1.99 +shellingham==1.5.4 +six==1.16.0 +sniffio==1.3.0 +sortedcontainers==2.4.0 +soundfile==0.12.1 +soxr==0.3.7 +speechbrain==0.5.16 +speechtokenizer==0.1.2 +SQLAlchemy==2.0.23 +starlette==0.27.0 +sympy==1.12 +tabulate==0.9.0 +tensorboardX==2.6.2.2 +threadpoolctl==3.2.0 +tokenizers==0.15.0 +tomlkit==0.12.0 +toolz==0.12.0 +torch==2.1.1 +torch-audiomentations==0.11.0 +torch-pitch-shift==1.2.4 +torchaudio==2.1.1 +torchmetrics==1.2.1 +torchvision==0.16.1 +tqdm==4.66.1 +transformers==4.35.2 +triton==2.1.0 +typer==0.9.0 +typing_extensions==4.8.0 +tzdata==2023.3 +uritemplate==4.1.1 +urllib3==2.1.0 +uvicorn==0.24.0.post1 +websockets==11.0.3 +yarl==1.9.3 \ No newline at end of file diff --git a/train_s2a.py b/train_s2a.py new file mode 100644 index 0000000..e56f7d9 --- /dev/null +++ b/train_s2a.py @@ -0,0 +1,214 @@ +"""S2A training logic. + +Copyright PolyAI Limited. +""" +import argparse +import json +import os +from pathlib import Path +from typing import List + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger + +from data.data_module import DataModule +from modules.s2a_model import Pheme +from modules.vocoder import VocoderType + + +def parse_args(): + parser = argparse.ArgumentParser() + # Paths + parser.add_argument("--saving_path", type=str, default="./ckpt") + parser.add_argument("--resume_checkpoint", type=str, default=None) + parser.add_argument( + "--vocoder_type", + type=str, + choices=[voc_type.name for voc_type in VocoderType], + default=VocoderType.SPEECHTOKENIZER.name, + ) + parser.add_argument("--vocoder_config_path", type=str, default=None) + parser.add_argument("--vocoder_ckpt_path", type=str, default=None) + parser.add_argument( + "--metapath", type=str, nargs="+", help="paths to train metadata", + required=True + ) + parser.add_argument( + "--val_metapath", type=str, nargs="+", default=[], + help="paths to validation metadata", + ) + parser.add_argument("--pretrained_path", type=str, default=None) + parser.add_argument("--speaker_embedding_dir", type=str, default=None) + parser.add_argument("--sampledir", type=str, default="./logs") + + # Optimizer + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--batch_size", type=float, default=200) + parser.add_argument("--max_length", type=int, default=1600) + parser.add_argument("--train_bucket_size", type=int, default=8192) + parser.add_argument("--training_step", type=int, default=800000) + parser.add_argument("--optim_flat_percent", type=float, default=0.0) + parser.add_argument("--warmup_step", type=int, default=50) + parser.add_argument("--adam_beta1", type=float, default=0.9) + parser.add_argument("--adam_beta2", type=float, default=0.98) + + # Architecture + parser.add_argument("--ffd_size", type=int, default=3072) + parser.add_argument("--hidden_size", type=int, default=768) + parser.add_argument("--enc_nlayers", type=int, default=6) + parser.add_argument("--dec_nlayers", type=int, default=6) + parser.add_argument("--nheads", type=int, default=12) + parser.add_argument("--dropout", type=float, default=0.1) + parser.add_argument("--depthwise_conv_kernel_size", type=int, default=5) + parser.add_argument("--aligner_softmax_temp", type=float, default=1.0) + parser.add_argument("--layer_norm_eps", type=float, default=1e-5) + parser.add_argument("--use_sem_tokens", type=bool, default=True) + parser.add_argument("--use_spkr_emb", action="store_true") + parser.add_argument("--use_text_emb", action="store_true") + parser.add_argument("--only_inference", action="store_true") + + # Dropout + parser.add_argument("--speaker_embed_dropout", type=float, default=0.05) + parser.add_argument("--label_smoothing", type=float, default=0.0) + + # Trainer + parser.add_argument("--val_check_interval", type=int, default=1) + parser.add_argument("--max_dataset_samples", type=int, default=-1) + parser.add_argument("--check_val_every_n_epoch", type=int, default=1) + parser.add_argument( + "--precision", type=str, choices=["16", "32", "bf16"], default="bf16" + ) + parser.add_argument("--nworkers", type=int, default=16) + parser.add_argument("--distributed", action="store_true") + parser.add_argument( + "--accelerator", + type=str, + choices=["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"], + default="gpu", + ) + parser.add_argument("--version", type=int, default=None) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + + # Data + parser.add_argument("--sample_rate", type=int, default=16000) + parser.add_argument("--n_codes", type=int, default=1024) + parser.add_argument("--n_cluster_groups", type=int, default=7) + parser.add_argument("--first_n_lvls", type=int, default=7) + parser.add_argument("--use_pretrained_ckpt_cfg", action="store_true") + parser.add_argument("--n_semantic_codes", type=int, default=1024) + + # Distribution + parser.add_argument("--sagemaker", action="store_true") + + args = parser.parse_args() + + return args + + +def split_metapath(in_paths: List[str]): + podidx_paths, other_paths = [], [] + + for itm_path in in_paths: + if itm_path.endswith("jsonl"): + podidx_paths.append(itm_path) + else: + other_paths.append(itm_path) + + return podidx_paths, other_paths + + +if __name__ == "__main__": + args = parse_args() + os.makedirs(args.saving_path, exist_ok=True) + + with open(os.path.join(args.saving_path, "config.json"), "w") as f: + json.dump(args.__dict__, f, indent=2) + + if args.pretrained_path: + if ( + Path(args.pretrained_path).with_name("config.json").exists() + and args.use_pretrained_ckpt_cfg + ): + with open( + Path(args.pretrained_path).with_name("config.json"), "r") as f: + prev_cfg = json.load(f) + for k, v in prev_cfg.items(): + if isinstance(v, (int,)): + if args.__dict__[k] != v: + print(f"updating {k}!", args.__dict__[k], v) + args.__dict__[k] = v + + fname_prefix = f"" + checkpoint_callback = ModelCheckpoint( + dirpath=args.saving_path, + filename=(fname_prefix + "{epoch}-{step}"), + every_n_train_steps=( + None if args.val_check_interval == 1.0 else args.val_check_interval # noqa + ), + every_n_epochs=( + None if args.check_val_every_n_epoch == 1 else args.check_val_every_n_epoch # noqa + ), + verbose=True, + save_last=True, + save_top_k=3, + monitor="val/dataset_0/acc_top_5", + mode='max' + ) + lr_monitor = LearningRateMonitor(logging_interval="step") + + logger_tb = TensorBoardLogger( + args.saving_path, name="VQ-TTS", version=args.version) + logger_wandb = WandbLogger(project="mqtts", log_model=True, config=args) + + distribution_kwargs = { + "accelerator": "gpu", + "strategy": "ddp_find_unused_parameters_true" if args.distributed else "auto", # noqa + } + + wrapper = Trainer( + precision=args.precision, + callbacks=[checkpoint_callback, lr_monitor], + num_sanity_val_steps=20, + max_steps=args.training_step, + accumulate_grad_batches=args.accumulate_grad_batches, + logger=[logger_tb, logger_wandb], + check_val_every_n_epoch=args.check_val_every_n_epoch, + profiler="simple", + use_distributed_sampler=False, + # distribution + **distribution_kwargs, + ) + model = Pheme(args) + logger_wandb.watch(model=model) + _, other_metapath = split_metapath(args.metapath) + _, other_val_metapath = split_metapath(args.val_metapath) + + print( + f"Received datasets: \n{other_metapath = } " + f"\n \n{other_val_metapath = }" + ) + + other_meta = {} + if len(other_metapath) > 0: + other_meta["fit"] = other_metapath + if len(other_val_metapath) > 0: + other_meta["valid"] = other_val_metapath + + data_module = DataModule( + args, other_metapath, other_val_metapath, + wrapper.world_size, wrapper.local_rank + ) + data_module.setup(stage="fit") + train_data_module = data_module + + valid_dataloaders = [] + data_module.setup(stage="valid") + valid_dataloaders.extend(data_module.val_dataloader()) + + wrapper.fit( + model, + train_dataloaders=train_data_module.train_dataloader(), + val_dataloaders=valid_dataloaders, + ckpt_path=args.resume_checkpoint, + ) diff --git a/train_t2s.py b/train_t2s.py new file mode 100644 index 0000000..6fb9b7d --- /dev/null +++ b/train_t2s.py @@ -0,0 +1,127 @@ +"""Train T2S to generate semantic tokens. + +Copyright PolyAI Limited. +""" +import argparse +import logging +from datetime import datetime +from pathlib import Path + +import torch +from transformers import Trainer, TrainingArguments + +from data.semantic_dataset import Collator, ConcatenateSemanticDataset +from modules.t2s_model import T2S, compute_custom_metrics +from utils import split_metapath + + +# Synchronize the GPU +torch.cuda.synchronize() + +# Check for CUDA errors +if torch.cuda.is_available(): + device = torch.cuda.current_device() + print(torch.cuda.get_device_properties(device)) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--metapath", type=str, nargs="+", help="paths to train metadata", + required=True + ) + parser.add_argument( + "--val_metapath", type=str, nargs="+", default=[], + help="paths to validation metadata", + ) + parser.add_argument( + "--train_path", type=str, + default="datasets/giga-training-data/train.json" + ) + parser.add_argument( + "--eval_path", type=str, + default="datasets/giga-training-data/dev.json" + ) + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument( + "--model_size", choices=["test", "tiny", "t5small", "large", "Large"], + default="tiny" + ) + parser.add_argument("--eval_accumulation_steps", type=int, default=10) + parser.add_argument("--warmup_steps", type=int, default=5000) + parser.add_argument("--save_steps", type=int, default=500) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--n_epochs", type=int, default=20) + parser.add_argument("--nworkers", type=int, default=8) + parser.add_argument("--max_duration", type=int, default=15) + parser.add_argument("--eval_n_samples", type=int, default=400) + parser.add_argument("--learning_rate", type=float, default=5E-4) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--resume_from_checkpoint", type=str, default=None) + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + args = parse_args() + + model = T2S(args) + n_params = sum([param.numel() for param in model.parameters()]) + print(f"Model has {n_params = }") + + train_path = split_metapath(args.metapath) + eval_paths = split_metapath(args.val_metapath) + + dataset_train = ConcatenateSemanticDataset( + manifest_path=train_path, + symbol_table_path=model.text_tokens_file, + max_duration=args.max_duration + ) + + dataset_eval = ConcatenateSemanticDataset( + manifest_path=eval_paths, + symbol_table_path=model.text_tokens_file, + n_samples=args.eval_n_samples, + max_duration=args.max_duration + ) + + current_timestamp = datetime.now() + current_timestamp = current_timestamp.strftime("%Y-%m-%d-%H:%M:%S") + if args.resume_from_checkpoint is not None: + output_dir = Path(args.resume_from_checkpoint).parent + else: + output_dir = Path(args.output_dir) + + training_args = TrainingArguments( + output_dir=output_dir, + learning_rate=args.learning_rate, + per_device_train_batch_size=args.batch_size, + per_device_eval_batch_size=args.batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + num_train_epochs=args.n_epochs, + save_steps=args.save_steps, + eval_steps=args.save_steps, + save_total_limit=3, + dataloader_num_workers=args.nworkers, + evaluation_strategy="steps", + save_strategy="steps", + load_best_model_at_end=True, + report_to=["all"], + bf16=False, + warmup_steps=args.warmup_steps, + ddp_find_unused_parameters=False, + eval_accumulation_steps=args.eval_accumulation_steps + ) + + trainer = Trainer( + model=model.t2s, + args=training_args, + data_collator=Collator().collate, + train_dataset=dataset_train, + eval_dataset=dataset_eval, + compute_metrics=compute_custom_metrics, + ) + + trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) diff --git a/transformer_infer.py b/transformer_infer.py new file mode 100644 index 0000000..a2b72fb --- /dev/null +++ b/transformer_infer.py @@ -0,0 +1,262 @@ +"""Inference logic. + +Copyright PolyAI Limited. +""" +import argparse +import json +import logging +import os +import time +from pathlib import Path + +import numpy as np +import soundfile as sf +import torch +from einops import rearrange +from librosa.util import normalize +from pyannote.audio import Inference +from transformers import GenerationConfig, T5ForConditionalGeneration + +import constants as c +from data.collation import get_text_semantic_token_collater +from data.semantic_dataset import TextTokenizer +from modules.s2a_model import Pheme +from modules.vocoder import VocoderType + +# How many times one token can be generated +MAX_TOKEN_COUNT = 100 + +logging.basicConfig(level=logging.DEBUG) +device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--text", type=str, + default="I gotta say, I would never expect that to happen!" + ) + parser.add_argument( + "--manifest_path", type=str, default="demo/manifest.json") + parser.add_argument("--outputdir", type=str, default="demo/") + parser.add_argument("--featuredir", type=str, default="demo/") + parser.add_argument( + "--text_tokens_file", type=str, + default="ckpt/unique_text_tokens.k2symbols" + ) + parser.add_argument("--t2s_path", type=str, default="ckpt/t2s/") + parser.add_argument( + "--a2s_path", type=str, default="ckpt/s2a/s2a.ckpt") + + parser.add_argument("--target_sample_rate", type=int, default=16_000) + + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--top_k", type=int, default=210) + parser.add_argument("--voice", type=str, default="male_voice") + + return parser.parse_args() + + +class PhemeClient(): + def __init__(self, args): + self.args = args + self.outputdir = args.outputdir + self.target_sample_rate = args.target_sample_rate + self.featuredir = Path(args.featuredir).expanduser() + self.collater = get_text_semantic_token_collater(args.text_tokens_file) + self.phonemizer = TextTokenizer() + + self.load_manifest(args.manifest_path) + + # T2S model + self.t2s = T5ForConditionalGeneration.from_pretrained(args.t2s_path) + self.t2s.to(device) + self.t2s.eval() + + # S2A model + self.s2a = Pheme.load_from_checkpoint(args.a2s_path) + self.s2a.to(device=device) + self.s2a.eval() + + # Vocoder + vocoder = VocoderType["SPEECHTOKENIZER"].get_vocoder(None, None) + self.vocoder = vocoder.to(device) + self.vocoder.eval() + + self.spkr_embedding = Inference( + "pyannote/embedding", + window="whole", + use_auth_token=os.environ["HUGGING_FACE_HUB_TOKEN"], + ) + + def load_manifest(self, input_path): + input_file = {} + with open(input_path, "rb") as f: + for line in f: + temp = json.loads(line) + input_file[temp["audio_filepath"].split(".wav")[0]] = temp + self.input_file = input_file + + def lazy_decode(self, decoder_output, symbol_table): + semantic_tokens = map(lambda x: symbol_table[x], decoder_output) + semantic_tokens = [int(x) for x in semantic_tokens if x.isdigit()] + + return np.array(semantic_tokens) + + def infer_text(self, text, voice, sampling_config): + semantic_prompt = np.load(self.args.featuredir + "/audios-speech-tokenizer/semantic/" + f"{voice}.npy") # noqa + phones_seq = self.phonemizer(text)[0] + input_ids = self.collater([phones_seq]) + input_ids = input_ids.type(torch.IntTensor).to(device) + + labels = [str(lbl) for lbl in semantic_prompt] + labels = self.collater([labels])[:, :-1] + decoder_input_ids = labels.to(device).long() + logging.debug(f"decoder_input_ids: {decoder_input_ids}") + + counts = 1E10 + while (counts > MAX_TOKEN_COUNT): + output_ids = self.t2s.generate( + input_ids, decoder_input_ids=decoder_input_ids, + generation_config=sampling_config).sequences + + # check repetitiveness + _, counts = torch.unique_consecutive(output_ids, return_counts=True) + counts = max(counts).item() + + output_semantic = self.lazy_decode( + output_ids[0], self.collater.idx2token) + + # remove the prompt + return output_semantic[len(semantic_prompt):].reshape(1, -1) + + def _load_speaker_emb(self, element_id_prompt): + wav, _ = sf.read(self.featuredir / element_id_prompt) + audio = normalize(wav) * 0.95 + speaker_emb = self.spkr_embedding( + { + "waveform": torch.FloatTensor(audio).unsqueeze(0), + "sample_rate": self.target_sample_rate + } + ).reshape(1, -1) + + return speaker_emb + + def _load_prompt(self, prompt_file_path): + element_id_prompt = Path(prompt_file_path).stem + acoustic_path_prompt = self.featuredir / "audios-speech-tokenizer/acoustic" / f"{element_id_prompt}.npy" # noqa + semantic_path_prompt = self.featuredir / "audios-speech-tokenizer/semantic" / f"{element_id_prompt}.npy" # noqa + + acoustic_prompt = np.load(acoustic_path_prompt).squeeze().T + semantic_prompt = np.load(semantic_path_prompt)[None] + + return acoustic_prompt, semantic_prompt + + def infer_acoustic(self, output_semantic, prompt_file_path): + semantic_tokens = output_semantic.reshape(1, -1) + acoustic_tokens = np.full( + [semantic_tokens.shape[1], 7], fill_value=c.PAD) + + acoustic_prompt, semantic_prompt = self._load_prompt(prompt_file_path) # noqa + + # Prepend prompt + acoustic_tokens = np.concatenate( + [acoustic_prompt, acoustic_tokens], axis=0) + semantic_tokens = np.concatenate([ + semantic_prompt, semantic_tokens], axis=1) + + # Add speaker + acoustic_tokens = np.pad( + acoustic_tokens, [[1, 0], [0, 0]], constant_values=c.SPKR_1) + semantic_tokens = np.pad( + semantic_tokens, [[0,0], [1, 0]], constant_values=c.SPKR_1) + + speaker_emb = None + if self.s2a.hp.use_spkr_emb: + speaker_emb = self._load_speaker_emb(prompt_file_path) + speaker_emb = np.repeat( + speaker_emb, semantic_tokens.shape[1], axis=0) + speaker_emb = torch.from_numpy(speaker_emb).to(device) + else: + speaker_emb = None + + acoustic_tokens = torch.from_numpy( + acoustic_tokens).unsqueeze(0).to(device).long() + semantic_tokens = torch.from_numpy(semantic_tokens).to(device).long() + start_t = torch.tensor( + [acoustic_prompt.shape[0]], dtype=torch.long, device=device) + length = torch.tensor([ + semantic_tokens.shape[1]], dtype=torch.long, device=device) + + codes = self.s2a.model.inference( + acoustic_tokens, + semantic_tokens, + start_t=start_t, + length=length, + maskgit_inference=True, + speaker_emb=speaker_emb + ) + + # Remove the prompt + synth_codes = codes[:, :, start_t:] + synth_codes = rearrange(synth_codes, "b c t -> c b t") + + return synth_codes + + def generate_audio(self, text, voice, sampling_config, prompt_file_path): + start_time = time.time() + output_semantic = self.infer_text( + text, voice, sampling_config + ) + logging.debug(f"semantic_tokens: {time.time() - start_time}") + + start_time = time.time() + codes = self.infer_acoustic(output_semantic, prompt_file_path) + logging.debug(f"acoustic_tokens: {time.time() - start_time}") + + start_time = time.time() + audio_array = self.vocoder.decode(codes) + audio_array = rearrange(audio_array, "1 1 T -> T").cpu().numpy() + logging.debug(f"vocoder time: {time.time() - start_time}") + + return audio_array + + @torch.no_grad() + def infer( + self, text, voice="male_voice", temperature=0.7, + top_k=210, max_new_tokens=750, + ): + sampling_config = GenerationConfig.from_pretrained( + self.args.t2s_path, + top_k=top_k, + num_beams=1, + do_sample=True, + temperature=temperature, + num_return_sequences=1, + max_new_tokens=max_new_tokens, + return_dict_in_generate=True, + output_scores=True + ) + + voice_data = self.input_file[voice] + prompt_file_path = voice_data["audio_prompt_filepath"] + text = voice_data["text"] + " " + text + + audio_array = self.generate_audio( + text, voice, sampling_config, prompt_file_path) + + return audio_array + + +if __name__ == "__main__": + args = parse_arguments() + args.outputdir = Path(args.outputdir).expanduser() + args.outputdir.mkdir(parents=True, exist_ok=True) + args.manifest_path = Path(args.manifest_path).expanduser() + + client = PhemeClient(args) + audio_array = client.infer(args.text, voice=args.voice) + sf.write(os.path.join( + args.outputdir, f"{args.voice}.wav"), audio_array, + args.target_sample_rate + ) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..717ab99 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,73 @@ +"""Copyright PolyAI Limited.""" +import logging +import pdb +import sys +import traceback +from functools import wraps +from time import time +from typing import List + +import torch + +from .symbol_table import SymbolTable + + +def load_checkpoint(ckpt_path: str) -> dict: + """ + Loads checkpoint, while matching phone embedding size. + """ + state_dict: dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + new_state_dict = dict() + for p_name in state_dict.keys(): + if p_name.startswith("vocoder"): + continue + + new_state_dict[p_name] = state_dict[p_name] + + return new_state_dict + + +def breakpoint_on_error(fn): + """Creates a breakpoint on error + + Use as a wrapper + + Args: + fn: the function + + Returns: + inner function + """ + + def inner(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception: + """Standard python way of creating a breakpoint on error""" + extype, value, tb = sys.exc_info() + print(f"extype={extype},\nvalue={value}") + traceback.print_exc() + pdb.post_mortem(tb) + + return inner + + +def measure_duration(f): + @wraps(f) + def wrap(*args, **kw): + ts = time() + result = f(*args, **kw) + te = time() + logging.debug("func:%r took: %2.4f sec" % (f.__name__, te - ts)) + return result + + return wrap + + +def split_metapath(in_paths: List[str]): + other_paths = [] + + for itm_path in in_paths: + other_paths.append(itm_path) + + return other_paths diff --git a/utils/get_tokens_speech_tokenizer.py b/utils/get_tokens_speech_tokenizer.py new file mode 100644 index 0000000..eed17ed --- /dev/null +++ b/utils/get_tokens_speech_tokenizer.py @@ -0,0 +1,70 @@ +"""Get tokens using the SpeechTokenizer. + +Apply SpeechTokenizer to extract acoustic and semantic tokens. +The tokens will be extracted to +encoding_output/acoustic and encoding_output/semantic. + +python utils/get_tokens_speech_tokenizer.py \ + --config_path ckpt/speechtokenizer/config.json \ + --ckpt_path ckpt/speechtokenizer/SpeechTokenizer.pt \ + --encoding_input datasets/example/audios \ + --encoding_output datasets/example/audios-speech-tokenizer + +Copyright PolyAI Limited. +""" +import argparse +import pathlib + +from modules.speech_tokenizer import SpeechTokenizer + +MQTTS_ROOT_PATH = str(pathlib.Path(__file__).parent.resolve()) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config_path", + type=str, + help="Path to the SpeechTokenizer config", + default=MQTTS_ROOT_PATH + "/ckpt/speechtokenizer/config.json", + ) + parser.add_argument( + "--ckpt_path", + type=str, + help="Path to the SpeechTokenizer checkpoint", + default=MQTTS_ROOT_PATH + "/ckpt/speechtokenizer/SpeechTokenizer.pt", + ) + parser.add_argument( + "--encoding_input", + type=str, + help="Path to the input folder for encoding", + default=MQTTS_ROOT_PATH + "/datasets/giga-training-data/audios", + ) + parser.add_argument( + "--encoding_output", + type=str, + help="Path where to save the encoded tokens", + default="/tmp/encoding_output", + ) + parser.add_argument( + "--start_percent", + type=int, + default=0, + ) + parser.add_argument( + "--end_percent", + type=int, + default=100, + ) + + args = parser.parse_args() + print("Parsed args") + print(args) + + tokenizer = SpeechTokenizer( + config_path=args.config_path, + ckpt_path=args.ckpt_path, + ) + tokenizer.encode_files_with_model_concurrent( + folder_path=args.encoding_input, destination_folder=args.encoding_output, + start_percent=args.start_percent, end_percent=args.end_percent + ) diff --git a/utils/symbol_table.py b/utils/symbol_table.py new file mode 100644 index 0000000..19b366b --- /dev/null +++ b/utils/symbol_table.py @@ -0,0 +1,281 @@ +"""Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Copyright PolyAI Limited. +""" +from dataclasses import dataclass, field +from typing import Dict, Generic, List, Optional, TypeVar, Union + +Symbol = TypeVar('Symbol') + + +# Disable __repr__ otherwise it could freeze e.g. Jupyter. +@dataclass(repr=False) +class SymbolTable(Generic[Symbol]): + '''SymbolTable that maps symbol IDs, found on the FSA arcs to + actual objects. These objects can be arbitrary Python objects + that can serve as keys in a dictionary (i.e. they need to be + hashable and immutable). + + The SymbolTable can only be read to/written from disk if the + symbols are strings. + ''' + _id2sym: Dict[int, Symbol] = field(default_factory=dict) + '''Map an integer to a symbol. + ''' + + _sym2id: Dict[Symbol, int] = field(default_factory=dict) + '''Map a symbol to an integer. + ''' + + _next_available_id: int = 1 + '''A helper internal field that helps adding new symbols + to the table efficiently. + ''' + + eps: Symbol = '' + '''Null symbol, always mapped to index 0. + ''' + + def __post_init__(self): + for idx, sym in self._id2sym.items(): + assert self._sym2id[sym] == idx + assert idx >= 0 + + for sym, idx in self._sym2id.items(): + assert idx >= 0 + assert self._id2sym[idx] == sym + + if 0 not in self._id2sym: + self._id2sym[0] = self.eps + self._sym2id[self.eps] = 0 + else: + assert self._id2sym[0] == self.eps + assert self._sym2id[self.eps] == 0 + + self._next_available_id = max(self._id2sym) + 1 + + @staticmethod + def from_str(s: str) -> 'SymbolTable': + '''Build a symbol table from a string. + + The string consists of lines. Every line has two fields separated + by space(s), tab(s) or both. The first field is the symbol and the + second the integer id of the symbol. + + Args: + s: + The input string with the format described above. + Returns: + An instance of :class:`SymbolTable`. + ''' + id2sym: Dict[int, str] = dict() + sym2id: Dict[str, int] = dict() + + for line in s.split('\n'): + fields = line.split() + if len(fields) == 0: + continue # skip empty lines + assert len(fields) == 2, \ + f'Expect a line with 2 fields. Given: {len(fields)}' + sym, idx = fields[0], int(fields[1]) + assert sym not in sym2id, f'Duplicated symbol {sym}' + assert idx not in id2sym, f'Duplicated id {idx}' + id2sym[idx] = sym + sym2id[sym] = idx + + eps = id2sym.get(0, '') + + return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps) + + @staticmethod + def from_file(filename: str) -> 'SymbolTable': + '''Build a symbol table from file. + + Every line in the symbol table file has two fields separated by + space(s), tab(s) or both. The following is an example file: + + .. code-block:: + + 0 + a 1 + b 2 + c 3 + + Args: + filename: + Name of the symbol table file. Its format is documented above. + + Returns: + An instance of :class:`SymbolTable`. + + ''' + with open(filename, 'r', encoding='utf-8') as f: + return SymbolTable.from_str(f.read().strip()) + + def to_str(self) -> str: + ''' + Returns: + Return a string representation of this object. You can pass + it to the method ``from_str`` to recreate an identical object. + ''' + s = '' + for idx, symbol in sorted(self._id2sym.items()): + s += f'{symbol} {idx}\n' + return s + + def to_file(self, filename: str): + '''Serialize the SymbolTable to a file. + + Every line in the symbol table file has two fields separated by + space(s), tab(s) or both. The following is an example file: + + .. code-block:: + + 0 + a 1 + b 2 + c 3 + + Args: + filename: + Name of the symbol table file. Its format is documented above. + ''' + with open(filename, 'w') as f: + for idx, symbol in sorted(self._id2sym.items()): + print(symbol, idx, file=f) + + def add(self, symbol: Symbol, index: Optional[int] = None) -> int: + '''Add a new symbol to the SymbolTable. + + Args: + symbol: + The symbol to be added. + index: + Optional int id to which the symbol should be assigned. + If it is not available, a ValueError will be raised. + + Returns: + The int id to which the symbol has been assigned. + ''' + # Already in the table? Return its ID. + if symbol in self._sym2id: + return self._sym2id[symbol] + # Specific ID not provided - use next available. + if index is None: + index = self._next_available_id + # Specific ID provided but not available. + if index in self._id2sym: + raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - " + f"already occupied by {self._id2sym[index]}") + self._sym2id[symbol] = index + self._id2sym[index] = symbol + + # Update next available ID if needed + if self._next_available_id <= index: + self._next_available_id = index + 1 + + return index + + def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]: + '''Get a symbol for an id or get an id for a symbol + + Args: + k: + If it is an id, it tries to find the symbol corresponding + to the id; if it is a symbol, it tries to find the id + corresponding to the symbol. + + Returns: + An id or a symbol depending on the given `k`. + ''' + if isinstance(k, int): + return self._id2sym[k] + else: + return self._sym2id[k] + + def merge(self, other: 'SymbolTable') -> 'SymbolTable': + '''Create a union of two SymbolTables. + Raises an AssertionError if the same IDs are occupied by + different symbols. + + Args: + other: + A symbol table to merge with ``self``. + + Returns: + A new symbol table. + ''' + self._check_compatible(other) + + id2sym = {**self._id2sym, **other._id2sym} + sym2id = {**self._sym2id, **other._sym2id} + + return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps) + + def _check_compatible(self, other: 'SymbolTable') -> None: + # Epsilon compatibility + assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \ + f'{self.eps} != {other.eps}' + # IDs compatibility + common_ids = set(self._id2sym).intersection(other._id2sym) + for idx in common_ids: + assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \ + f'self[idx] = "{self[idx]}", ' \ + f'other[idx] = "{other[idx]}"' + # Symbols compatibility + common_symbols = set(self._sym2id).intersection(other._sym2id) + for sym in common_symbols: + assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \ + f'self[sym] = "{self[sym]}", ' \ + f'other[sym] = "{other[sym]}"' + + def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]: + return self.get(item) + + def __contains__(self, item: Union[int, Symbol]) -> bool: + if isinstance(item, int): + return item in self._id2sym + else: + return item in self._sym2id + + def __len__(self) -> int: + return len(self._id2sym) + + def __eq__(self, other: 'SymbolTable') -> bool: + if len(self) != len(other): + return False + + for s in self.symbols: + if self[s] != other[s]: + return False + + return True + + @property + def ids(self) -> List[int]: + '''Returns a list of integer IDs corresponding to the symbols. + ''' + ans = list(self._id2sym.keys()) + ans.sort() + return ans + + @property + def symbols(self) -> List[Symbol]: + '''Returns a list of symbols (e.g., strings) corresponding to + the integer IDs. + ''' + ans = list(self._sym2id.keys()) + ans.sort() + return ans