Skip to content

Commit

Permalink
minor update patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
SWivid committed Nov 15, 2024
1 parent 6f13ad4 commit 2a844ae
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
19 changes: 11 additions & 8 deletions src/f5_tts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
vocoder_name="vocos",
local_path=None,
device=None,
hf_cache_dir=None,
):
# Initialize parameters
self.final_wave = None
Expand All @@ -46,29 +47,31 @@ def __init__(
)

# Load models
self.load_vocoder_model(vocoder_name, local_path=local_path)
self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, local_path=local_path)
self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
self.load_ema_model(
model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
)

def load_vocoder_model(self, vocoder_name, local_path=None):
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)

def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, local_path=None):
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
if model_type == "F5-TTS":
if not ckpt_file:
if mel_spec_type == "vocos":
ckpt_file = str(
cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=local_path)
cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
)
elif mel_spec_type == "bigvgan":
ckpt_file = str(
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=local_path)
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
)
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cls = DiT
elif model_type == "E2-TTS":
if not ckpt_file:
ckpt_file = str(
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=local_path)
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
)
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
model_cls = UNetT
Expand Down
12 changes: 6 additions & 6 deletions src/f5_tts/infer/utils_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,18 @@ def chunk_text(text, max_chars=135):


# load vocoder
def load_vocoder(vocoder_name="vocos", is_local=False, local_path=None, device=device):
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device, hf_cache_dir=None):
if vocoder_name == "vocos":
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
if is_local and local_path is not None:
if is_local:
print(f"Load vocos from local path {local_path}")
config_path = f"{local_path}/config.yaml"
model_path = f"{local_path}/pytorch_model.bin"
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
repo_id = "charactr/vocos-mel-24khz"
config_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin")
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
vocoder = Vocos.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
from vocos.feature_extractors import EncodecFeatures
Expand All @@ -119,11 +119,11 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path=None, device=d
from third_party.BigVGAN import bigvgan
except ImportError:
print("You need to follow the README to init submodule and change the BigVGAN source code.")
if is_local and local_path is not None:
if is_local:
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
else:
local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=local_path)
local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=hf_cache_dir)
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)

vocoder.remove_weight_norm()
Expand Down

0 comments on commit 2a844ae

Please sign in to comment.