diff --git a/src/f5_tts/infer/infer_gradio.py b/src/f5_tts/infer/infer_gradio.py index 7c9fdf84..7095180e 100644 --- a/src/f5_tts/infer/infer_gradio.py +++ b/src/f5_tts/infer/infer_gradio.py @@ -3,6 +3,7 @@ import re import tempfile +from collections import OrderedDict import click import gradio as gr @@ -116,7 +117,7 @@ def infer( spectrogram_path = tmp_spectrogram.name save_spectrogram(combined_spectrogram, spectrogram_path) - return (final_sample_rate, final_wave), spectrogram_path + return (final_sample_rate, final_wave), spectrogram_path, ref_text with gr.Blocks() as app_credits: @@ -172,7 +173,7 @@ def basic_tts( cross_fade_duration_slider, speed_slider, ): - return infer( + audio_out, spectrogram_path, ref_text_out = infer( ref_audio_input, ref_text_input, gen_text_input, @@ -181,6 +182,7 @@ def basic_tts( cross_fade_duration_slider, speed_slider, ) + return audio_out, spectrogram_path, gr.update(value=ref_text_out) generate_btn.click( basic_tts, @@ -192,7 +194,7 @@ def basic_tts( cross_fade_duration_slider, speed_slider, ], - outputs=[audio_output, spectrogram_output], + outputs=[audio_output, spectrogram_output, ref_text_input], ) @@ -262,26 +264,26 @@ def parse_speechtypes_text(gen_text): with gr.Row(): with gr.Column(): regular_name = gr.Textbox(value="Regular", label="Speech Type Name") - regular_insert = gr.Button("Insert", variant="secondary") + regular_insert = gr.Button("Insert Label", variant="secondary") regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath") regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2) - # Additional speech types (up to 99 more) + # Regular speech type (max 100) max_speech_types = 100 - speech_type_rows = [] - speech_type_names = [regular_name] - speech_type_audios = [] - speech_type_ref_texts = [] - speech_type_delete_btns = [] - speech_type_insert_btns = [] - speech_type_insert_btns.append(regular_insert) - + speech_type_rows = [] # 99 + speech_type_names = [regular_name] # 100 + speech_type_audios = [regular_audio] # 100 + speech_type_ref_texts = [regular_ref_text] # 100 + speech_type_delete_btns = [] # 99 + speech_type_insert_btns = [regular_insert] # 100 + + # Additional speech types (99 more) for i in range(max_speech_types - 1): with gr.Row(visible=False) as row: with gr.Column(): name_input = gr.Textbox(label="Speech Type Name") - delete_btn = gr.Button("Delete", variant="secondary") - insert_btn = gr.Button("Insert", variant="secondary") + delete_btn = gr.Button("Delete Type", variant="secondary") + insert_btn = gr.Button("Insert Label", variant="secondary") audio_input = gr.Audio(label="Reference Audio", type="filepath") ref_text_input = gr.Textbox(label="Reference Text", lines=2) speech_type_rows.append(row) @@ -295,22 +297,22 @@ def parse_speechtypes_text(gen_text): add_speech_type_btn = gr.Button("Add Speech Type") # Keep track of current number of speech types - speech_type_count = gr.State(value=0) + speech_type_count = gr.State(value=1) # Function to add a speech type def add_speech_type_fn(speech_type_count): - if speech_type_count < max_speech_types - 1: + if speech_type_count < max_speech_types: speech_type_count += 1 # Prepare updates for the rows row_updates = [] - for i in range(max_speech_types - 1): + for i in range(1, max_speech_types): if i < speech_type_count: row_updates.append(gr.update(visible=True)) else: row_updates.append(gr.update()) else: # Optionally, show a warning - row_updates = [gr.update() for _ in range(max_speech_types - 1)] + row_updates = [gr.update() for _ in range(1, max_speech_types)] return [speech_type_count] + row_updates add_speech_type_btn.click( @@ -323,13 +325,13 @@ def delete_speech_type_fn(speech_type_count): # Prepare updates row_updates = [] - for i in range(max_speech_types - 1): + for i in range(1, max_speech_types): if i == index: row_updates.append(gr.update(visible=False)) else: row_updates.append(gr.update()) - speech_type_count = max(0, speech_type_count - 1) + speech_type_count = max(1, speech_type_count) return [speech_type_count] + row_updates @@ -367,7 +369,7 @@ def insert_speech_type_fn(current_text, speech_type_name): with gr.Accordion("Advanced Settings", open=False): remove_silence_multistyle = gr.Checkbox( label="Remove Silences", - value=False, + value=True, ) # Generate button @@ -378,25 +380,25 @@ def insert_speech_type_fn(current_text, speech_type_name): @gpu_decorator def generate_multistyle_speech( - regular_audio, - regular_ref_text, gen_text, *args, ): - num_additional_speech_types = max_speech_types - 1 - speech_type_names_list = args[:num_additional_speech_types] - speech_type_audios_list = args[num_additional_speech_types : 2 * num_additional_speech_types] - speech_type_ref_texts_list = args[2 * num_additional_speech_types : 3 * num_additional_speech_types] - remove_silence = args[3 * num_additional_speech_types + 1] - + speech_type_names_list = args[:max_speech_types] + speech_type_audios_list = args[max_speech_types : 2 * max_speech_types] + speech_type_ref_texts_list = args[2 * max_speech_types : 3 * max_speech_types] + remove_silence = args[3 * max_speech_types] # Collect the speech types and their audios into a dict - speech_types = {"Regular": {"audio": regular_audio, "ref_text": regular_ref_text}} + speech_types = OrderedDict() + ref_text_idx = 0 for name_input, audio_input, ref_text_input in zip( speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list ): if name_input and audio_input: speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input} + else: + speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""} + ref_text_idx += 1 # Parse the gen_text into segments segments = parse_speechtypes_text(gen_text) @@ -419,26 +421,27 @@ def generate_multistyle_speech( ref_text = speech_types[current_style].get("ref_text", "") # Generate speech for this segment - audio, _ = infer( + audio_out, _, ref_text_out = infer( ref_audio, ref_text, text, tts_model_choice, remove_silence, 0, show_info=print ) # show_info=print no pull to top when generating - sr, audio_data = audio + sr, audio_data = audio_out generated_audio_segments.append(audio_data) + speech_types[current_style]["ref_text"] = ref_text_out # Concatenate all audio segments if generated_audio_segments: final_audio_data = np.concatenate(generated_audio_segments) - return (sr, final_audio_data) + return [(sr, final_audio_data)] + [ + gr.update(value=speech_types[style]["ref_text"]) for style in speech_types + ] else: gr.Warning("No audio generated.") - return None + return [None] + [gr.update(value=speech_types[style]["ref_text"]) for style in speech_types] generate_multistyle_btn.click( generate_multistyle_speech, inputs=[ - regular_audio, - regular_ref_text, gen_text_input_multistyle, ] + speech_type_names @@ -447,13 +450,12 @@ def generate_multistyle_speech( + [ remove_silence_multistyle, ], - outputs=audio_output_multistyle, + outputs=[audio_output_multistyle] + speech_type_ref_texts, ) # Validation function to disable Generate button if speech types are missing def validate_speech_types(gen_text, regular_name, *args): - num_additional_speech_types = max_speech_types - 1 - speech_type_names_list = args[:num_additional_speech_types] + speech_type_names_list = args[:max_speech_types] # Collect the speech types names speech_types_available = set() @@ -561,7 +563,7 @@ def load_chat_model(): label="Type your message", lines=1, ) - send_btn_chat = gr.Button("Send") + send_btn_chat = gr.Button("Send Message") clear_btn_chat = gr.Button("Clear Conversation") conversation_state = gr.State( @@ -607,7 +609,7 @@ def generate_audio_response(history, ref_audio, ref_text, remove_silence): if not last_ai_response: return None - audio_result, _ = infer( + audio_result, _, ref_text_out = infer( ref_audio, ref_text, last_ai_response, @@ -617,7 +619,7 @@ def generate_audio_response(history, ref_audio, ref_text, remove_silence): speed=1.0, show_info=print, # show_info=print no pull to top when generating ) - return audio_result + return audio_result, gr.update(value=ref_text_out) def clear_conversation(): """Reset the conversation""" @@ -641,7 +643,7 @@ def update_system_prompt(new_prompt): ).then( generate_audio_response, inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat], - outputs=[audio_output_chat], + outputs=[audio_output_chat, ref_text_chat], ).then( lambda: None, None, @@ -656,7 +658,7 @@ def update_system_prompt(new_prompt): ).then( generate_audio_response, inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat], - outputs=[audio_output_chat], + outputs=[audio_output_chat, ref_text_chat], ).then( lambda: None, None, @@ -671,7 +673,7 @@ def update_system_prompt(new_prompt): ).then( generate_audio_response, inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat], - outputs=[audio_output_chat], + outputs=[audio_output_chat, ref_text_chat], ).then( lambda: None, None, @@ -702,9 +704,9 @@ def update_system_prompt(new_prompt): * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching) * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS) -The checkpoints support English and Chinese. +The checkpoints currently support English and Chinese. -If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt. +If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result). **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.** """ @@ -729,7 +731,7 @@ def switch_tts_model(new_choice): gr.TabbedInterface( [app_tts, app_multistyle, app_chat, app_credits], - ["TTS", "Multi-Speech", "Voice-Chat", "Credits"], + ["Basic-TTS", "Multi-Speech", "Voice-Chat", "Credits"], )