diff --git a/finetune-cli.py b/finetune-cli.py index 2f49e969..79ce9bb0 100644 --- a/finetune-cli.py +++ b/finetune-cli.py @@ -28,6 +28,7 @@ def parse_args(): parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps') parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps') parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps') + parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune') return parser.parse_args() @@ -42,17 +43,21 @@ def main(): wandb_resume_id = None model_cls = DiT model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) - ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) + if args.finetune: + ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) elif args.exp_name == "E2TTS_Base": wandb_resume_id = None model_cls = UNetT model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) - ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) + if args.finetune: + ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) + + if args.finetune: + path_ckpt = os.path.join("ckpts",args.dataset_name) + if os.path.isdir(path_ckpt)==False: + os.makedirs(path_ckpt,exist_ok=True) + shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path))) - path_ckpt = os.path.join("ckpts",args.dataset_name) - if os.path.isdir(path_ckpt)==False: - os.makedirs(path_ckpt,exist_ok=True) - shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path))) checkpoint_path=os.path.join("ckpts",args.dataset_name) # Use the dataset_name provided in the command line diff --git a/finetune_gradio.py b/finetune_gradio.py index 4908d787..d6db8cc4 100644 --- a/finetune_gradio.py +++ b/finetune_gradio.py @@ -9,24 +9,19 @@ import librosa import numpy as np from scipy.io import wavfile -from tqdm import tqdm import shutil import time import json -from datasets import Dataset from model.utils import convert_char_to_pinyin import signal import psutil import platform import subprocess from datasets.arrow_writer import ArrowWriter -from datasets import load_dataset, load_from_disk import json - - training_process = None system = platform.system() python_executable = sys.executable or "python" @@ -265,8 +260,20 @@ def start_training(dataset_name="", finetune=True, ): + global training_process + path_project = os.path.join(path_data, dataset_name + "_pinyin") + + if os.path.isdir(path_project)==False: + yield f"There is not project with name {dataset_name}",gr.update(interactive=True),gr.update(interactive=False) + return + + file_raw = os.path.join(path_project,"raw.arrow") + if os.path.isfile(file_raw)==False: + yield f"There is no file {file_raw}",gr.update(interactive=True),gr.update(interactive=False) + return + # Check if a training process is already running if training_process is not None: return "Train run already!",gr.update(interactive=False),gr.update(interactive=True) @@ -274,7 +281,7 @@ def start_training(dataset_name="", yield "start train",gr.update(interactive=False),gr.update(interactive=False) # Command to run the training script with the specified arguments - cmd = f"{python_executable} finetune-cli.py --exp_name {exp_name} " \ + cmd = f"accelerate launch finetune-cli.py --exp_name {exp_name} " \ f"--learning_rate {learning_rate} " \ f"--batch_size_per_gpu {batch_size_per_gpu} " \ f"--batch_size_type {batch_size_type} " \ @@ -346,6 +353,8 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog path_project_wavs = os.path.join(path_project,"wavs") file_metadata = os.path.join(path_project,"metadata.csv") + if audio_files is None:return "You need to load an audio file." + if os.path.isdir(path_project_wavs): shutil.rmtree(path_project_wavs) @@ -356,16 +365,17 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog if user: file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))] + if file_audios==[]:return "No audio file was found in the dataset." else: file_audios = audio_files - - print([file_audios]) + alpha = 0.5 _max = 1.0 slicer = Slicer(24000) num = 0 + error_num = 0 data="" for file_audio in progress.tqdm(file_audios, desc="transcribe files",total=len((file_audios))): @@ -381,18 +391,26 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog if(tmp_max>1):chunk/=tmp_max chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk wavfile.write(file_segment,24000, (chunk * 32767).astype(np.int16)) + + try: + text=transcribe(file_segment,language) + text = text.lower().strip().replace('"',"") - text=transcribe(file_segment,language) - text = text.lower().strip().replace('"',"") + data+= f"{name_segment}|{text}\n" - data+= f"{name_segment}|{text}\n" + num+=1 + except: + error_num +=1 - num+=1 - with open(file_metadata,"w",encoding="utf-8") as f: f.write(data) - - return f"transcribe complete samples : {num} in path {path_project_wavs}" + + if error_num!=[]: + error_text=f"\nerror files : {error_num}" + else: + error_text="" + + return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}" def format_seconds_to_hms(seconds): hours = int(seconds / 3600) @@ -408,6 +426,8 @@ def create_metadata(name_project,progress=gr.Progress()): file_raw = os.path.join(path_project,"raw.arrow") file_duration = os.path.join(path_project,"duration.json") file_vocab = os.path.join(path_project,"vocab.txt") + + if os.path.isfile(file_metadata)==False: return "The file was not found in " + file_metadata with open(file_metadata,"r",encoding="utf-8") as f: data=f.read() @@ -419,11 +439,18 @@ def create_metadata(name_project,progress=gr.Progress()): count=data.split("\n") lenght=0 result=[] + error_files=[] for line in progress.tqdm(data.split("\n"),total=count): sp_line=line.split("|") if len(sp_line)!=2:continue - name_audio,text = sp_line[:2] + name_audio,text = sp_line[:2] + file_audio = os.path.join(path_project_wavs, name_audio + ".wav") + + if os.path.isfile(file_audio)==False: + error_files.append(file_audio) + continue + duraction = get_audio_duration(file_audio) if duraction<2 and duraction>15:continue if len(text)<4:continue @@ -439,6 +466,10 @@ def create_metadata(name_project,progress=gr.Progress()): lenght+=duraction + if duration_list==[]: + error_files_text="\n".join(error_files) + return f"Error: No audio files found in the specified path : \n{error_files_text}" + min_second = round(min(duration_list),2) max_second = round(max(duration_list),2) @@ -450,9 +481,15 @@ def create_metadata(name_project,progress=gr.Progress()): json.dump({"duration": duration_list}, f, ensure_ascii=False) file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt" + if os.path.isfile(file_vocab_finetune==False):return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!" shutil.copy2(file_vocab_finetune, file_vocab) - - return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n" + + if error_files!=[]: + error_text="error files\n" + "\n".join(error_files) + else: + error_text="" + + return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}" def check_user(value): return gr.update(visible=not value),gr.update(visible=value) @@ -466,15 +503,19 @@ def calculate_train(name_project,batch_size_type,max_samples,learning_rate,num_w data = json.load(file) duration_list = data['duration'] + samples = len(duration_list) - gpu_properties = torch.cuda.get_device_properties(0) - total_memory = gpu_properties.total_memory / (1024 ** 3) + if torch.cuda.is_available(): + gpu_properties = torch.cuda.get_device_properties(0) + total_memory = gpu_properties.total_memory / (1024 ** 3) + elif torch.backends.mps.is_available(): + total_memory = psutil.virtual_memory().available / (1024 ** 3) if batch_size_type=="frame": batch = int(total_memory * 0.5) batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch) - batch_size_per_gpu = int(36800 / batch ) + batch_size_per_gpu = int(38400 / batch ) else: batch_size_per_gpu = int(total_memory / 8) batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu) @@ -509,13 +550,12 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) - if ema_model_state_dict is not None: new_checkpoint = {'ema_model_state_dict': ema_model_state_dict} torch.save(new_checkpoint, new_checkpoint_path) - print(f"New checkpoint saved at: {new_checkpoint_path}") + return f"New checkpoint saved at: {new_checkpoint_path}" else: - print("No 'ema_model_state_dict' found in the checkpoint.") + return "No 'ema_model_state_dict' found in the checkpoint." except Exception as e: - print(f"An error occurred: {e}") - + return f"An error occurred: {e}" def vocab_check(project_name): name_project = project_name + "_pinyin" @@ -524,12 +564,17 @@ def vocab_check(project_name): file_metadata = os.path.join(path_project, "metadata.csv") file_vocab="data/Emilia_ZH_EN_pinyin/vocab.txt" + if os.path.isfile(file_vocab)==False: + return f"the file {file_vocab} not found !" with open(file_vocab,"r",encoding="utf-8") as f: data=f.read() vocab = data.split("\n") + if os.path.isfile(file_metadata)==False: + return f"the file {file_metadata} not found !" + with open(file_metadata,"r",encoding="utf-8") as f: data=f.read() @@ -548,6 +593,7 @@ def vocab_check(project_name): if miss_symbols==[]:info ="You can train using your language !" else:info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols) + return info @@ -652,8 +698,9 @@ def vocab_check(project_name): with gr.TabItem("reduse checkpoint"): txt_path_checkpoint = gr.Text(label="path checkpoint :") txt_path_checkpoint_small = gr.Text(label="path output :") + txt_info_reduse = gr.Text(label="info",value="") reduse_button = gr.Button("reduse") - reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small]) + reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small],outputs=[txt_info_reduse]) with gr.TabItem("vocab check experiment"): check_button = gr.Button("check vocab") @@ -680,10 +727,4 @@ def main(port, host, share, api): ) if __name__ == "__main__": - name="my_speak" - - #create_data_project(name) - #transcribe_all(name) - #create_metadata(name) - main()