Skip to content

Commit

Permalink
Improved batching and added reference text ending
Browse files Browse the repository at this point in the history
  • Loading branch information
jpgallegoar committed Oct 15, 2024
1 parent ced7864 commit 028421e
Showing 1 changed file with 41 additions and 99 deletions.
140 changes: 41 additions & 99 deletions gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,101 +112,34 @@ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
"E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
)

def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
if len(text.encode('utf-8')) <= max_chars:
return [text]
if text[-1] not in ['。', '.', '!', '!', '?', '?']:
text += '.'

sentences = re.split('([。.!?!?])', text)
sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]

batches = []
current_batch = ""

def split_by_words(text):
words = text.split()
current_word_part = ""
word_batches = []
for word in words:
if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
current_word_part += word + ' '
else:
if current_word_part:
# Try to find a suitable split word
for split_word in split_words:
split_index = current_word_part.rfind(' ' + split_word + ' ')
if split_index != -1:
word_batches.append(current_word_part[:split_index].strip())
current_word_part = current_word_part[split_index:].strip() + ' '
break
else:
# If no suitable split word found, just append the current part
word_batches.append(current_word_part.strip())
current_word_part = ""
current_word_part += word + ' '
if current_word_part:
word_batches.append(current_word_part.strip())
return word_batches
def chunk_text(text, max_chars=135):
"""
Splits the input text into chunks, each with a maximum number of characters.
Args:
text (str): The text to be split.
max_chars (int): The maximum number of characters per chunk.
Returns:
List[str]: A list of text chunks.
"""
chunks = []
current_chunk = ""
# Split the text into sentences based on punctuation followed by whitespace
sentences = re.split(r'(?<=[;:,.!?])\s+', text)

for sentence in sentences:
if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
current_batch += sentence
if len(current_chunk) + len(sentence) <= max_chars:
current_chunk += sentence + " "
else:
# If adding this sentence would exceed the limit
if current_batch:
batches.append(current_batch)
current_batch = ""

# If the sentence itself is longer than max_chars, split it
if len(sentence.encode('utf-8')) > max_chars:
# First, try to split by colon
colon_parts = sentence.split(':')
if len(colon_parts) > 1:
for part in colon_parts:
if len(part.encode('utf-8')) <= max_chars:
batches.append(part)
else:
# If colon part is still too long, split by comma
comma_parts = re.split('[,,]', part)
if len(comma_parts) > 1:
current_comma_part = ""
for comma_part in comma_parts:
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
current_comma_part += comma_part + ','
else:
if current_comma_part:
batches.append(current_comma_part.rstrip(','))
current_comma_part = comma_part + ','
if current_comma_part:
batches.append(current_comma_part.rstrip(','))
else:
# If no comma, split by words
batches.extend(split_by_words(part))
else:
# If no colon, split by comma
comma_parts = re.split('[,,]', sentence)
if len(comma_parts) > 1:
current_comma_part = ""
for comma_part in comma_parts:
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
current_comma_part += comma_part + ','
else:
if current_comma_part:
batches.append(current_comma_part.rstrip(','))
current_comma_part = comma_part + ','
if current_comma_part:
batches.append(current_comma_part.rstrip(','))
else:
# If no comma, split by words
batches.extend(split_by_words(sentence))
else:
current_batch = sentence

if current_batch:
batches.append(current_batch)

return batches
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + " "

if current_chunk:
chunks.append(current_chunk.strip())

return chunks

@gpu_decorator
def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, progress=gr.Progress()):
Expand Down Expand Up @@ -306,7 +239,9 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
aseg = AudioSegment.from_file(ref_audio_orig)

non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
non_silent_wave += non_silent_seg
Expand All @@ -332,13 +267,20 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
else:
gr.Info("Using custom reference text...")

# Split the input text into batches
# Add the functionality to ensure it ends with ". "
if not ref_text.endswith(". "):
if ref_text.endswith("."):
ref_text += " "
else:
ref_text += ". "

audio, sr = torchaudio.load(ref_audio)
max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)

# Use the new chunk_text function to split gen_text
gen_text_batches = chunk_text(gen_text, max_chars=135)
print('ref_text', ref_text)
for i, gen_text in enumerate(gen_text_batches):
print(f'gen_text {i}', gen_text)
for i, batch_text in enumerate(gen_text_batches):
print(f'gen_text {i}', batch_text)

gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
Expand Down Expand Up @@ -823,4 +765,4 @@ def main(port, host, share, api):


if __name__ == "__main__":
main()
main()

0 comments on commit 028421e

Please sign in to comment.