-
Notifications
You must be signed in to change notification settings - Fork 916
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add api * update infer limits
- Loading branch information
Showing
2 changed files
with
171 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import soundfile as sf | ||
import torch | ||
import tqdm | ||
from cached_path import cached_path | ||
|
||
from model import DiT, UNetT | ||
from model.utils import save_spectrogram | ||
|
||
from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav | ||
|
||
|
||
class F5TTS: | ||
def __init__( | ||
self, | ||
model_type="F5-TTS", | ||
ckpt_file="", | ||
vocab_file="", | ||
ode_method="euler", | ||
use_ema=True, | ||
local_path=None, | ||
device=None, | ||
): | ||
# Initialize parameters | ||
self.final_wave = None | ||
self.target_sample_rate = 24000 | ||
self.n_mel_channels = 100 | ||
self.hop_length = 256 | ||
self.target_rms = 0.1 | ||
|
||
# Set device | ||
self.device = device or ( | ||
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | ||
) | ||
|
||
# Load models | ||
self.load_vecoder_model(local_path) | ||
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema) | ||
|
||
def load_vecoder_model(self, local_path): | ||
self.vocos = load_vocoder(local_path is not None, local_path, self.device) | ||
|
||
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema): | ||
if model_type == "F5-TTS": | ||
if not ckpt_file: | ||
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")) | ||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) | ||
model_cls = DiT | ||
elif model_type == "E2-TTS": | ||
if not ckpt_file: | ||
ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) | ||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) | ||
model_cls = UNetT | ||
else: | ||
raise ValueError(f"Unknown model type: {model_type}") | ||
|
||
self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device) | ||
|
||
def export_wav(self, wav, file_wave, remove_silence=False): | ||
if remove_silence: | ||
remove_silence_for_generated_wav(file_wave) | ||
|
||
sf.write(file_wave, wav, self.target_sample_rate) | ||
|
||
def export_spectrogram(self, spect, file_spect): | ||
save_spectrogram(spect, file_spect) | ||
|
||
def infer( | ||
self, | ||
ref_file, | ||
ref_text, | ||
gen_text, | ||
sway_sampling_coef=-1, | ||
cfg_strength=2, | ||
nfe_step=32, | ||
speed=1.0, | ||
fix_duration=None, | ||
remove_silence=False, | ||
file_wave=None, | ||
file_spect=None, | ||
cross_fade_duration=0.15, | ||
show_info=print, | ||
progress=tqdm, | ||
): | ||
wav, sr, spect = infer_process( | ||
ref_file, | ||
ref_text, | ||
gen_text, | ||
self.ema_model, | ||
cross_fade_duration, | ||
speed, | ||
show_info, | ||
progress, | ||
nfe_step, | ||
cfg_strength, | ||
sway_sampling_coef, | ||
fix_duration, | ||
) | ||
|
||
if file_wave is not None: | ||
self.export_wav(wav, file_wave, remove_silence) | ||
|
||
if file_spect is not None: | ||
self.export_spectrogram(spect, file_spect) | ||
|
||
return wav, sr, spect | ||
|
||
|
||
if __name__ == "__main__": | ||
f5tts = F5TTS() | ||
|
||
wav, sr, spect = f5tts.infer( | ||
ref_file="tests/ref_audio/test_en_1_ref_short.wav", | ||
ref_text="some call me nature, others call me mother nature.", | ||
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""", | ||
file_wave="tests/out.wav", | ||
file_spect="tests/out.png", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters