diff --git a/api.py b/api.py index 8c69fa74..efe098b6 100644 --- a/api.py +++ b/api.py @@ -7,6 +7,9 @@ from model.utils import save_spectrogram from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav +from model.utils import seed_everything +import random +import sys class F5TTS: @@ -26,6 +29,7 @@ def __init__( self.n_mel_channels = 100 self.hop_length = 256 self.target_rms = 0.1 + self.seed = -1 # Set device self.device = device or ( @@ -56,11 +60,11 @@ def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema) self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device) def export_wav(self, wav, file_wave, remove_silence=False): + sf.write(file_wave, wav, self.target_sample_rate) + if remove_silence: remove_silence_for_generated_wav(file_wave) - sf.write(file_wave, wav, self.target_sample_rate) - def export_spectrogram(self, spect, file_spect): save_spectrogram(spect, file_spect) @@ -81,7 +85,12 @@ def infer( remove_silence=False, file_wave=None, file_spect=None, + seed=-1, ): + if seed == -1: + seed = random.randint(0, sys.maxsize) + seed_everything(seed) + self.seed = seed wav, sr, spect = infer_process( ref_file, ref_text, @@ -116,4 +125,7 @@ def infer( gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""", file_wave="tests/out.wav", file_spect="tests/out.png", + seed=-1, # random seed = -1 ) + + print("seed :", f5tts.seed)