From 0ee258c0251fc6a314aecbf667a61def2fa768de Mon Sep 17 00:00:00 2001 From: ZhikangNiu Date: Sat, 23 Nov 2024 23:51:31 +0800 Subject: [PATCH 1/4] support hydra config training --- src/f5_tts/config/E2TTS_Base_train.yaml | 40 +++++++++ src/f5_tts/config/F5TTS_Base_train.yaml | 42 ++++++++++ src/f5_tts/train/README.md | 2 +- src/f5_tts/train/train.py | 104 ++++++++---------------- 4 files changed, 118 insertions(+), 70 deletions(-) create mode 100644 src/f5_tts/config/E2TTS_Base_train.yaml create mode 100644 src/f5_tts/config/F5TTS_Base_train.yaml diff --git a/src/f5_tts/config/E2TTS_Base_train.yaml b/src/f5_tts/config/E2TTS_Base_train.yaml new file mode 100644 index 00000000..4672bb85 --- /dev/null +++ b/src/f5_tts/config/E2TTS_Base_train.yaml @@ -0,0 +1,40 @@ +hydra: + run: + dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +datasets: + name: Emilia_ZH_EN + batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 + batch_size_type: frame # "frame" or "sample" + max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models + +optim: + epochs: 15 + learning_rate: 7.5e-5 + num_warmup_updates: 20000 # warmup steps + grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps + max_grad_norm: 1.0 + +model: + name: E2TTS + tokenizer: char + tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) + arch: + dim: 1024 + depth: 24 + heads: 16 + ff_mult: 4 + mel_spec: + target_sample_rate: 24000 + n_mel_channels: 100 + hop_length: 256 + win_length: 1024 + n_fft: 1024 + mel_spec_type: vocos # 'vocos' or 'bigvgan' + is_local_vocoder: False + local_vocoder_path: None + +ckpts: + save_per_updates: 50000 # save checkpoint per steps + last_per_steps: 5000 # save last checkpoint per steps + save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} \ No newline at end of file diff --git a/src/f5_tts/config/F5TTS_Base_train.yaml b/src/f5_tts/config/F5TTS_Base_train.yaml new file mode 100644 index 00000000..f3ead3ec --- /dev/null +++ b/src/f5_tts/config/F5TTS_Base_train.yaml @@ -0,0 +1,42 @@ +hydra: + run: + dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +datasets: + name: Emilia_ZH_EN + batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 + batch_size_type: frame # "frame" or "sample" + max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models + +optim: + epochs: 15 + learning_rate: 7.5e-5 + num_warmup_updates: 20000 # warmup steps + grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps + max_grad_norm: 1.0 + +model: + name: F5TTS + tokenizer: char + tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) + arch: + dim: 1024 + depth: 22 + heads: 16 + ff_mult: 2 + text_dim: 512 + conv_layers: 4 + mel_spec: + target_sample_rate: 24000 + n_mel_channels: 100 + hop_length: 256 + win_length: 1024 + n_fft: 1024 + mel_spec_type: vocos # 'vocos' or 'bigvgan' + is_local_vocoder: False + local_vocoder_path: None + +ckpts: + save_per_updates: 50000 # save checkpoint per steps + last_per_steps: 5000 # save last checkpoint per steps + save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} \ No newline at end of file diff --git a/src/f5_tts/train/README.md b/src/f5_tts/train/README.md index d114db58..a6dfda37 100644 --- a/src/f5_tts/train/README.md +++ b/src/f5_tts/train/README.md @@ -35,7 +35,7 @@ Once your datasets are prepared, you can start the training process. # setup accelerate config, e.g. use multi-gpu ddp, fp16 # will be to: ~/.cache/huggingface/accelerate/default_config.yaml accelerate config -accelerate launch src/f5_tts/train/train.py +accelerate launch src/f5_tts/train/train.py --config-name F5TTS_Base_train.yaml # F5TTS_Base_train.yaml | E2TTS_Base_train.yaml ``` ### 2. Finetuning practice diff --git a/src/f5_tts/train/train.py b/src/f5_tts/train/train.py index fac0fe5a..48341f20 100644 --- a/src/f5_tts/train/train.py +++ b/src/f5_tts/train/train.py @@ -1,98 +1,64 @@ # training script. - +import os from importlib.resources import files +import hydra + from f5_tts.model import CFM, DiT, Trainer, UNetT from f5_tts.model.dataset import load_dataset from f5_tts.model.utils import get_tokenizer -# -------------------------- Dataset Settings --------------------------- # - -target_sample_rate = 24000 -n_mel_channels = 100 -hop_length = 256 -win_length = 1024 -n_fft = 1024 -mel_spec_type = "vocos" # 'vocos' or 'bigvgan' - -tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' -tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) -dataset_name = "Emilia_ZH_EN" - -# -------------------------- Training Settings -------------------------- # -exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base +@hydra.main(config_path=os.path.join("..", "configs"), config_name=None) +def main(cfg): + tokenizer = cfg.model.tokenizer + mel_spec_type = cfg.model.mel_spec.mel_spec_type + exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}" -learning_rate = 7.5e-5 - -batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200 -batch_size_type = "frame" # "frame" or "sample" -max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models -grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps -max_grad_norm = 1.0 - -epochs = 11 # use linear decay, thus epochs control the slope -num_warmup_updates = 20000 # warmup steps -save_per_updates = 50000 # save checkpoint per steps -last_per_steps = 5000 # save last checkpoint per steps - -# model params -if exp_name == "F5TTS_Base": - wandb_resume_id = None - model_cls = DiT - model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) -elif exp_name == "E2TTS_Base": - wandb_resume_id = None - model_cls = UNetT - model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) - - -# ----------------------------------------------------------------------- # - - -def main(): - if tokenizer == "custom": - tokenizer_path = tokenizer_path + # set text tokenizer + if tokenizer != "custom": + tokenizer_path = cfg.datasets.name else: - tokenizer_path = dataset_name + tokenizer_path = cfg.model.tokenizer_path vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) - mel_spec_kwargs = dict( - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mel_channels=n_mel_channels, - target_sample_rate=target_sample_rate, - mel_spec_type=mel_spec_type, - ) + # set model + if "F5TTS" in cfg.model.name: + model_cls = DiT + elif "E2TTS" in cfg.model.name: + model_cls = UNetT + wandb_resume_id = None model = CFM( - transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), - mel_spec_kwargs=mel_spec_kwargs, + transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels), + mel_spec_kwargs=cfg.model.mel_spec, vocab_char_map=vocab_char_map, ) + # init trainer trainer = Trainer( model, - epochs, - learning_rate, - num_warmup_updates=num_warmup_updates, - save_per_updates=save_per_updates, - checkpoint_path=str(files("f5_tts").joinpath(f"../../ckpts/{exp_name}")), - batch_size=batch_size_per_gpu, - batch_size_type=batch_size_type, - max_samples=max_samples, - grad_accumulation_steps=grad_accumulation_steps, - max_grad_norm=max_grad_norm, + epochs=cfg.optim.epochs, + learning_rate=cfg.optim.learning_rate, + num_warmup_updates=cfg.optim.num_warmup_updates, + save_per_updates=cfg.ckpts.save_per_updates, + checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")), + batch_size=cfg.datasets.batch_size_per_gpu, + batch_size_type=cfg.datasets.batch_size_type, + max_samples=cfg.datasets.max_samples, + grad_accumulation_steps=cfg.optim.grad_accumulation_steps, + max_grad_norm=cfg.optim.max_grad_norm, wandb_project="CFM-TTS", wandb_run_name=exp_name, wandb_resume_id=wandb_resume_id, - last_per_steps=last_per_steps, + last_per_steps=cfg.ckpts.last_per_steps, log_samples=True, mel_spec_type=mel_spec_type, + is_local_vocoder=cfg.model.mel_spec.is_local_vocoder, + local_vocoder_path=cfg.model.mel_spec.local_vocoder_path, ) - train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) + train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec) trainer.train( train_dataset, resumable_with_seed=666, # seed for shuffling dataset From e29a0d5b2321c97125747a1ad7f55bf2039fa233 Mon Sep 17 00:00:00 2001 From: ZhikangNiu Date: Sun, 24 Nov 2024 01:11:24 +0800 Subject: [PATCH 2/4] support command line set args --- src/f5_tts/eval/README.md | 4 +- src/f5_tts/eval/eval_infer_batch.py | 4 +- .../eval/eval_librispeech_test_clean.py | 115 +++++++++-------- src/f5_tts/eval/eval_seedtts_testset.py | 117 ++++++++++-------- 4 files changed, 130 insertions(+), 110 deletions(-) diff --git a/src/f5_tts/eval/README.md b/src/f5_tts/eval/README.md index b0f9e00d..ed358207 100644 --- a/src/f5_tts/eval/README.md +++ b/src/f5_tts/eval/README.md @@ -42,8 +42,8 @@ Then update in the following scripts with the paths you put evaluation model ckp Update the path with your batch-inferenced results, and carry out WER / SIM evaluations: ```bash # Evaluation for Seed-TTS test set -python src/f5_tts/eval/eval_seedtts_testset.py +python src/f5_tts/eval/eval_seedtts_testset.py --gen_wav_dir # Evaluation for LibriSpeech-PC test-clean (cross-sentence) -python src/f5_tts/eval/eval_librispeech_test_clean.py +python src/f5_tts/eval/eval_librispeech_test_clean.py --gen_wav_dir ``` \ No newline at end of file diff --git a/src/f5_tts/eval/eval_infer_batch.py b/src/f5_tts/eval/eval_infer_batch.py index 8598f487..785880cc 100644 --- a/src/f5_tts/eval/eval_infer_batch.py +++ b/src/f5_tts/eval/eval_infer_batch.py @@ -34,8 +34,6 @@ n_fft = 1024 target_rms = 0.1 - -tokenizer = "pinyin" rel_path = str(files("f5_tts").joinpath("../../")) @@ -49,6 +47,7 @@ def main(): parser.add_argument("-n", "--expname", required=True) parser.add_argument("-c", "--ckptstep", default=1200000, type=int) parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"]) + parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"]) parser.add_argument("-nfe", "--nfestep", default=32, type=int) parser.add_argument("-o", "--odemethod", default="euler") @@ -64,6 +63,7 @@ def main(): ckpt_step = args.ckptstep ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt" mel_spec_type = args.mel_spec_type + tokenizer = args.tokenizer nfe_step = args.nfestep ode_method = args.odemethod diff --git a/src/f5_tts/eval/eval_librispeech_test_clean.py b/src/f5_tts/eval/eval_librispeech_test_clean.py index 7f13ab1c..a5f76e09 100644 --- a/src/f5_tts/eval/eval_librispeech_test_clean.py +++ b/src/f5_tts/eval/eval_librispeech_test_clean.py @@ -2,6 +2,7 @@ import sys import os +import argparse sys.path.append(os.getcwd()) @@ -19,55 +20,65 @@ rel_path = str(files("f5_tts").joinpath("../../")) -eval_task = "wer" # sim | wer -lang = "en" -metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst" -librispeech_test_clean_path = "/LibriSpeech/test-clean" # test-clean path -gen_wav_dir = "PATH_TO_GENERATED" # generated wavs - -gpus = [0, 1, 2, 3, 4, 5, 6, 7] -test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path) - -## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book, -## leading to a low similarity for the ground truth in some cases. -# test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth - -local = False -if local: # use local custom checkpoint dir - asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3" -else: - asr_ckpt_dir = "" # auto download to cache dir - -wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" - - -# --------------------------- WER --------------------------- - -if eval_task == "wer": - wers = [] - - with mp.Pool(processes=len(gpus)) as pool: - args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] - results = pool.map(run_asr_wer, args) - for wers_ in results: - wers.extend(wers_) - - wer = round(np.mean(wers) * 100, 3) - print(f"\nTotal {len(wers)} samples") - print(f"WER : {wer}%") - - -# --------------------------- SIM --------------------------- - -if eval_task == "sim": - sim_list = [] - - with mp.Pool(processes=len(gpus)) as pool: - args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] - results = pool.map(run_sim, args) - for sim_ in results: - sim_list.extend(sim_) - - sim = round(sum(sim_list) / len(sim_list), 3) - print(f"\nTotal {len(sim_list)} samples") - print(f"SIM : {sim}") +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"]) + parser.add_argument("-l", "--lang", type=str, default="en") + parser.add_argument("-g", "--gen_wav_dir", type=str, required=True) + parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True) + parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use") + parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory") + return parser.parse_args() + + +def main(): + args = get_args() + eval_task = args.eval_task + lang = args.lang + librispeech_test_clean_path = args.librispeech_test_clean_path # test-clean path + gen_wav_dir = args.gen_wav_dir + metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst" + + gpus = list(range(args.gpu_nums)) + test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path) + + ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book, + ## leading to a low similarity for the ground truth in some cases. + # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth + + local = args.local + if local: # use local custom checkpoint dir + asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3" + else: + asr_ckpt_dir = "" # auto download to cache dir + wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" + + # --------------------------- WER --------------------------- + if eval_task == "wer": + wers = [] + with mp.Pool(processes=len(gpus)) as pool: + args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] + results = pool.map(run_asr_wer, args) + for wers_ in results: + wers.extend(wers_) + + wer = round(np.mean(wers) * 100, 3) + print(f"\nTotal {len(wers)} samples") + print(f"WER : {wer}%") + + # --------------------------- SIM --------------------------- + if eval_task == "sim": + sim_list = [] + with mp.Pool(processes=len(gpus)) as pool: + args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] + results = pool.map(run_sim, args) + for sim_ in results: + sim_list.extend(sim_) + + sim = round(sum(sim_list) / len(sim_list), 3) + print(f"\nTotal {len(sim_list)} samples") + print(f"SIM : {sim}") + + +if __name__ == "__main__": + main() diff --git a/src/f5_tts/eval/eval_seedtts_testset.py b/src/f5_tts/eval/eval_seedtts_testset.py index 88b1a8d6..5cc19877 100644 --- a/src/f5_tts/eval/eval_seedtts_testset.py +++ b/src/f5_tts/eval/eval_seedtts_testset.py @@ -2,6 +2,7 @@ import sys import os +import argparse sys.path.append(os.getcwd()) @@ -19,57 +20,65 @@ rel_path = str(files("f5_tts").joinpath("../../")) -eval_task = "wer" # sim | wer -lang = "zh" # zh | en -metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset -# gen_wav_dir = rel_path + f"/data/seedtts_testset/{lang}/wavs" # ground truth wavs -gen_wav_dir = "PATH_TO_GENERATED" # generated wavs - - -# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different -# zh 1.254 seems a result of 4 workers wer_seed_tts -gpus = [0, 1, 2, 3, 4, 5, 6, 7] -test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus) - -local = False -if local: # use local custom checkpoint dir - if lang == "zh": - asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr - elif lang == "en": - asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3" -else: - asr_ckpt_dir = "" # auto download to cache dir - -wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" - - -# --------------------------- WER --------------------------- - -if eval_task == "wer": - wers = [] - - with mp.Pool(processes=len(gpus)) as pool: - args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] - results = pool.map(run_asr_wer, args) - for wers_ in results: - wers.extend(wers_) - - wer = round(np.mean(wers) * 100, 3) - print(f"\nTotal {len(wers)} samples") - print(f"WER : {wer}%") - - -# --------------------------- SIM --------------------------- - -if eval_task == "sim": - sim_list = [] - - with mp.Pool(processes=len(gpus)) as pool: - args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] - results = pool.map(run_sim, args) - for sim_ in results: - sim_list.extend(sim_) - - sim = round(sum(sim_list) / len(sim_list), 3) - print(f"\nTotal {len(sim_list)} samples") - print(f"SIM : {sim}") +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"]) + parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"]) + parser.add_argument("-g", "--gen_wav_dir", type=str, required=True) + parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use") + parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory") + return parser.parse_args() + + +def main(): + args = get_args() + eval_task = args.eval_task + lang = args.lang + gen_wav_dir = args.gen_wav_dir + metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset + + # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different + # zh 1.254 seems a result of 4 workers wer_seed_tts + gpus = list(range(args.gpu_nums)) + test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus) + + local = args.local + if local: # use local custom checkpoint dir + if lang == "zh": + asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr + elif lang == "en": + asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3" + else: + asr_ckpt_dir = "" # auto download to cache dir + wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" + + # --------------------------- WER --------------------------- + + if eval_task == "wer": + wers = [] + with mp.Pool(processes=len(gpus)) as pool: + args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] + results = pool.map(run_asr_wer, args) + for wers_ in results: + wers.extend(wers_) + + wer = round(np.mean(wers) * 100, 3) + print(f"\nTotal {len(wers)} samples") + print(f"WER : {wer}%") + + # --------------------------- SIM --------------------------- + if eval_task == "sim": + sim_list = [] + with mp.Pool(processes=len(gpus)) as pool: + args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] + results = pool.map(run_sim, args) + for sim_ in results: + sim_list.extend(sim_) + + sim = round(sum(sim_list) / len(sim_list), 3) + print(f"\nTotal {len(sim_list)} samples") + print(f"SIM : {sim}") + + +if __name__ == "__main__": + main() From c8c8e4725e5a7cc3881e6b4b78738d55c1098f25 Mon Sep 17 00:00:00 2001 From: ZhikangNiu Date: Sun, 24 Nov 2024 13:27:05 +0800 Subject: [PATCH 3/4] update F5 and E2 config --- src/f5_tts/config/E2TTS_Base_train.yaml | 34 +++++++++---------- src/f5_tts/config/E2TTS_Small_train.yaml | 40 ++++++++++++++++++++++ src/f5_tts/config/F5TTS_Base_train.yaml | 38 ++++++++++----------- src/f5_tts/config/F5TTS_Small_train.yaml | 42 ++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 36 deletions(-) create mode 100644 src/f5_tts/config/E2TTS_Small_train.yaml create mode 100644 src/f5_tts/config/F5TTS_Small_train.yaml diff --git a/src/f5_tts/config/E2TTS_Base_train.yaml b/src/f5_tts/config/E2TTS_Base_train.yaml index 4672bb85..9d7d77b8 100644 --- a/src/f5_tts/config/E2TTS_Base_train.yaml +++ b/src/f5_tts/config/E2TTS_Base_train.yaml @@ -3,36 +3,36 @@ hydra: dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} datasets: - name: Emilia_ZH_EN + name: Emilia_ZH_EN # dataset name batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 batch_size_type: frame # "frame" or "sample" max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models optim: - epochs: 15 - learning_rate: 7.5e-5 + epochs: 15 # max epochs + learning_rate: 7.5e-5 # learning rate num_warmup_updates: 20000 # warmup steps grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps - max_grad_norm: 1.0 + max_grad_norm: 1.0 # gradient clipping model: - name: E2TTS - tokenizer: char + name: E2TTS_Base # model name + tokenizer: pinyin # tokenizer type tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) arch: - dim: 1024 - depth: 24 - heads: 16 - ff_mult: 4 + dim: 1024 # model dimension + depth: 24 # number of transformer layers + heads: 16 # number of transformer heads + ff_mult: 4 # ff layer expansion mel_spec: - target_sample_rate: 24000 - n_mel_channels: 100 - hop_length: 256 - win_length: 1024 - n_fft: 1024 + target_sample_rate: 24000 # target sample rate + n_mel_channels: 100 # mel channel + hop_length: 256 # hop length + win_length: 1024 # window length + n_fft: 1024 # fft length mel_spec_type: vocos # 'vocos' or 'bigvgan' - is_local_vocoder: False - local_vocoder_path: None + is_local_vocoder: False # use local vocoder or not + local_vocoder_path: None # path to local vocoder ckpts: save_per_updates: 50000 # save checkpoint per steps diff --git a/src/f5_tts/config/E2TTS_Small_train.yaml b/src/f5_tts/config/E2TTS_Small_train.yaml new file mode 100644 index 00000000..a836dc36 --- /dev/null +++ b/src/f5_tts/config/E2TTS_Small_train.yaml @@ -0,0 +1,40 @@ +hydra: + run: + dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +datasets: + name: Emilia_ZH_EN + batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 + batch_size_type: frame # "frame" or "sample" + max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models + +optim: + epochs: 15 + learning_rate: 7.5e-5 + num_warmup_updates: 20000 # warmup steps + grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps + max_grad_norm: 1.0 + +model: + name: E2TTS_Small + tokenizer: pinyin + tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) + arch: + dim: 768 + depth: 20 + heads: 12 + ff_mult: 4 + mel_spec: + target_sample_rate: 24000 + n_mel_channels: 100 + hop_length: 256 + win_length: 1024 + n_fft: 1024 + mel_spec_type: vocos # 'vocos' or 'bigvgan' + is_local_vocoder: False + local_vocoder_path: None + +ckpts: + save_per_updates: 50000 # save checkpoint per steps + last_per_steps: 5000 # save last checkpoint per steps + save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} \ No newline at end of file diff --git a/src/f5_tts/config/F5TTS_Base_train.yaml b/src/f5_tts/config/F5TTS_Base_train.yaml index f3ead3ec..73299f5f 100644 --- a/src/f5_tts/config/F5TTS_Base_train.yaml +++ b/src/f5_tts/config/F5TTS_Base_train.yaml @@ -3,38 +3,38 @@ hydra: dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} datasets: - name: Emilia_ZH_EN + name: Emilia_ZH_EN # dataset name batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 batch_size_type: frame # "frame" or "sample" max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models optim: - epochs: 15 - learning_rate: 7.5e-5 + epochs: 15 # max epochs + learning_rate: 7.5e-5 # learning rate num_warmup_updates: 20000 # warmup steps grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps - max_grad_norm: 1.0 + max_grad_norm: 1.0 # gradient clipping model: - name: F5TTS - tokenizer: char + name: F5TTS_Base # model name + tokenizer: pinyin # tokenizer type tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) arch: - dim: 1024 - depth: 22 - heads: 16 - ff_mult: 2 - text_dim: 512 - conv_layers: 4 + dim: 1024 # model dim + depth: 22 # model depth + heads: 16 # model heads + ff_mult: 2 # feedforward expansion + text_dim: 512 # text encoder dim + conv_layers: 4 # convolution layers mel_spec: - target_sample_rate: 24000 - n_mel_channels: 100 - hop_length: 256 - win_length: 1024 - n_fft: 1024 + target_sample_rate: 24000 # target sample rate + n_mel_channels: 100 # mel channel + hop_length: 256 # hop length + win_length: 1024 # window length + n_fft: 1024 # fft length mel_spec_type: vocos # 'vocos' or 'bigvgan' - is_local_vocoder: False - local_vocoder_path: None + is_local_vocoder: False # use local vocoder or not + local_vocoder_path: None # local vocoder path ckpts: save_per_updates: 50000 # save checkpoint per steps diff --git a/src/f5_tts/config/F5TTS_Small_train.yaml b/src/f5_tts/config/F5TTS_Small_train.yaml new file mode 100644 index 00000000..d165b4fa --- /dev/null +++ b/src/f5_tts/config/F5TTS_Small_train.yaml @@ -0,0 +1,42 @@ +hydra: + run: + dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +datasets: + name: Emilia_ZH_EN + batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 + batch_size_type: frame # "frame" or "sample" + max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models + +optim: + epochs: 15 + learning_rate: 7.5e-5 + num_warmup_updates: 20000 # warmup steps + grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps + max_grad_norm: 1.0 + +model: + name: F5TTS_Small + tokenizer: pinyin + tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) + arch: + dim: 768 + depth: 18 + heads: 12 + ff_mult: 2 + text_dim: 512 + conv_layers: 4 + mel_spec: + target_sample_rate: 24000 + n_mel_channels: 100 + hop_length: 256 + win_length: 1024 + n_fft: 1024 + mel_spec_type: vocos # 'vocos' or 'bigvgan' + is_local_vocoder: False + local_vocoder_path: None + +ckpts: + save_per_updates: 50000 # save checkpoint per steps + last_per_steps: 5000 # save last checkpoint per steps + save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} \ No newline at end of file From 3a84b7972c840356d52243b4073bed73822c323c Mon Sep 17 00:00:00 2001 From: ZhikangNiu Date: Mon, 25 Nov 2024 15:05:00 +0800 Subject: [PATCH 4/4] support finetune_cli hydra and fix some minor bugs --- src/f5_tts/config/E2TTS_Base_finetune.yaml | 46 ++++++ src/f5_tts/config/E2TTS_Base_train.yaml | 3 + src/f5_tts/config/E2TTS_Small_train.yaml | 3 + src/f5_tts/config/F5TTS_Base_finetune.yaml | 46 ++++++ src/f5_tts/config/F5TTS_Base_train.yaml | 3 + src/f5_tts/config/F5TTS_Small_train.yaml | 3 + src/f5_tts/model/trainer.py | 2 +- src/f5_tts/model/utils.py | 4 +- src/f5_tts/train/finetune_cli.py | 183 ++++++--------------- src/f5_tts/train/train.py | 3 + 10 files changed, 158 insertions(+), 138 deletions(-) create mode 100644 src/f5_tts/config/E2TTS_Base_finetune.yaml create mode 100644 src/f5_tts/config/F5TTS_Base_finetune.yaml diff --git a/src/f5_tts/config/E2TTS_Base_finetune.yaml b/src/f5_tts/config/E2TTS_Base_finetune.yaml new file mode 100644 index 00000000..4451a898 --- /dev/null +++ b/src/f5_tts/config/E2TTS_Base_finetune.yaml @@ -0,0 +1,46 @@ +hydra: + run: + dir: ckpts/finetune_${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +datasets: + name: Emilia_ZH_EN # dataset name + batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 + batch_size_type: frame # "frame" or "sample" + max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models + num_workers: 16 # number of workers + +optim: + epochs: 15 # max epochs + learning_rate: 7.5e-5 # learning rate + num_warmup_updates: 20000 # warmup steps + grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps + max_grad_norm: 1.0 # gradient clipping + bnb_optimizer: False # use bnb optimizer or not + +model: + name: F5TTS_Base # model name + tokenizer: pinyin # tokenizer type + tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) + arch: + dim: 1024 # model dim + depth: 22 # model depth + heads: 16 # model heads + ff_mult: 2 # feedforward expansion + text_dim: 512 # text encoder dim + conv_layers: 4 # convolution layers + mel_spec: + target_sample_rate: 24000 # target sample rate + n_mel_channels: 100 # mel channel + hop_length: 256 # hop length + win_length: 1024 # window length + n_fft: 1024 # fft length + mel_spec_type: vocos # 'vocos' or 'bigvgan' + is_local_vocoder: False # use local vocoder or not + local_vocoder_path: None # local vocoder path + +ckpts: + logger: wandb # wandb | tensorboard | None + save_per_updates: 50000 # save checkpoint per steps + last_per_steps: 5000 # save last checkpoint per steps + pretain_ckpt_path: ckpts/E2TTS_Base/model_1200000.pt + save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} \ No newline at end of file diff --git a/src/f5_tts/config/E2TTS_Base_train.yaml b/src/f5_tts/config/E2TTS_Base_train.yaml index 9d7d77b8..09c3f8cb 100644 --- a/src/f5_tts/config/E2TTS_Base_train.yaml +++ b/src/f5_tts/config/E2TTS_Base_train.yaml @@ -7,6 +7,7 @@ datasets: batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 batch_size_type: frame # "frame" or "sample" max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models + num_workers: 16 # number of workers optim: epochs: 15 # max epochs @@ -14,6 +15,7 @@ optim: num_warmup_updates: 20000 # warmup steps grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps max_grad_norm: 1.0 # gradient clipping + bnb_optimizer: False # use bnb optimizer or not model: name: E2TTS_Base # model name @@ -35,6 +37,7 @@ model: local_vocoder_path: None # path to local vocoder ckpts: + logger: wandb # wandb | tensorboard | None save_per_updates: 50000 # save checkpoint per steps last_per_steps: 5000 # save last checkpoint per steps save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} \ No newline at end of file diff --git a/src/f5_tts/config/E2TTS_Small_train.yaml b/src/f5_tts/config/E2TTS_Small_train.yaml index a836dc36..e3de8703 100644 --- a/src/f5_tts/config/E2TTS_Small_train.yaml +++ b/src/f5_tts/config/E2TTS_Small_train.yaml @@ -7,6 +7,7 @@ datasets: batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 batch_size_type: frame # "frame" or "sample" max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models + num_workers: 16 # number of workers optim: epochs: 15 @@ -14,6 +15,7 @@ optim: num_warmup_updates: 20000 # warmup steps grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps max_grad_norm: 1.0 + bnb_optimizer: False model: name: E2TTS_Small @@ -35,6 +37,7 @@ model: local_vocoder_path: None ckpts: + logger: wandb # wandb | tensorboard | None save_per_updates: 50000 # save checkpoint per steps last_per_steps: 5000 # save last checkpoint per steps save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} \ No newline at end of file diff --git a/src/f5_tts/config/F5TTS_Base_finetune.yaml b/src/f5_tts/config/F5TTS_Base_finetune.yaml new file mode 100644 index 00000000..32478deb --- /dev/null +++ b/src/f5_tts/config/F5TTS_Base_finetune.yaml @@ -0,0 +1,46 @@ +hydra: + run: + dir: ckpts/finetune_${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +datasets: + name: Emilia_ZH_EN # dataset name + batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 + batch_size_type: frame # "frame" or "sample" + max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models + num_workers: 16 # number of workers + +optim: + epochs: 15 # max epochs + learning_rate: 7.5e-5 # learning rate + num_warmup_updates: 20000 # warmup steps + grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps + max_grad_norm: 1.0 # gradient clipping + bnb_optimizer: False # use bnb optimizer or not + +model: + name: F5TTS_Base # model name + tokenizer: pinyin # tokenizer type + tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) + arch: + dim: 1024 # model dim + depth: 22 # model depth + heads: 16 # model heads + ff_mult: 2 # feedforward expansion + text_dim: 512 # text encoder dim + conv_layers: 4 # convolution layers + mel_spec: + target_sample_rate: 24000 # target sample rate + n_mel_channels: 100 # mel channel + hop_length: 256 # hop length + win_length: 1024 # window length + n_fft: 1024 # fft length + mel_spec_type: vocos # 'vocos' or 'bigvgan' + is_local_vocoder: False # use local vocoder or not + local_vocoder_path: None # local vocoder path + +ckpts: + logger: wandb # wandb | tensorboard | None + save_per_updates: 50000 # save checkpoint per steps + last_per_steps: 5000 # save last checkpoint per steps + pretain_ckpt_path: ckpts/F5TTS_Base/model_1200000.pt + save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} \ No newline at end of file diff --git a/src/f5_tts/config/F5TTS_Base_train.yaml b/src/f5_tts/config/F5TTS_Base_train.yaml index 73299f5f..11e0d8fd 100644 --- a/src/f5_tts/config/F5TTS_Base_train.yaml +++ b/src/f5_tts/config/F5TTS_Base_train.yaml @@ -7,6 +7,7 @@ datasets: batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 batch_size_type: frame # "frame" or "sample" max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models + num_workers: 16 # number of workers optim: epochs: 15 # max epochs @@ -14,6 +15,7 @@ optim: num_warmup_updates: 20000 # warmup steps grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps max_grad_norm: 1.0 # gradient clipping + bnb_optimizer: False # use bnb optimizer or not model: name: F5TTS_Base # model name @@ -37,6 +39,7 @@ model: local_vocoder_path: None # local vocoder path ckpts: + logger: wandb # wandb | tensorboard | None save_per_updates: 50000 # save checkpoint per steps last_per_steps: 5000 # save last checkpoint per steps save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} \ No newline at end of file diff --git a/src/f5_tts/config/F5TTS_Small_train.yaml b/src/f5_tts/config/F5TTS_Small_train.yaml index d165b4fa..5136f3b7 100644 --- a/src/f5_tts/config/F5TTS_Small_train.yaml +++ b/src/f5_tts/config/F5TTS_Small_train.yaml @@ -7,6 +7,7 @@ datasets: batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 batch_size_type: frame # "frame" or "sample" max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models + num_workers: 16 # number of workers optim: epochs: 15 @@ -14,6 +15,7 @@ optim: num_warmup_updates: 20000 # warmup steps grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps max_grad_norm: 1.0 + bnb_optimizer: False model: name: F5TTS_Small @@ -37,6 +39,7 @@ model: local_vocoder_path: None ckpts: + logger: wandb # wandb | tensorboard | None save_per_updates: 50000 # save checkpoint per steps last_per_steps: 5000 # save last checkpoint per steps save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} \ No newline at end of file diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 0825b024..7aab3473 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -91,7 +91,7 @@ def __init__( elif self.logger == "tensorboard": from torch.utils.tensorboard import SummaryWriter - self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}") + self.writer = SummaryWriter(log_dir=f"{checkpoint_path}/runs/{wandb_run_name}") self.model = model diff --git a/src/f5_tts/model/utils.py b/src/f5_tts/model/utils.py index 76cfa4d0..8da7aa4d 100644 --- a/src/f5_tts/model/utils.py +++ b/src/f5_tts/model/utils.py @@ -113,7 +113,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): with open(tokenizer_path, "r", encoding="utf-8") as f: vocab_char_map = {} for i, char in enumerate(f): - vocab_char_map[char[:-1]] = i + vocab_char_map[char.strip()] = i # ignore \n vocab_size = len(vocab_char_map) assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" @@ -125,7 +125,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): with open(dataset_name, "r", encoding="utf-8") as f: vocab_char_map = {} for i, char in enumerate(f): - vocab_char_map[char[:-1]] = i + vocab_char_map[char.strip()] = i vocab_size = len(vocab_char_map) return vocab_char_map, vocab_size diff --git a/src/f5_tts/train/finetune_cli.py b/src/f5_tts/train/finetune_cli.py index 187fd68a..f33ed48c 100644 --- a/src/f5_tts/train/finetune_cli.py +++ b/src/f5_tts/train/finetune_cli.py @@ -1,6 +1,6 @@ -import argparse import os import shutil +import hydra from cached_path import cached_path from f5_tts.model import CFM, UNetT, DiT, Trainer @@ -9,163 +9,76 @@ from importlib.resources import files -# -------------------------- Dataset Settings --------------------------- # -target_sample_rate = 24000 -n_mel_channels = 100 -hop_length = 256 -win_length = 1024 -n_fft = 1024 -mel_spec_type = "vocos" # 'vocos' or 'bigvgan' +@hydra.main(config_path=os.path.join("..", "configs"), config_name=None) +def main(cfg): + tokenizer = cfg.model.tokenizer + mel_spec_type = cfg.model.mel_spec.mel_spec_type + exp_name = f"finetune_{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}" + # set text tokenizer + if tokenizer != "custom": + tokenizer_path = cfg.datasets.name + else: + tokenizer_path = cfg.model.tokenizer_path + vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) -# -------------------------- Argument Parsing --------------------------- # -def parse_args(): - # batch_size_per_gpu = 1000 settting for gpu 8GB - # batch_size_per_gpu = 1600 settting for gpu 12GB - # batch_size_per_gpu = 2000 settting for gpu 16GB - # batch_size_per_gpu = 3200 settting for gpu 24GB - - # num_warmup_updates = 300 for 5000 sample about 10 hours - - # change save_per_updates , last_per_steps change this value what you need , - - parser = argparse.ArgumentParser(description="Train CFM Model") - - parser.add_argument( - "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name" - ) - parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use") - parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training") - parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU") - parser.add_argument( - "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type" - ) - parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch") - parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps") - parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping") - parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs") - parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps") - parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps") - parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps") - parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune") - parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint") - parser.add_argument( - "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type" - ) - parser.add_argument( - "--tokenizer_path", - type=str, - default=None, - help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')", - ) - parser.add_argument( - "--log_samples", - type=bool, - default=False, - help="Log inferenced samples per ckpt save steps", - ) - parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger") - parser.add_argument( - "--bnb_optimizer", - type=bool, - default=False, - help="Use 8-bit Adam optimizer from bitsandbytes", - ) - - return parser.parse_args() - - -# -------------------------- Training Settings -------------------------- # - - -def main(): - args = parse_args() - - checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}")) + print("\nvocab : ", vocab_size) + print("\nvocoder : ", mel_spec_type) # Model parameters based on experiment name - if args.exp_name == "F5TTS_Base": - wandb_resume_id = None + if "F5TTS" in cfg.model.name: model_cls = DiT - model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) - if args.finetune: - if args.pretrain is None: - ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) - else: - ckpt_path = args.pretrain - elif args.exp_name == "E2TTS_Base": - wandb_resume_id = None + ckpt_path = cfg.ckpts.pretain_ckpt_path or str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) + elif "E2TTS" in cfg.model.name: model_cls = UNetT - model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) - if args.finetune: - if args.pretrain is None: - ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) - else: - ckpt_path = args.pretrain - - if args.finetune: - if not os.path.isdir(checkpoint_path): - os.makedirs(checkpoint_path, exist_ok=True) + ckpt_path = cfg.ckpts.pretain_ckpt_path or str(cached_path("hf://SWivid/F5-TTS/E2TTS_Base/model_1200000.pt")) + wandb_resume_id = None - file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path)) - if not os.path.isfile(file_checkpoint): - shutil.copy2(ckpt_path, file_checkpoint) - print("copy checkpoint for finetune") + checkpoint_path = str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")) - # Use the tokenizer and tokenizer_path provided in the command line arguments - tokenizer = args.tokenizer - if tokenizer == "custom": - if not args.tokenizer_path: - raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.") - tokenizer_path = args.tokenizer_path - else: - tokenizer_path = args.dataset_name + if not os.path.isdir(checkpoint_path): + os.makedirs(checkpoint_path, exist_ok=True) - vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) - - print("\nvocab : ", vocab_size) - print("\nvocoder : ", mel_spec_type) - - mel_spec_kwargs = dict( - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mel_channels=n_mel_channels, - target_sample_rate=target_sample_rate, - mel_spec_type=mel_spec_type, - ) + file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path)) + if not os.path.isfile(file_checkpoint): + shutil.copy2(ckpt_path, file_checkpoint) + print("copy checkpoint for finetune") model = CFM( - transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), - mel_spec_kwargs=mel_spec_kwargs, + transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels), + mel_spec_kwargs=cfg.model.mel_spec, vocab_char_map=vocab_char_map, ) trainer = Trainer( model, - args.epochs, - args.learning_rate, - num_warmup_updates=args.num_warmup_updates, - save_per_updates=args.save_per_updates, + epochs=cfg.optim.epochs, + learning_rate=cfg.optim.learning_rate, + num_warmup_updates=cfg.optim.num_warmup_updates, + save_per_updates=cfg.ckpts.save_per_updates, checkpoint_path=checkpoint_path, - batch_size=args.batch_size_per_gpu, - batch_size_type=args.batch_size_type, - max_samples=args.max_samples, - grad_accumulation_steps=args.grad_accumulation_steps, - max_grad_norm=args.max_grad_norm, - logger=args.logger, - wandb_project=args.dataset_name, - wandb_run_name=args.exp_name, + batch_size=cfg.datasets.batch_size_per_gpu, + batch_size_type=cfg.datasets.batch_size_type, + max_samples=cfg.datasets.max_samples, + grad_accumulation_steps=cfg.optim.grad_accumulation_steps, + max_grad_norm=cfg.optim.max_grad_norm, + logger=cfg.ckpts.logger, + wandb_project=cfg.datasets.name, + wandb_run_name=exp_name, wandb_resume_id=wandb_resume_id, - log_samples=args.log_samples, - last_per_steps=args.last_per_steps, - bnb_optimizer=args.bnb_optimizer, + log_samples=True, + last_per_steps=cfg.ckpts.last_per_steps, + bnb_optimizer=cfg.optim.bnb_optimizer, + mel_spec_type=mel_spec_type, + is_local_vocoder=cfg.model.mel_spec.is_local_vocoder, + local_vocoder_path=cfg.model.mel_spec.local_vocoder_path, ) - train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) + train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec) trainer.train( train_dataset, + num_workers=cfg.datasets.num_workers, resumable_with_seed=666, # seed for shuffling dataset ) diff --git a/src/f5_tts/train/train.py b/src/f5_tts/train/train.py index 48341f20..9ecd4c34 100644 --- a/src/f5_tts/train/train.py +++ b/src/f5_tts/train/train.py @@ -48,11 +48,13 @@ def main(cfg): max_samples=cfg.datasets.max_samples, grad_accumulation_steps=cfg.optim.grad_accumulation_steps, max_grad_norm=cfg.optim.max_grad_norm, + logger=cfg.ckpts.logger, wandb_project="CFM-TTS", wandb_run_name=exp_name, wandb_resume_id=wandb_resume_id, last_per_steps=cfg.ckpts.last_per_steps, log_samples=True, + bnb_optimizer=cfg.optim.bnb_optimizer, mel_spec_type=mel_spec_type, is_local_vocoder=cfg.model.mel_spec.is_local_vocoder, local_vocoder_path=cfg.model.mel_spec.local_vocoder_path, @@ -61,6 +63,7 @@ def main(cfg): train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec) trainer.train( train_dataset, + num_workers=cfg.datasets.num_workers, resumable_with_seed=666, # seed for shuffling dataset )