Skip to content

Commit

Permalink
load asr pipeline only if needed
Browse files Browse the repository at this point in the history
  • Loading branch information
SWivid committed Oct 21, 2024
1 parent 795cb19 commit b899a35
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 14 deletions.
4 changes: 2 additions & 2 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def __init__(
)

# Load models
self.load_vecoder_model(local_path)
self.load_vocoder_model(local_path)
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)

def load_vecoder_model(self, local_path):
def load_vocoder_model(self, local_path):
self.vocos = load_vocoder(local_path is not None, local_path, self.device)

def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
Expand Down
4 changes: 2 additions & 2 deletions inference-cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
exp_name = "F5TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path

elif model == "E2-TTS":
model_cls = UNetT
Expand All @@ -114,7 +114,7 @@
exp_name = "E2TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path

print(f"Using {model}...")
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
Expand Down
32 changes: 22 additions & 10 deletions model/utils_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,6 @@
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

asr_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device=device,
)

vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")


Expand Down Expand Up @@ -82,8 +75,6 @@ def chunk_text(text, max_chars=135):


# load vocoder


def load_vocoder(is_local=False, local_path="", device=device):
if is_local:
print(f"Load vocos from local path {local_path}")
Expand All @@ -97,6 +88,22 @@ def load_vocoder(is_local=False, local_path="", device=device):
return vocos


# load asr pipeline

asr_pipe = None


def initialize_asr_pipeline(device=device):
global asr_pipe

asr_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large",
torch_dtype=torch.float16,
device=device,
)


# load model for inference


Expand Down Expand Up @@ -133,7 +140,7 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method="euler
# preprocess reference audio and text


def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device):
show_info("Converting audio...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
aseg = AudioSegment.from_file(ref_audio_orig)
Expand All @@ -152,6 +159,9 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
ref_audio = f.name

if not ref_text.strip():
global asr_pipe
if asr_pipe is None:
initialize_asr_pipeline(device=device)
show_info("No reference text provided, transcribing reference audio...")
ref_text = asr_pipe(
ref_audio,
Expand Down Expand Up @@ -329,6 +339,8 @@ def infer_batch_process(


# remove silence from generated wav


def remove_silence_for_generated_wav(filename):
aseg = AudioSegment.from_file(filename)
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
Expand Down

0 comments on commit b899a35

Please sign in to comment.