diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..8fe13894 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,15 @@ +# These are supported funding model platforms + +github: [Flux9665] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry +polar: # Replace with a single Polar username +buy_me_a_coffee: # Replace with a single Buy Me a Coffee username +thanks_dev: # Replace with a single thanks.dev username +custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/InferenceInterfaces/ToucanTTSInterface.py b/InferenceInterfaces/ToucanTTSInterface.py index 7dde7b35..11ada6a0 100644 --- a/InferenceInterfaces/ToucanTTSInterface.py +++ b/InferenceInterfaces/ToucanTTSInterface.py @@ -143,8 +143,9 @@ def forward(self, energy=None, input_is_phones=False, return_plot_as_filepath=False, - loudness_in_db=-24.0, - prosody_creativity=0.1): + loudness_in_db=-29.0, + prosody_creativity=0.1, + return_everything=False): """ duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. 1.0 means no scaling happens, higher values increase durations for the whole @@ -233,6 +234,8 @@ def forward(self, plt.savefig("tmp.png") plt.close() return wave, sr, "tmp.png" + if return_everything: + return wave, mel, durations, pitch return wave, sr def read_to_file(self, diff --git a/Modules/ToucanTTS/InferenceToucanTTS.py b/Modules/ToucanTTS/InferenceToucanTTS.py index f62ec6b1..eec9715a 100644 --- a/Modules/ToucanTTS/InferenceToucanTTS.py +++ b/Modules/ToucanTTS/InferenceToucanTTS.py @@ -242,7 +242,7 @@ def _forward(self, mask=text_masks.float(), n_timesteps=20, temperature=prosody_creativity, - c=utterance_embedding)), min=0.0).long().squeeze(1) if gold_durations is None else gold_durations + c=utterance_embedding)), min=0.0).long().squeeze(1) if gold_durations is None else gold_durations.squeeze(1) # modifying the predictions with control parameters for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)): diff --git a/Preprocessing/multilinguality/create_lang_dist_dataset.py b/Preprocessing/multilinguality/create_lang_dist_dataset.py index 00fdc901..535bbcb3 100644 --- a/Preprocessing/multilinguality/create_lang_dist_dataset.py +++ b/Preprocessing/multilinguality/create_lang_dist_dataset.py @@ -10,7 +10,7 @@ from Preprocessing.multilinguality.SimilaritySolver import SimilaritySolver from Utility.utils import load_json_from_path -LANG_PAIRS_ORACLE_PATH = "lang_1_to_lang_2_to_oracle_dist.json" +LANG_PAIRS_ORACLE_PATH = "lang_1_to_lang_2_to_l1_dist.json" ISO_LOOKUP_PATH = hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="iso_lookup.json") ISO_TO_FULLNAME_PATH = hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="iso_to_fullname.json") LANG_PAIRS_MAP_PATH = hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="lang_1_to_lang_2_to_map_dist.json") diff --git a/Preprocessing/multilinguality/eval_lang_emb_approximation.py b/Preprocessing/multilinguality/eval_lang_emb_approximation.py index f889820a..9e3108ea 100644 --- a/Preprocessing/multilinguality/eval_lang_emb_approximation.py +++ b/Preprocessing/multilinguality/eval_lang_emb_approximation.py @@ -13,6 +13,7 @@ import matplotlib.pyplot as plt from Utility.utils import load_json_from_path + def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embeddings, weighted_avg=False, min_n_langs=5, max_n_langs=30, threshold_percentile=95, loss_fn="MSE"): df = pd.read_csv(csv_path, sep="|") @@ -23,7 +24,7 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe features_per_closest_lang = 2 # for combined, df has up to 5 features (if containing individual distances) per closest lang + 1 target lang column - if "combined_dist_0" in df.columns: + if "combined_dist_0" in df.columns: if "map_dist_0" in df.columns: features_per_closest_lang += 1 if "asp_dist_0" in df.columns: @@ -63,7 +64,7 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe except KeyError: print(f"KeyError: Unable to retrieve language embedding for {row.target_lang}") continue - avg_emb = torch.zeros([16]) + avg_emb = torch.zeros([32]) dists = [getattr(row, d) for i, d in enumerate(closest_dist_columns) if i < min_n_langs or getattr(row, d) < threshold] langs = [getattr(row, l) for l in closest_lang_columns[:len(dists)]] @@ -77,7 +78,7 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe lang_emb = language_embeddings[iso_lookup[-1][lang]] avg_emb += lang_emb normalization_factor = len(langs) - avg_emb /= normalization_factor # normalize + avg_emb /= normalization_factor # normalize current_loss = loss_fn(avg_emb, y).item() all_losses.append(current_loss) @@ -95,12 +96,10 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe parser.add_argument("--loss_fn", choices=["MSE", "L1"], type=str, default="MSE", help="loss function used") args = parser.parse_args() csv_paths = [ - "distance_datasets/dataset_map_top30_furthest.csv", "distance_datasets/dataset_random_top30.csv", "distance_datasets/dataset_asp_top30.csv", - "distance_datasets/dataset_tree_top30.csv", "distance_datasets/dataset_map_top30.csv", - "distance_datasets/dataset_combined_top30_indiv-dists.csv", + "distance_datasets/dataset_tree_top30.csv", "distance_datasets/dataset_learned_top30.csv", "distance_datasets/dataset_oracle_top30.csv", ] @@ -112,49 +111,44 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe OUT_DIR = "plots" os.makedirs(OUT_DIR, exist_ok=True) - fig, ax = plt.subplots(figsize=(3.15022, 3.15022*(2/3)), constrained_layout=True) + fig, ax = plt.subplots(figsize=(6, 4)) plt.ylabel(args.loss_fn) for i, csv_path in enumerate(csv_paths): print(f"csv_path: {os.path.basename(csv_path)}") for condition in weighted: - losses = compute_loss_for_approximated_embeddings(csv_path, - iso_lookup, - lang_embs, - condition, - min_n_langs=args.min_n_langs, - max_n_langs=args.max_n_langs, - threshold_percentile=args.threshold_percentile, - loss_fn=args.loss_fn) + losses = compute_loss_for_approximated_embeddings(csv_path, + iso_lookup, + lang_embs, + condition, + min_n_langs=args.min_n_langs, + max_n_langs=args.max_n_langs, + threshold_percentile=args.threshold_percentile, + loss_fn=args.loss_fn) print(f"weighted average: {condition} | mean loss: {np.mean(losses)}") losses_of_multiple_datasets.append(losses) - bp_dict = ax.boxplot(losses_of_multiple_datasets, - labels = [ - "map furthest", - "random", - "inv. ASP", - "tree", - "map", - "avg", - "meta-learned", - "oracle", - ], + bp_dict = ax.boxplot(losses_of_multiple_datasets, + labels=["Random", + "Inverse ASP", + "Map Distance", + "Tree Distance", + "Learned Distance", + "Oracle"], patch_artist=True, - boxprops=dict(facecolor = "lightblue", + boxprops=dict(facecolor="lightblue", ), - showfliers=False, - widths=0.45 - ) - + showfliers=False, + widths=0.55 + ) # major ticks every 0.1, minor ticks every 0.05, between 0.0 and 0.6 - major_ticks = np.arange(0, 0.6, 0.1) - minor_ticks = np.arange(0, 0.6, 0.05) + major_ticks = np.arange(0, 1.0, 0.1) + minor_ticks = np.arange(0, 1.0, 0.05) ax.set_yticks(major_ticks) ax.set_yticks(minor_ticks, minor=True) # horizontal grid lines for minor and major ticks ax.grid(which='both', linestyle='-', color='lightgray', linewidth=0.3, axis='y') - ax.set_aspect(4.5) - plt.title(f"min. {args.min_n_langs} kNN, max. {args.max_n_langs}\nthreshold: {args.threshold_percentile}th-percentile distance of {args.max_n_langs}th-closest language") + # plt.title(f"Using between {args.min_n_langs} and {args.max_n_langs} Nearest Neighbors to approximate an unseen Embedding") plt.xticks(rotation=45) - - plt.savefig(os.path.join(OUT_DIR, "example_boxplot_release.pdf"), bbox_inches='tight') + plt.tight_layout() + plt.show() + # plt.savefig(os.path.join(OUT_DIR, "example_boxplot_release.pdf"), bbox_inches='tight') diff --git a/README.md b/README.md index 867424df..cd1a73f0 100644 --- a/README.md +++ b/README.md @@ -12,12 +12,14 @@ IMS Toucan is a toolkit for training, using, and teaching state-of-the-art Text-to-Speech Synthesis, developed at the **Institute for Natural Language Processing (IMS), University of Stuttgart, Germany**, official home of the massively multilingual ToucanTTS system. Our system is fast, controllable, and doesn't require a ton of compute. -If you find this repo useful, consider giving it a star⭐. Large numbers make me happy, and they are quite motivating. -
![image](Utility/toucan.png) +
+ +If you find this repo useful, consider giving it a star. ⭐ Large numbers make me happy, and they are very motivating. If you want to motivate me even more, you can even consider [sponsoring this toolkit](https://github.com/sponsors/Flux9665). We only use GitHub Sponsors for this, there are scammers on other platforms that pretend to be the creator. Don't let them fool you. The code and the models are absolutely free, and thanks to the generous support of Hugging Face🤗, we even have an [instance of the model running on GPU](https://huggingface.co/spaces/Flux9665/MassivelyMultilingualTTS) free for anyone to use. + ---
@@ -29,17 +31,13 @@ If you find this repo useful, consider giving it a star⭐. Large numbers make m [Cloning prosody across speakers](https://toucanprosodycloningdemo.github.io) -[Multi-lingual and multi-speaker audios](https://multilingualtoucan.github.io/) - -[Massively-Multi-Lingual audios and study setup](https://anondemos.github.io/MMDemo) - ### Interactive Demo -[Check out our interactive massively-multi-lingual demo on Huggingface🤗](https://huggingface.co/spaces/Flux9665/MassivelyMultilingualTTS) +[Check out our interactive massively-multi-lingual demo on Hugging Face🤗](https://huggingface.co/spaces/Flux9665/MassivelyMultilingualTTS) ### Dataset -[We have also published a massively multilingual TTS dataset on Huggingface🤗](https://huggingface.co/datasets/Flux9665/BibleMMS) +[We have also published a massively multilingual TTS dataset on Hugging Face🤗](https://huggingface.co/datasets/Flux9665/BibleMMS) ---
@@ -94,7 +92,7 @@ absolute). #### Pretrained Models -You don't need to use pretrained models, but it can speed things up tremendously. They will be downloaded on the fly automatically when they are needed, thanks to Huggingface🤗 and [VB](https://github.com/Vaibhavs10) in particular. +You don't need to use pretrained models, but it can speed things up tremendously. They will be downloaded on the fly automatically when they are needed, thanks to Hugging Face🤗 and [VB](https://github.com/Vaibhavs10) in particular. #### \[optional] eSpeak-NG diff --git a/Recipes/AlignerPipeline.py b/Recipes/AlignerPipeline.py index 11728d66..9c3a3b73 100644 --- a/Recipes/AlignerPipeline.py +++ b/Recipes/AlignerPipeline.py @@ -1,14 +1,15 @@ import torch from torch.utils.data import ConcatDataset -from Modules.Aligner.autoaligner_train_loop import train_loop as train_aligner -from Utility.corpus_preparation import prepare_aligner_corpus from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR -from Utility.storage_config import PREPROCESSING_DIR def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id, gpu_count): + from Modules.Aligner.autoaligner_train_loop import train_loop as train_aligner + from Utility.corpus_preparation import prepare_aligner_corpus + from Utility.storage_config import MODELS_DIR + from Utility.storage_config import PREPROCESSING_DIR + if gpu_id == "cpu": device = torch.device("cpu") else: diff --git a/Recipes/BigVGAN_combined.py b/Recipes/BigVGAN_combined.py index 90918799..d9cef50f 100644 --- a/Recipes/BigVGAN_combined.py +++ b/Recipes/BigVGAN_combined.py @@ -4,15 +4,16 @@ import torch import wandb -from Modules.Vocoder.BigVGAN import BigVGAN -from Modules.Vocoder.HiFiGAN_Dataset import HiFiGANDataset -from Modules.Vocoder.HiFiGAN_Discriminators import AvocodoHiFiGANJointDiscriminator -from Modules.Vocoder.HiFiGAN_train_loop import train_loop from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR def run(gpu_id, resume_checkpoint, finetune, resume, model_dir, use_wandb, wandb_resume_id, gpu_count): + from Modules.Vocoder.BigVGAN import BigVGAN + from Modules.Vocoder.HiFiGAN_Dataset import HiFiGANDataset + from Modules.Vocoder.HiFiGAN_Discriminators import AvocodoHiFiGANJointDiscriminator + from Modules.Vocoder.HiFiGAN_train_loop import train_loop + from Utility.storage_config import MODELS_DIR + if gpu_id == "cpu": device = torch.device("cpu") else: diff --git a/Recipes/BigVGAN_e2e.py b/Recipes/BigVGAN_e2e.py index 1b451c78..71fa28c3 100644 --- a/Recipes/BigVGAN_e2e.py +++ b/Recipes/BigVGAN_e2e.py @@ -3,15 +3,16 @@ import torch import wandb -from Modules.Vocoder.BigVGAN import BigVGAN -from Modules.Vocoder.HiFiGAN_Discriminators import AvocodoHiFiGANJointDiscriminator -from Modules.Vocoder.HiFiGAN_E2E_Dataset import HiFiGANDataset -from Modules.Vocoder.HiFiGAN_train_loop import train_loop from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR def run(gpu_id, resume_checkpoint, finetune, resume, model_dir, use_wandb, wandb_resume_id, gpu_count): + from Modules.Vocoder.BigVGAN import BigVGAN + from Modules.Vocoder.HiFiGAN_Dataset import HiFiGANDataset + from Modules.Vocoder.HiFiGAN_Discriminators import AvocodoHiFiGANJointDiscriminator + from Modules.Vocoder.HiFiGAN_train_loop import train_loop + from Utility.storage_config import MODELS_DIR + if gpu_id == "cpu": device = torch.device("cpu") else: diff --git a/Recipes/HiFiGAN_combined.py b/Recipes/HiFiGAN_combined.py index 03e9bd7b..f606f48d 100644 --- a/Recipes/HiFiGAN_combined.py +++ b/Recipes/HiFiGAN_combined.py @@ -4,15 +4,16 @@ import torch import wandb -from Modules.Vocoder.HiFiGAN_Dataset import HiFiGANDataset -from Modules.Vocoder.HiFiGAN_Discriminators import AvocodoHiFiGANJointDiscriminator -from Modules.Vocoder.HiFiGAN_Generator import HiFiGAN -from Modules.Vocoder.HiFiGAN_train_loop import train_loop from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR def run(gpu_id, resume_checkpoint, finetune, resume, model_dir, use_wandb, wandb_resume_id, gpu_count): + from Modules.Vocoder.HiFiGAN_Discriminators import AvocodoHiFiGANJointDiscriminator + from Modules.Vocoder.HiFiGAN_E2E_Dataset import HiFiGANDataset + from Modules.Vocoder.HiFiGAN_Generator import HiFiGAN + from Modules.Vocoder.HiFiGAN_train_loop import train_loop + from Utility.storage_config import MODELS_DIR + if gpu_id == "cpu": device = torch.device("cpu") else: diff --git a/Recipes/HiFiGAN_e2e.py b/Recipes/HiFiGAN_e2e.py index 6cfe3ebb..1f475580 100644 --- a/Recipes/HiFiGAN_e2e.py +++ b/Recipes/HiFiGAN_e2e.py @@ -3,15 +3,16 @@ import torch import wandb -from Modules.Vocoder.HiFiGAN_Discriminators import AvocodoHiFiGANJointDiscriminator -from Modules.Vocoder.HiFiGAN_E2E_Dataset import HiFiGANDataset -from Modules.Vocoder.HiFiGAN_Generator import HiFiGAN -from Modules.Vocoder.HiFiGAN_train_loop import train_loop from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR def run(gpu_id, resume_checkpoint, finetune, resume, model_dir, use_wandb, wandb_resume_id, gpu_count): + from Modules.Vocoder.HiFiGAN_Discriminators import AvocodoHiFiGANJointDiscriminator + from Modules.Vocoder.HiFiGAN_E2E_Dataset import HiFiGANDataset + from Modules.Vocoder.HiFiGAN_Generator import HiFiGAN + from Modules.Vocoder.HiFiGAN_train_loop import train_loop + from Utility.storage_config import MODELS_DIR + if gpu_id == "cpu": device = torch.device("cpu") else: diff --git a/Recipes/ToucanTTS_IntegrationTest.py b/Recipes/ToucanTTS_IntegrationTest.py index 2b4d586e..9496c547 100644 --- a/Recipes/ToucanTTS_IntegrationTest.py +++ b/Recipes/ToucanTTS_IntegrationTest.py @@ -7,15 +7,18 @@ import torch import wandb -from Modules.ToucanTTS.ToucanTTS import ToucanTTS -from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop -from Utility.corpus_preparation import prepare_tts_corpus from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR -from Utility.storage_config import PREPROCESSING_DIR def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id, gpu_count): + from torch.utils.data import ConcatDataset + + from Modules.ToucanTTS.ToucanTTS import ToucanTTS + from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop + from Utility.corpus_preparation import prepare_tts_corpus + from Utility.storage_config import MODELS_DIR + from Utility.storage_config import PREPROCESSING_DIR + if gpu_id == "cpu": device = torch.device("cpu") else: @@ -78,4 +81,4 @@ def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb train_samplers=[train_sampler], gpu_count=gpu_count) if use_wandb: - wandb.finish() \ No newline at end of file + wandb.finish() diff --git a/Recipes/ToucanTTS_Massive_English_stage1.py b/Recipes/ToucanTTS_Massive_English_stage1.py index 9b5cde2a..5033013b 100644 --- a/Recipes/ToucanTTS_Massive_English_stage1.py +++ b/Recipes/ToucanTTS_Massive_English_stage1.py @@ -2,17 +2,19 @@ import torch import wandb -from torch.utils.data import ConcatDataset -from Modules.ToucanTTS.ToucanTTS import ToucanTTS -from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop -from Utility.corpus_preparation import prepare_tts_corpus from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR -from Utility.storage_config import PREPROCESSING_DIR def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id, gpu_count): + from torch.utils.data import ConcatDataset + + from Modules.ToucanTTS.ToucanTTS import ToucanTTS + from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop + from Utility.corpus_preparation import prepare_tts_corpus + from Utility.storage_config import MODELS_DIR + from Utility.storage_config import PREPROCESSING_DIR + if gpu_id == "cpu": device = torch.device("cpu") else: diff --git a/Recipes/ToucanTTS_Massive_English_stage2.py b/Recipes/ToucanTTS_Massive_English_stage2.py index d339fbb4..c5443341 100644 --- a/Recipes/ToucanTTS_Massive_English_stage2.py +++ b/Recipes/ToucanTTS_Massive_English_stage2.py @@ -2,17 +2,19 @@ import torch import wandb -from torch.utils.data import ConcatDataset -from Modules.ToucanTTS.ToucanTTS import ToucanTTS -from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop -from Utility.corpus_preparation import prepare_tts_corpus from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR -from Utility.storage_config import PREPROCESSING_DIR def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id, gpu_count): + from torch.utils.data import ConcatDataset + + from Modules.ToucanTTS.ToucanTTS import ToucanTTS + from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop + from Utility.corpus_preparation import prepare_tts_corpus + from Utility.storage_config import MODELS_DIR + from Utility.storage_config import PREPROCESSING_DIR + if gpu_id == "cpu": device = torch.device("cpu") else: diff --git a/Recipes/ToucanTTS_Massive_German.py b/Recipes/ToucanTTS_Massive_German.py index 47923891..96a6f045 100644 --- a/Recipes/ToucanTTS_Massive_German.py +++ b/Recipes/ToucanTTS_Massive_German.py @@ -2,17 +2,19 @@ import torch import wandb -from torch.utils.data import ConcatDataset -from Modules.ToucanTTS.ToucanTTS import ToucanTTS -from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop -from Utility.corpus_preparation import prepare_tts_corpus from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR -from Utility.storage_config import PREPROCESSING_DIR def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id, gpu_count): + from torch.utils.data import ConcatDataset + + from Modules.ToucanTTS.ToucanTTS import ToucanTTS + from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop + from Utility.corpus_preparation import prepare_tts_corpus + from Utility.storage_config import MODELS_DIR + from Utility.storage_config import PREPROCESSING_DIR + if gpu_id == "cpu": device = torch.device("cpu") else: diff --git a/Recipes/ToucanTTS_Massive_stage1.py b/Recipes/ToucanTTS_Massive_stage1.py index 85254acf..bfbdf1f3 100644 --- a/Recipes/ToucanTTS_Massive_stage1.py +++ b/Recipes/ToucanTTS_Massive_stage1.py @@ -9,17 +9,19 @@ import torch import torch.multiprocessing import wandb -from torch.utils.data import ConcatDataset -from Modules.ToucanTTS.ToucanTTS import ToucanTTS -from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop -from Utility.corpus_preparation import prepare_tts_corpus from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR -from Utility.storage_config import PREPROCESSING_DIR def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id, gpu_count): + from torch.utils.data import ConcatDataset + + from Modules.ToucanTTS.ToucanTTS import ToucanTTS + from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop + from Utility.corpus_preparation import prepare_tts_corpus + from Utility.storage_config import MODELS_DIR + from Utility.storage_config import PREPROCESSING_DIR + # It is not recommended training this yourself or to finetune this, but you can. # The recommended use is to download the pretrained model from the GitHub release # page and finetune to your desired data diff --git a/Recipes/ToucanTTS_Massive_stage2.py b/Recipes/ToucanTTS_Massive_stage2.py index 707b3216..226c72c5 100644 --- a/Recipes/ToucanTTS_Massive_stage2.py +++ b/Recipes/ToucanTTS_Massive_stage2.py @@ -8,17 +8,19 @@ import torch import torch.multiprocessing import wandb -from torch.utils.data import ConcatDataset -from Modules.ToucanTTS.ToucanTTS import ToucanTTS -from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop -from Utility.corpus_preparation import prepare_tts_corpus from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR -from Utility.storage_config import PREPROCESSING_DIR def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id, gpu_count): + from torch.utils.data import ConcatDataset + + from Modules.ToucanTTS.ToucanTTS import ToucanTTS + from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop + from Utility.corpus_preparation import prepare_tts_corpus + from Utility.storage_config import MODELS_DIR + from Utility.storage_config import PREPROCESSING_DIR + # It is not recommended training this yourself or to finetune this, but you can. # The recommended use is to download the pretrained model from the GitHub release # page and finetune to your desired data diff --git a/Recipes/ToucanTTS_Massive_stage3.py b/Recipes/ToucanTTS_Massive_stage3.py index 586f168f..b319fe9b 100644 --- a/Recipes/ToucanTTS_Massive_stage3.py +++ b/Recipes/ToucanTTS_Massive_stage3.py @@ -9,17 +9,19 @@ import torch import torch.multiprocessing import wandb -from torch.utils.data import ConcatDataset -from Modules.ToucanTTS.ToucanTTS import ToucanTTS -from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop -from Utility.corpus_preparation import prepare_tts_corpus from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR -from Utility.storage_config import PREPROCESSING_DIR def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id, gpu_count): + from torch.utils.data import ConcatDataset + + from Modules.ToucanTTS.ToucanTTS import ToucanTTS + from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop + from Utility.corpus_preparation import prepare_tts_corpus + from Utility.storage_config import MODELS_DIR + from Utility.storage_config import PREPROCESSING_DIR + # It is not recommended training this yourself or to finetune this, but you can. # The recommended use is to download the pretrained model from the GitHub release # page and finetune to your desired data diff --git a/Recipes/ToucanTTS_Nancy.py b/Recipes/ToucanTTS_Nancy.py index 217e8304..b0e196ac 100644 --- a/Recipes/ToucanTTS_Nancy.py +++ b/Recipes/ToucanTTS_Nancy.py @@ -3,15 +3,18 @@ import torch import wandb -from Modules.ToucanTTS.ToucanTTS import ToucanTTS -from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop -from Utility.corpus_preparation import prepare_tts_corpus from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR -from Utility.storage_config import PREPROCESSING_DIR def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id, gpu_count): + from torch.utils.data import ConcatDataset + + from Modules.ToucanTTS.ToucanTTS import ToucanTTS + from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop + from Utility.corpus_preparation import prepare_tts_corpus + from Utility.storage_config import MODELS_DIR + from Utility.storage_config import PREPROCESSING_DIR + if gpu_id == "cpu": device = torch.device("cpu") else: @@ -62,6 +65,9 @@ def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb train_loop(net=model, datasets=[train_set], device=device, + warmup_steps=4000, + steps=200000, + batch_size=16, save_directory=save_dir, eval_lang="eng", path_to_checkpoint=resume_checkpoint, @@ -71,4 +77,4 @@ def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb train_samplers=[train_sampler], gpu_count=gpu_count) if use_wandb: - wandb.finish() \ No newline at end of file + wandb.finish() diff --git a/Recipes/finetuning_example_multilingual.py b/Recipes/finetuning_example_multilingual.py index 02dd0572..e18d70c4 100644 --- a/Recipes/finetuning_example_multilingual.py +++ b/Recipes/finetuning_example_multilingual.py @@ -8,17 +8,20 @@ import torch import wandb -from torch.utils.data import ConcatDataset -from Modules.ToucanTTS.ToucanTTS import ToucanTTS -from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop -from Utility.corpus_preparation import prepare_tts_corpus from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR -from Utility.storage_config import PREPROCESSING_DIR def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id, gpu_count): + from huggingface_hub import hf_hub_download + from torch.utils.data import ConcatDataset + + from Modules.ToucanTTS.ToucanTTS import ToucanTTS + from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop + from Utility.corpus_preparation import prepare_tts_corpus + from Utility.storage_config import MODELS_DIR + from Utility.storage_config import PREPROCESSING_DIR + if gpu_id == "cpu": device = torch.device("cpu") else: @@ -87,7 +90,7 @@ def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb warmup_steps=500, lr=1e-5, # if you have enough data (over ~1000 datapoints) you can increase this up to 1e-4 and it will still be stable, but learn quicker. # DOWNLOAD THESE INITIALIZATION MODELS FROM THE RELEASE PAGE OF THE GITHUB OR RUN THE DOWNLOADER SCRIPT TO GET THEM AUTOMATICALLY - path_to_checkpoint=os.path.join(MODELS_DIR, "ToucanTTS_Meta", "best.pt") if resume_checkpoint is None else resume_checkpoint, + path_to_checkpoint=hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="ToucanTTS.pt") if resume_checkpoint is None else resume_checkpoint, fine_tune=True if resume_checkpoint is None and not resume else finetune, resume=resume, steps=5000, diff --git a/Recipes/finetuning_example_simple.py b/Recipes/finetuning_example_simple.py index 65c68e52..b1afd269 100644 --- a/Recipes/finetuning_example_simple.py +++ b/Recipes/finetuning_example_simple.py @@ -8,17 +8,20 @@ import torch import wandb -from torch.utils.data import ConcatDataset -from Modules.ToucanTTS.ToucanTTS import ToucanTTS -from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop -from Utility.corpus_preparation import prepare_tts_corpus from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR -from Utility.storage_config import PREPROCESSING_DIR def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id, gpu_count): + from huggingface_hub import hf_hub_download + from torch.utils.data import ConcatDataset + + from Modules.ToucanTTS.ToucanTTS import ToucanTTS + from Modules.ToucanTTS.toucantts_train_loop_arbiter import train_loop + from Utility.corpus_preparation import prepare_tts_corpus + from Utility.storage_config import MODELS_DIR + from Utility.storage_config import PREPROCESSING_DIR + if gpu_id == "cpu": device = torch.device("cpu") else: @@ -57,7 +60,7 @@ def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb warmup_steps=500, lr=1e-5, # if you have enough data (over ~1000 datapoints) you can increase this up to 1e-4 and it will still be stable, but learn quicker. # DOWNLOAD THESE INITIALIZATION MODELS FROM THE RELEASE PAGE OF THE GITHUB OR RUN THE DOWNLOADER SCRIPT TO GET THEM AUTOMATICALLY - path_to_checkpoint=os.path.join(MODELS_DIR, "ToucanTTS_Meta", "best.pt") if resume_checkpoint is None else resume_checkpoint, + path_to_checkpoint=hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="ToucanTTS.pt") if resume_checkpoint is None else resume_checkpoint, fine_tune=True if resume_checkpoint is None and not resume else finetune, resume=resume, steps=5000, diff --git a/requirements.txt b/requirements.txt index bba186f1..5c2e0ba6 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/run_advanced_GUI_demo.py b/run_advanced_GUI_demo.py new file mode 100644 index 00000000..9d264aac --- /dev/null +++ b/run_advanced_GUI_demo.py @@ -0,0 +1,645 @@ +import sys + +import numpy as np +import pyqtgraph as pg +import scipy.io.wavfile +import sounddevice +import torch.cuda +from PyQt5.QtCore import QTimer +from PyQt5.QtCore import Qt +from PyQt5.QtCore import pyqtSignal +from PyQt5.QtGui import QColor +from PyQt5.QtGui import QCursor +from PyQt5.QtGui import QFont +from PyQt5.QtGui import QPen +from PyQt5.QtWidgets import QApplication +from PyQt5.QtWidgets import QComboBox +from PyQt5.QtWidgets import QFileDialog +from PyQt5.QtWidgets import QHBoxLayout +from PyQt5.QtWidgets import QLineEdit +from PyQt5.QtWidgets import QMainWindow +from PyQt5.QtWidgets import QMessageBox +from PyQt5.QtWidgets import QPushButton +from PyQt5.QtWidgets import QVBoxLayout +from PyQt5.QtWidgets import QWidget +from huggingface_hub import hf_hub_download + +from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface +from Utility.utils import load_json_from_path + + +class DraggableScatter(pg.ScatterPlotItem): + pointMoved = pyqtSignal(int, float) # Emits index and new y-value + + def __init__(self, x, y, pen=None, brush=None, size=3, **kwargs): + super().__init__(x=x, y=y, pen=pen, brush=brush, size=size, **kwargs) + self.setAcceptHoverEvents(True) + self.dragging = False + self.selected_point = None + self.x = list(x) + self.y = list(y) + + def getViewBox(self): + """ + Traverse up the parent hierarchy to locate the ViewBox. + Returns the ViewBox if found, else None. + """ + parent = self.parentItem() + while parent is not None: + if isinstance(parent, pg.ViewBox): + return parent + parent = parent.parentItem() + return None + + def mousePressEvent(self, event): + threshold = 100 + if event.button() == Qt.LeftButton: + vb = self.getViewBox() + if vb is None: + super().mousePressEvent(event) + return + mousePoint = vb.mapSceneToView(event.scenePos()) + x_click = mousePoint.x() + # Find the closest point + min_dist = float('inf') + closest_idx = None + for i, (x, y) in enumerate(zip(self.x, self.y)): + dist = abs(x - x_click) + if dist < min_dist: + min_dist = dist + closest_idx = i + if min_dist < threshold: + self.selected_point = closest_idx + self.dragging = True + event.accept() + return + super().mousePressEvent(event) + + def mouseMoveEvent(self, event): + if self.dragging and self.selected_point is not None: + vb = self.getViewBox() + if vb is None: + super().mouseMoveEvent(event) + return + mousePoint = vb.mapSceneToView(event.scenePos()) + new_y = mousePoint.y() + if 0 < new_y < 2: + self.y[self.selected_point] = new_y + self.setData(x=self.x, y=self.y) + self.pointMoved.emit(self.selected_point, new_y) + event.accept() + return + super().mouseMoveEvent(event) + + def mouseReleaseEvent(self, event): + if self.dragging: + self.dragging = False + self.selected_point = None + event.accept() + return + super().mouseReleaseEvent(event) + + def hoverEvent(self, event): + threshold = 100 + if event.isExit(): + self.setCursor(Qt.ArrowCursor) + else: + vb = self.getViewBox() + if vb is None: + self.setCursor(Qt.ArrowCursor) + return + mousePoint = vb.mapSceneToView(event.scenePos()) + x_hover = mousePoint.x() + # Check if hovering near a point + min_dist = float('inf') + for i, (x, y) in enumerate(zip(self.x, self.y)): + dist = abs(x - x_hover) + if dist < min_dist: + min_dist = dist + if min_dist < threshold: + self.setCursor(QCursor(Qt.OpenHandCursor)) + else: + self.setCursor(QCursor(Qt.ArrowCursor)) + + +class TTSInterface(QMainWindow): + def __init__(self, tts_interface: ToucanTTSInterface): + super().__init__() + + path_to_iso_list = hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="iso_to_fullname.json") + iso_to_name = load_json_from_path(path_to_iso_list) + self.name_to_iso = dict() + for iso in iso_to_name: + self.name_to_iso[iso_to_name[iso]] = iso + text_selection = [iso_to_name[iso_code] for iso_code in iso_to_name] + self.tts_backend = tts_interface + + # Define Placeholders + self.word_boundaries = [] + self.pitch_curve = None + self.phonemes = [] + self.durations = None + self.pitch = None + self.audio_file_path = None + self.result_audio = None + self.min_duration = 1 + + self.setWindowTitle("TTS Model Interface") + self.setGeometry(100, 100, 1200, 900) + + self.central_widget = QWidget() + self.setCentralWidget(self.central_widget) + self.main_layout = QVBoxLayout() + self.central_widget.setLayout(self.main_layout) + self.main_layout.setSpacing(15) # spacing between widgets + self.main_layout.setContentsMargins(20, 50, 20, 30) # Left, Top, Right, Bottom + + # Add Text Input + self.text_input_layout = QHBoxLayout() + self.dropdown_box = QComboBox(self) + self.dropdown_box.addItems(text_selection) # Add your options here + self.dropdown_box.setCurrentText("English") + self.dropdown_box.currentIndexChanged.connect(self.on_user_input_changed) + self.text_input_layout.addWidget(self.dropdown_box) + + self.text_input = QLineEdit() + self.text_input.setPlaceholderText("Enter the text you want to be read here...") + self.text_input.textChanged.connect(self.on_user_input_changed) + self.text_input_layout.addWidget(self.text_input) + self.main_layout.insertLayout(0, self.text_input_layout) + self.text_input.setFocus() + self.text_input.setText("") + + # Initialize plots + self.init_plots() + + # Initialize buttons + self.init_controls() + + # Initialize Timer for TTS Cooldown + self.tts_timer = QTimer() + self.tts_timer.setSingleShot(True) + self.tts_timer.timeout.connect(self.run_tts) + self.tts_update_required = False + + def clear_all_widgets(self): + self.spectrogram_view.setParent(None) + self.pitch_plot.setParent(None) + self.generate_button.setParent(None) + self.load_audio_button.setParent(None) + self.save_audio_button.setParent(None) + self.play_audio_button.setParent(None) + + def init_plots(self): + # Spectrogram Plot + self.spectrogram_view = pg.PlotWidget(background="#f5f5f5") + self.spectrogram_view.setLabel('left', 'Frequency Buckets', units='') + self.spectrogram_view.setLabel('bottom', 'Phonemes', units='') + self.main_layout.addWidget(self.spectrogram_view) + + # Pitch Plot + self.pitch_plot = pg.PlotWidget(background="#f5f5f5") + self.pitch_plot.setLabel('left', 'Intonation', units='') + self.pitch_plot.setLabel('bottom', 'Phonemes', units='') + self.main_layout.addWidget(self.pitch_plot) + + def load_data(self, durations, pitch, spectrogram): + + durations = remove_indexes(durations, self.word_boundaries) + pitch = remove_indexes(pitch, self.word_boundaries) + + self.durations = durations + self.cumulative_durations = np.cumsum(self.durations) + self.pitch = pitch + self.spectrogram = spectrogram + + # Display Spectrogram + self.spectrogram_view.setLimits(xMin=0, xMax=self.cumulative_durations[-1] + 10, yMin=0, yMax=1000) # Adjust as per your data + self.spectrogram_view.enableAutoRange(axis=pg.ViewBox.XYAxes, enable=True) + self.spectrogram_view.setMouseEnabled(x=False, y=False) # Disable panning and zooming + img = pg.ImageItem(self.spectrogram) + self.spectrogram_view.addItem(img) + img.setLookupTable(pg.colormap.get('GnBu', source='matplotlib').getLookupTable()) + spectrogram_ticks = self.get_phoneme_ticks(self.cumulative_durations) + self.spectrogram_view.getAxis('bottom').setTicks([spectrogram_ticks]) + spectrogram_label_color = QColor('#006400') + self.spectrogram_view.getAxis('bottom').setTextPen(QPen(spectrogram_label_color)) + self.spectrogram_view.getAxis('bottom').setStyle(tickFont=QFont('Times New Roman', 16)) + self.spectrogram_view.getAxis('left').setTextPen(QPen(QColor('#f5f5f5'))) + + # Display Pitch + self.pitch_curve = self.pitch_plot.plot(self.cumulative_durations, self.pitch, pen=pg.mkPen('#B8860B', width=4), name='Pitch') + self.pitch_plot.setMouseEnabled(x=False, y=False) # Disable panning and zooming + pitch_ticks = self.get_phoneme_ticks(self.cumulative_durations, for_pitch=True) + self.pitch_plot.getAxis('bottom').setTicks([pitch_ticks]) + pitch_label_color = QColor('#006400') + self.pitch_plot.getAxis('bottom').setTextPen(QPen(pitch_label_color)) + self.pitch_plot.getAxis('bottom').setStyle(tickFont=QFont('Times New Roman', 16)) + self.pitch_plot.getAxis('left').setTextPen(QPen(QColor('#f5f5f5'))) + + # Display Durations + self.duration_lines = [] + for i, cum_dur in enumerate(self.cumulative_durations): + line = pg.InfiniteLine(pos=cum_dur, angle=90, pen=pg.mkPen('orange', width=3)) + self.spectrogram_view.addItem(line) + line.setMovable(True) + # Use lambda with default argument to capture current i + line.sigPositionChanged.connect(lambda _, idx=i: self.on_duration_changed(idx)) + self.duration_lines.append(line) + + self.enable_interactions() + + def get_phoneme_ticks(self, cumulative_durations, for_pitch=False): + """ + Create ticks for phoneme labels centered between durations. + """ + ticks = [] + previous = 0 + for i, cum_dur in enumerate(cumulative_durations): + if for_pitch: + ticks.append((cum_dur, self.phonemes[i])) + previous = cum_dur + else: + if i == 0: + center = cum_dur / 2 + else: + center = (previous + cum_dur) / 2 + ticks.append((center, self.phonemes[i])) + previous = cum_dur + return ticks + + def init_controls(self): + # Main vertical layout for controls + self.controls_layout = QVBoxLayout() + self.main_layout.addLayout(self.controls_layout) + + # Lower row layout for buttons + self.lower_row = QHBoxLayout() + self.controls_layout.addLayout(self.lower_row) + + self.generate_button = QPushButton("Generate new Prosody") + self.generate_button.clicked.connect(self.generate_new_prosody) + self.lower_row.addWidget(self.generate_button) + + self.load_audio_button = QPushButton("Load Example of Voice to Mimic") + self.load_audio_button.clicked.connect(self.load_audio_file) + self.lower_row.addWidget(self.load_audio_button) + + self.save_audio_button = QPushButton("Save Audio File") + self.save_audio_button.clicked.connect(self.save_audio_file) + self.lower_row.addWidget(self.save_audio_button) + + self.play_audio_button = QPushButton("Play Audio") + self.play_audio_button.clicked.connect(self.play_audio) + self.lower_row.addWidget(self.play_audio_button) + + def enable_interactions(self): + x_pitch = self.cumulative_durations.copy() + y_pitch = self.pitch.copy() + self.pitch_scatter = DraggableScatter(x_pitch, + y_pitch, + pen=pg.mkPen(None), + brush=pg.mkBrush(218, 165, 32, 255), + size=18, ) + self.pitch_scatter.pointMoved.connect(self.on_pitch_point_moved) + self.pitch_plot.addItem(self.pitch_scatter) + self.pitch_plot.showGrid(x=True, y=False, alpha=0.1) + self.pitch_plot.setYRange(0, 2) + + def on_duration_changed(self, idx): + """ + Moving a duration line adjusts the position of that line and all subsequent lines. + Ensures that durations do not become negative. + """ + min_duration = self.min_duration + + # Get new position of the moved line + new_pos = self.duration_lines[idx].value() + + # Calculate the minimum allowed position + if idx == 0: + min_allowed = min_duration + else: + min_allowed = self.duration_lines[idx - 1].value() + min_duration + + # Clamp new_pos + if new_pos < min_allowed: + new_pos = min_allowed + + # If the new_pos was clamped, update the line's position without emitting signal again + if new_pos != self.duration_lines[idx].value(): + self.duration_lines[idx].blockSignals(True) + self.duration_lines[idx].setValue(new_pos) + self.duration_lines[idx].blockSignals(False) + + # Calculate the delta change + delta = new_pos - self.cumulative_durations[idx] + + # Update current and subsequent cumulative durations + for i in range(idx, len(self.cumulative_durations)): + self.cumulative_durations[i] += delta + self.duration_lines[i].blockSignals(True) + self.duration_lines[i].setValue(self.cumulative_durations[i]) + self.duration_lines[i].blockSignals(False) + + # Update durations based on cumulative durations + self.durations = np.diff(np.insert(self.cumulative_durations, 0, 0)).tolist() + + # print(f"Updated Durations: {self.durations}") + + # Update pitch curve + self.pitch_curve.setData(self.cumulative_durations, self.pitch) + + # Update pitch scatter points + self.pitch_scatter.setData(x=self.cumulative_durations, y=self.pitch) + self.pitch_scatter.x = self.cumulative_durations + + # Update phoneme ticks + spectrogram_ticks = self.get_phoneme_ticks(self.cumulative_durations) + self.spectrogram_view.getAxis('bottom').setTicks([spectrogram_ticks]) + self.pitch_plot.getAxis('bottom').setTicks([spectrogram_ticks]) + + # Update spectrogram's X-axis limits + self.spectrogram_view.setLimits(xMin=0, xMax=self.cumulative_durations[-1] + 10) # Added buffer + + # Mark that an update is required + self.mark_tts_update() + + def on_pitch_point_moved(self, index, new_y): + # Update the pitch array with the new y-value + self.pitch[index] = new_y + # print(f"Pitch point {index} moved to {new_y:.2f} Hz") + # Update the pitch curve line + self.pitch_curve.setData(self.cumulative_durations, self.pitch) + # Update the scatter points' y-values (x remains the same) + self.pitch_scatter.y[index] = new_y + self.pitch_scatter.setData(x=self.cumulative_durations, y=self.pitch) + # Mark that an update is required + self.mark_tts_update() + + def on_user_input_changed(self, text): + """ + Handle changes in the text input field. + """ + # print(f"User input changed: {text}") + # Mark that an update is required + self.mark_tts_update() + + def generate_new_prosody(self): + """ + Generate new prosody. + """ + if self.text_input.text().strip() == "": + return + wave, mel, durations, pitch = self.tts_backend(text=self.text_input.text(), + view=False, + duration_scaling_factor=1.0, + pitch_variance_scale=1.0, + energy_variance_scale=1.0, + pause_duration_scaling_factor=1.0, + durations=None, + pitch=None, + energy=None, + input_is_phones=False, + return_plot_as_filepath=False, + loudness_in_db=-29.0, + prosody_creativity=0.8, + return_everything=True) + # reset and clear everything + self.clear_all_widgets() + self.init_plots() + self.init_controls() + + self.load_data(durations=durations.cpu().numpy(), pitch=pitch.cpu().numpy(), spectrogram=mel.cpu().transpose(0, 1).numpy()) + + self.update_result_audio(wave) + self.cumulative_durations = np.cumsum(self.durations) + + # Update scatter points + self.pitch_scatter.setData(x=self.cumulative_durations, y=self.pitch) + + # Update curves + self.pitch_curve.setData(self.cumulative_durations, self.pitch) + + # Update duration lines positions + for i, line in enumerate(self.duration_lines): + line.blockSignals(True) + line.setValue(self.cumulative_durations[i]) + line.blockSignals(False) + + # Update phoneme ticks + self.spectrogram_view.getAxis('bottom').setTicks([self.get_phoneme_ticks(self.cumulative_durations)]) + self.pitch_plot.getAxis('bottom').setTicks([self.get_phoneme_ticks(self.cumulative_durations, for_pitch=True)]) + + # print("Generated new random prosody.") + + def load_audio_file(self): + """ + Open a file dialog to load an audio file. + """ + options = QFileDialog.Options() + options |= QFileDialog.ReadOnly + file_filter = "Audio Files (*.wav *.mp3 *.flac *.ogg);;All Files (*)" + file_path, _ = QFileDialog.getOpenFileName(self, "Load Example of Voice to Mimic", "", file_filter, options=options) + if file_path: + self.audio_file_path = file_path + # print(f"Loaded audio file: {self.audio_file_path}") + # Here, you can add code to process the loaded audio if needed + self.mark_tts_update() + + def save_audio_file(self): + """ + Open a file dialog to save the resulting audio NumPy array. + """ + if self.result_audio is None: + QMessageBox.warning(self, "Save Error", "No resulting audio to save.") + return + + options = QFileDialog.Options() + options |= QFileDialog.DontUseNativeDialog + file_filter = "WAV Files (*.wav);;All Files (*)" + save_path, _ = QFileDialog.getSaveFileName(self, "Save Audio File", "", file_filter, options=options) + if save_path: + try: + sample_rate = 24000 + if "." not in save_path: + save_path = save_path + ".wav" + + # Normalize the audio if it's not in the correct range + if self.result_audio.dtype != np.int16: + audio_normalized = np.int16(self.result_audio / np.max(np.abs(self.result_audio)) * 32767) + else: + audio_normalized = self.result_audio + + # Save using scipy.io.wavfile + scipy.io.wavfile.write(save_path, sample_rate, audio_normalized) + # print(f"Audio saved successfully at: {save_path}") + QMessageBox.information(self, "Save Successful", f"Audio saved successfully at:\n{save_path}") + except Exception as e: + print(f"Error saving audio: {e}") + QMessageBox.critical(self, "Save Error", f"Failed to save audio:\n{e}") + + def play_audio(self): + # print("playing current audio...") + if self.result_audio is not None: + sounddevice.play(self.result_audio, samplerate=24000) + + def update_result_audio(self, audio_array): + """ + Update the resulting audio array. + This method should be called with your TTS model's output. + """ + self.result_audio = audio_array + # print("Resulting audio updated.") + + def mark_tts_update(self): + """ + Marks that a TTS update is required and starts/resets the timer. + """ + self.tts_update_required = True + self.tts_timer.start(800) # 800 milliseconds delay before the model starts to compute something + + def run_tts(self): + """ + Dummy method to simulate running the TTS model. + This should be replaced with actual TTS integration. + """ + text = self.text_input.text() + while self.tts_update_required: + self.tts_update_required = False + if text.strip() == "": + return + + # print(f"Running TTS with text: {text}") + + # reset and clear everything + self.clear_all_widgets() + self.init_plots() + self.init_controls() + + if self.audio_file_path is not None: + self.tts_backend.set_utterance_embedding(self.audio_file_path) + + self.tts_backend.set_language(self.name_to_iso[self.dropdown_box.currentText()]) + + phonemes = self.tts_backend.text2phone.get_phone_string(text=text) + self.phonemes = phonemes.replace(" ", "") + + forced_durations = None if self.durations is None or len(self.durations) != len(self.phonemes) else torch.LongTensor(insert_zeros_at_indexes(self.durations, self.word_boundaries)).unsqueeze(0) + forced_pitch = None if self.pitch is None or len(self.pitch) != len(self.phonemes) else torch.tensor(insert_zeros_at_indexes(self.pitch, self.word_boundaries)).unsqueeze(0) + + wave, mel, durations, pitch = self.tts_backend(text, + view=False, + duration_scaling_factor=1.0, + pitch_variance_scale=1.0, + energy_variance_scale=1.0, + pause_duration_scaling_factor=1.0, + durations=forced_durations, + pitch=forced_pitch, + energy=None, + input_is_phones=False, + return_plot_as_filepath=False, + loudness_in_db=-29.0, + prosody_creativity=0.1, + return_everything=True) + + self.word_boundaries = find_zero_indexes(durations) + + self.load_data(durations=durations.cpu().numpy(), pitch=pitch.cpu().numpy(), spectrogram=mel.cpu().transpose(0, 1).numpy()) + + self.update_result_audio(wave) + # print("TTS run completed and plots/audio updated.") + + +def main(): + app = QApplication(sys.argv) + stylesheet = """ + QMainWindow { + background-color: #f5f5f5; + color: #333333; + font-family: system-ui; + } + + QWidget { + background-color: #f5f5f5; + color: #333333; + font-size: 14px; + } + + QPushButton { + background-color: #b9770e; + border: 1px solid #ffffff; + color: #ffffff; + padding: 8px 16px; + border-radius: 10px; + } + + QPushButton:hover { + background-color: #228B22; + } + + QPushButton:pressed { + background-color: #006400; + } + + QSlider::groove:horizontal { + border: 1px solid #bbb; + background: #d3d3d3; + height: 8px; + border-radius: 4px; + } + + QSlider::handle:horizontal { + background: #D2691E; + border: 1px solid #D2691E; + width: 26px; + margin: -5px 0; + border-radius: 9px; + } + + QLabel { + color: #006400; + } + + QLineEdit { + background-color: #EEE8AA; + border: 10px solid #DAA520; + padding: 12px; + border-radius: 20px; + } + + QLineEdit:focus { + background-color: #EEE8AA; + border: 10px solid #DAA520; + padding: 12px; + border-radius: 20px; + } + """ + app.setStyleSheet(stylesheet) + + interface = TTSInterface(ToucanTTSInterface(device="cuda" if torch.cuda.is_available() else "cpu")) + interface.show() + sys.exit(app.exec_()) + + +def find_zero_indexes(numbers): + zero_indexes = [index for index, value in enumerate(numbers) if value == 0] + return zero_indexes + + +def remove_indexes(data_list, indexes_to_remove): + result = [value for i, value in enumerate(data_list) if i not in indexes_to_remove] + return result + + +def insert_zeros_at_indexes(data_list, indexes_to_add_zeros): + if len(indexes_to_add_zeros) == 0: + return data_list + result = data_list[:] + for index in sorted(indexes_to_add_zeros): + result.insert(index, 0) + return result + + +if __name__ == "__main__": + main() diff --git a/run_GUI_demo.py b/run_simple_GUI_demo.py similarity index 95% rename from run_GUI_demo.py rename to run_simple_GUI_demo.py index 87d42d60..3e6d9e43 100644 --- a/run_GUI_demo.py +++ b/run_simple_GUI_demo.py @@ -55,9 +55,9 @@ def __init__(self, outputs=[gr.Audio(type="numpy", label="Speech"), gr.Image(label="Visualization")], title=title, - theme="default", allow_flagging="never", - article=article) + article=article, + theme=gr.themes.Default(primary_hue="amber", secondary_hue="orange", spacing_size=gr.themes.sizes.spacing_lg, radius_size=gr.themes.sizes.radius_lg)) self.iface.launch() def read(self, @@ -88,7 +88,7 @@ def read(self, 0., 0., 0., - -12.) + -24.) return (sr, float2pcm(wav)), fig