Skip to content

Commit

Permalink
add api for easy use (#186)
Browse files Browse the repository at this point in the history
* add api
* update infer limits
  • Loading branch information
lpscr authored Oct 21, 2024
1 parent 0f9f878 commit 25cdc51
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 20 deletions.
117 changes: 117 additions & 0 deletions api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import soundfile as sf
import torch
import tqdm
from cached_path import cached_path

from model import DiT, UNetT
from model.utils import save_spectrogram

from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav


class F5TTS:
def __init__(
self,
model_type="F5-TTS",
ckpt_file="",
vocab_file="",
ode_method="euler",
use_ema=True,
local_path=None,
device=None,
):
# Initialize parameters
self.final_wave = None
self.target_sample_rate = 24000
self.n_mel_channels = 100
self.hop_length = 256
self.target_rms = 0.1

# Set device
self.device = device or (
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)

# Load models
self.load_vecoder_model(local_path)
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)

def load_vecoder_model(self, local_path):
self.vocos = load_vocoder(local_path is not None, local_path, self.device)

def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
if model_type == "F5-TTS":
if not ckpt_file:
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cls = DiT
elif model_type == "E2-TTS":
if not ckpt_file:
ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
model_cls = UNetT
else:
raise ValueError(f"Unknown model type: {model_type}")

self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device)

def export_wav(self, wav, file_wave, remove_silence=False):
if remove_silence:
remove_silence_for_generated_wav(file_wave)

sf.write(file_wave, wav, self.target_sample_rate)

def export_spectrogram(self, spect, file_spect):
save_spectrogram(spect, file_spect)

def infer(
self,
ref_file,
ref_text,
gen_text,
sway_sampling_coef=-1,
cfg_strength=2,
nfe_step=32,
speed=1.0,
fix_duration=None,
remove_silence=False,
file_wave=None,
file_spect=None,
cross_fade_duration=0.15,
show_info=print,
progress=tqdm,
):
wav, sr, spect = infer_process(
ref_file,
ref_text,
gen_text,
self.ema_model,
cross_fade_duration,
speed,
show_info,
progress,
nfe_step,
cfg_strength,
sway_sampling_coef,
fix_duration,
)

if file_wave is not None:
self.export_wav(wav, file_wave, remove_silence)

if file_spect is not None:
self.export_spectrogram(spect, file_spect)

return wav, sr, spect


if __name__ == "__main__":
f5tts = F5TTS()

wav, sr, spect = f5tts.infer(
ref_file="tests/ref_audio/test_en_1_ref_short.wav",
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
file_wave="tests/out.wav",
file_spect="tests/out.png",
)
74 changes: 54 additions & 20 deletions model/utils_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
nfe_step = 32 # 16, 32
cfg_strength = 2.0
ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0
fix_duration = None
# nfe_step = 32 # 16, 32
# cfg_strength = 2.0
# ode_method = "euler"
# sway_sampling_coef = -1.0
# speed = 1.0
# fix_duration = None

# -----------------------------------------

Expand Down Expand Up @@ -84,7 +84,7 @@ def chunk_text(text, max_chars=135):
# load vocoder


def load_vocoder(is_local=False, local_path=""):
def load_vocoder(is_local=False, local_path="", device=device):
if is_local:
print(f"Load vocos from local path {local_path}")
vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
Expand All @@ -100,14 +100,14 @@ def load_vocoder(is_local=False, local_path=""):
# load model for inference


def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method="euler", use_ema=True, device=device):
if vocab_file == "":
vocab_file = "Emilia_ZH_EN"
tokenizer = "pinyin"
else:
tokenizer = "custom"

print("\nvocab : ", vocab_file, tokenizer)
print("\nvocab : ", vocab_file)
print("tokenizer : ", tokenizer)
print("model : ", ckpt_path, "\n")

Expand All @@ -125,7 +125,7 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
vocab_char_map=vocab_char_map,
).to(device)

model = load_checkpoint(model, ckpt_path, device, use_ema=True)
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)

return model

Expand Down Expand Up @@ -178,7 +178,18 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):


def infer_process(
ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=speed, show_info=print, progress=tqdm
ref_audio,
ref_text,
gen_text,
model_obj,
cross_fade_duration=0.15,
speed=1.0,
show_info=print,
progress=tqdm,
nfe_step=32,
cfg_strength=2,
sway_sampling_coef=-1,
fix_duration=None,
):
# Split the input text into batches
audio, sr = torchaudio.load(ref_audio)
Expand All @@ -188,14 +199,36 @@ def infer_process(
print(f"gen_text {i}", gen_text)

show_info(f"Generating audio in {len(gen_text_batches)} batches...")
return infer_batch_process((audio, sr), ref_text, gen_text_batches, model_obj, cross_fade_duration, speed, progress)
return infer_batch_process(
(audio, sr),
ref_text,
gen_text_batches,
model_obj,
cross_fade_duration,
speed,
progress,
nfe_step,
cfg_strength,
sway_sampling_coef,
fix_duration,
)


# infer batches


def infer_batch_process(
ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm
ref_audio,
ref_text,
gen_text_batches,
model_obj,
cross_fade_duration=0.15,
speed=1,
progress=tqdm,
nfe_step=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
fix_duration=None,
):
audio, sr = ref_audio
if audio.shape[0] > 1:
Expand All @@ -219,11 +252,14 @@ def infer_batch_process(
text_list = [ref_text + gen_text]
final_text_list = convert_char_to_pinyin(text_list)

# Calculate duration
ref_audio_len = audio.shape[-1] // hop_length
ref_text_len = len(ref_text.encode("utf-8"))
gen_text_len = len(gen_text.encode("utf-8"))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
if fix_duration is not None:
duration = int(fix_duration * target_sample_rate / hop_length)
else:
# Calculate duration
ref_audio_len = audio.shape[-1] // hop_length
ref_text_len = len(ref_text.encode("utf-8"))
gen_text_len = len(gen_text.encode("utf-8"))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)

# inference
with torch.inference_mode():
Expand Down Expand Up @@ -293,8 +329,6 @@ def infer_batch_process(


# remove silence from generated wav


def remove_silence_for_generated_wav(filename):
aseg = AudioSegment.from_file(filename)
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
Expand Down

0 comments on commit 25cdc51

Please sign in to comment.