Skip to content

Commit

Permalink
Update infer-gradio with ref_text auto-filling; minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
SWivid committed Nov 8, 2024
1 parent c33a83c commit 23409af
Showing 1 changed file with 51 additions and 49 deletions.
100 changes: 51 additions & 49 deletions src/f5_tts/infer/infer_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import re
import tempfile
from collections import OrderedDict

import click
import gradio as gr
Expand Down Expand Up @@ -116,7 +117,7 @@ def infer(
spectrogram_path = tmp_spectrogram.name
save_spectrogram(combined_spectrogram, spectrogram_path)

return (final_sample_rate, final_wave), spectrogram_path
return (final_sample_rate, final_wave), spectrogram_path, ref_text


with gr.Blocks() as app_credits:
Expand Down Expand Up @@ -172,7 +173,7 @@ def basic_tts(
cross_fade_duration_slider,
speed_slider,
):
return infer(
audio_out, spectrogram_path, ref_text_out = infer(
ref_audio_input,
ref_text_input,
gen_text_input,
Expand All @@ -181,6 +182,7 @@ def basic_tts(
cross_fade_duration_slider,
speed_slider,
)
return audio_out, spectrogram_path, gr.update(value=ref_text_out)

generate_btn.click(
basic_tts,
Expand All @@ -192,7 +194,7 @@ def basic_tts(
cross_fade_duration_slider,
speed_slider,
],
outputs=[audio_output, spectrogram_output],
outputs=[audio_output, spectrogram_output, ref_text_input],
)


Expand Down Expand Up @@ -262,26 +264,26 @@ def parse_speechtypes_text(gen_text):
with gr.Row():
with gr.Column():
regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
regular_insert = gr.Button("Insert", variant="secondary")
regular_insert = gr.Button("Insert Label", variant="secondary")
regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2)

# Additional speech types (up to 99 more)
# Regular speech type (max 100)
max_speech_types = 100
speech_type_rows = []
speech_type_names = [regular_name]
speech_type_audios = []
speech_type_ref_texts = []
speech_type_delete_btns = []
speech_type_insert_btns = []
speech_type_insert_btns.append(regular_insert)

speech_type_rows = [] # 99
speech_type_names = [regular_name] # 100
speech_type_audios = [regular_audio] # 100
speech_type_ref_texts = [regular_ref_text] # 100
speech_type_delete_btns = [] # 99
speech_type_insert_btns = [regular_insert] # 100

# Additional speech types (99 more)
for i in range(max_speech_types - 1):
with gr.Row(visible=False) as row:
with gr.Column():
name_input = gr.Textbox(label="Speech Type Name")
delete_btn = gr.Button("Delete", variant="secondary")
insert_btn = gr.Button("Insert", variant="secondary")
delete_btn = gr.Button("Delete Type", variant="secondary")
insert_btn = gr.Button("Insert Label", variant="secondary")
audio_input = gr.Audio(label="Reference Audio", type="filepath")
ref_text_input = gr.Textbox(label="Reference Text", lines=2)
speech_type_rows.append(row)
Expand All @@ -295,22 +297,22 @@ def parse_speechtypes_text(gen_text):
add_speech_type_btn = gr.Button("Add Speech Type")

# Keep track of current number of speech types
speech_type_count = gr.State(value=0)
speech_type_count = gr.State(value=1)

# Function to add a speech type
def add_speech_type_fn(speech_type_count):
if speech_type_count < max_speech_types - 1:
if speech_type_count < max_speech_types:
speech_type_count += 1
# Prepare updates for the rows
row_updates = []
for i in range(max_speech_types - 1):
for i in range(1, max_speech_types):
if i < speech_type_count:
row_updates.append(gr.update(visible=True))
else:
row_updates.append(gr.update())
else:
# Optionally, show a warning
row_updates = [gr.update() for _ in range(max_speech_types - 1)]
row_updates = [gr.update() for _ in range(1, max_speech_types)]
return [speech_type_count] + row_updates

add_speech_type_btn.click(
Expand All @@ -323,13 +325,13 @@ def delete_speech_type_fn(speech_type_count):
# Prepare updates
row_updates = []

for i in range(max_speech_types - 1):
for i in range(1, max_speech_types):
if i == index:
row_updates.append(gr.update(visible=False))
else:
row_updates.append(gr.update())

speech_type_count = max(0, speech_type_count - 1)
speech_type_count = max(1, speech_type_count)

return [speech_type_count] + row_updates

Expand Down Expand Up @@ -367,7 +369,7 @@ def insert_speech_type_fn(current_text, speech_type_name):
with gr.Accordion("Advanced Settings", open=False):
remove_silence_multistyle = gr.Checkbox(
label="Remove Silences",
value=False,
value=True,
)

# Generate button
Expand All @@ -378,25 +380,25 @@ def insert_speech_type_fn(current_text, speech_type_name):

@gpu_decorator
def generate_multistyle_speech(
regular_audio,
regular_ref_text,
gen_text,
*args,
):
num_additional_speech_types = max_speech_types - 1
speech_type_names_list = args[:num_additional_speech_types]
speech_type_audios_list = args[num_additional_speech_types : 2 * num_additional_speech_types]
speech_type_ref_texts_list = args[2 * num_additional_speech_types : 3 * num_additional_speech_types]
remove_silence = args[3 * num_additional_speech_types + 1]

speech_type_names_list = args[:max_speech_types]
speech_type_audios_list = args[max_speech_types : 2 * max_speech_types]
speech_type_ref_texts_list = args[2 * max_speech_types : 3 * max_speech_types]
remove_silence = args[3 * max_speech_types]
# Collect the speech types and their audios into a dict
speech_types = {"Regular": {"audio": regular_audio, "ref_text": regular_ref_text}}
speech_types = OrderedDict()

ref_text_idx = 0
for name_input, audio_input, ref_text_input in zip(
speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
):
if name_input and audio_input:
speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
else:
speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""}
ref_text_idx += 1

# Parse the gen_text into segments
segments = parse_speechtypes_text(gen_text)
Expand All @@ -419,26 +421,27 @@ def generate_multistyle_speech(
ref_text = speech_types[current_style].get("ref_text", "")

# Generate speech for this segment
audio, _ = infer(
audio_out, _, ref_text_out = infer(
ref_audio, ref_text, text, tts_model_choice, remove_silence, 0, show_info=print
) # show_info=print no pull to top when generating
sr, audio_data = audio
sr, audio_data = audio_out

generated_audio_segments.append(audio_data)
speech_types[current_style]["ref_text"] = ref_text_out

# Concatenate all audio segments
if generated_audio_segments:
final_audio_data = np.concatenate(generated_audio_segments)
return (sr, final_audio_data)
return [(sr, final_audio_data)] + [
gr.update(value=speech_types[style]["ref_text"]) for style in speech_types
]
else:
gr.Warning("No audio generated.")
return None
return [None] + [gr.update(value=speech_types[style]["ref_text"]) for style in speech_types]

generate_multistyle_btn.click(
generate_multistyle_speech,
inputs=[
regular_audio,
regular_ref_text,
gen_text_input_multistyle,
]
+ speech_type_names
Expand All @@ -447,13 +450,12 @@ def generate_multistyle_speech(
+ [
remove_silence_multistyle,
],
outputs=audio_output_multistyle,
outputs=[audio_output_multistyle] + speech_type_ref_texts,
)

# Validation function to disable Generate button if speech types are missing
def validate_speech_types(gen_text, regular_name, *args):
num_additional_speech_types = max_speech_types - 1
speech_type_names_list = args[:num_additional_speech_types]
speech_type_names_list = args[:max_speech_types]

# Collect the speech types names
speech_types_available = set()
Expand Down Expand Up @@ -561,7 +563,7 @@ def load_chat_model():
label="Type your message",
lines=1,
)
send_btn_chat = gr.Button("Send")
send_btn_chat = gr.Button("Send Message")
clear_btn_chat = gr.Button("Clear Conversation")

conversation_state = gr.State(
Expand Down Expand Up @@ -607,7 +609,7 @@ def generate_audio_response(history, ref_audio, ref_text, remove_silence):
if not last_ai_response:
return None

audio_result, _ = infer(
audio_result, _, ref_text_out = infer(
ref_audio,
ref_text,
last_ai_response,
Expand All @@ -617,7 +619,7 @@ def generate_audio_response(history, ref_audio, ref_text, remove_silence):
speed=1.0,
show_info=print, # show_info=print no pull to top when generating
)
return audio_result
return audio_result, gr.update(value=ref_text_out)

def clear_conversation():
"""Reset the conversation"""
Expand All @@ -641,7 +643,7 @@ def update_system_prompt(new_prompt):
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
outputs=[audio_output_chat],
outputs=[audio_output_chat, ref_text_chat],
).then(
lambda: None,
None,
Expand All @@ -656,7 +658,7 @@ def update_system_prompt(new_prompt):
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
outputs=[audio_output_chat],
outputs=[audio_output_chat, ref_text_chat],
).then(
lambda: None,
None,
Expand All @@ -671,7 +673,7 @@ def update_system_prompt(new_prompt):
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
outputs=[audio_output_chat],
outputs=[audio_output_chat, ref_text_chat],
).then(
lambda: None,
None,
Expand Down Expand Up @@ -702,9 +704,9 @@ def update_system_prompt(new_prompt):
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
The checkpoints support English and Chinese.
The checkpoints currently support English and Chinese.
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
"""
Expand All @@ -729,7 +731,7 @@ def switch_tts_model(new_choice):

gr.TabbedInterface(
[app_tts, app_multistyle, app_chat, app_credits],
["TTS", "Multi-Speech", "Voice-Chat", "Credits"],
["Basic-TTS", "Multi-Speech", "Voice-Chat", "Credits"],
)


Expand Down

0 comments on commit 23409af

Please sign in to comment.