Skip to content

Commit

Permalink
add arrow dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
lpscr committed Nov 6, 2024
1 parent 2d2452e commit 73bf6be
Showing 1 changed file with 81 additions and 0 deletions.
81 changes: 81 additions & 0 deletions src/f5_tts/train/finetune_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np
import torch
import torchaudio
from datasets import load_dataset
from datasets import Dataset as Dataset_
from datasets.arrow_writer import ArrowWriter
from safetensors.torch import save_file
Expand All @@ -31,6 +32,7 @@
from f5_tts.api import F5TTS
from f5_tts.model.utils import convert_char_to_pinyin
from importlib.resources import files
import soundfile as sf

training_process = None
system = platform.system()
Expand All @@ -44,6 +46,7 @@
path_data = str(files("f5_tts").joinpath("../../data"))
path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))


file_train = "src/f5_tts/train/finetune_cli.py"

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
Expand Down Expand Up @@ -786,6 +789,65 @@ def has_supported_extension(file_name):
return file_audio


def get_nested_value(data, format):
keys = format.split("/")

item = data
for key in keys:
item = item.get(key)
if item is None:
return None

return item


def create_metadata_from_arrow(
name_project, arrow_type, arrow_path, arrow_name, arrow_text, arrow_audio, arrow_split, progress=gr.Progress()
):
path_project = os.path.join(path_data, name_project)
path_project_wavs = os.path.join(path_project, "wavs")
file_metadata = os.path.join(path_project, "metadata.csv")
file_custom_dataset_dir = os.path.join(path_project, "custom_dataset_dir")
os.makedirs(file_custom_dataset_dir, exist_ok=True)
os.makedirs(path_project_wavs, exist_ok=True)
data = ""
num = 0
if arrow_type == "Local" or arrow_type == "Online":
if arrow_type == "locals":
dataset = Dataset_.from_file(arrow_path)

if arrow_type == "Online":
if arrow_split == "":
arrow_split = None
if arrow_name == "":
arrow_name = None
dataset = load_dataset(arrow_path, arrow_name, split=arrow_split, cache_dir=file_custom_dataset_dir)

is_audio_path = None
for item in progress.tqdm(dataset):
text = get_nested_value(item, arrow_text)
audio = get_nested_value(item, arrow_audio)

if is_audio_path is None:
if isinstance(audio, str):
is_audio_path = True
else:
is_audio_path = False

if not is_audio_path:
namefile = "segment_{num}"
filename = os.path.join(path_project_wavs, namefile + ".wav")
sf.write(filename, audio, 24000)
num += 1
else:
filename = audio

data += f"{filename}|{text}\n"

with open(file_metadata, "w", encoding="utf-sig") as f:
f.write(data)


def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
path_project = os.path.join(path_data, name_project)
path_project_wavs = os.path.join(path_project, "wavs")
Expand Down Expand Up @@ -1505,6 +1567,18 @@ def get_audio_select(file_sample):
Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
```""")

with gr.Accordion("Dataset Arrow", open=False):
arrow_type = gr.Radio(label="Type", choices=["Local", "Online"], value="Local")
with gr.Row():
arrow_path = gr.Textbox(label="Path", value="")
arrow_name = gr.Textbox(label="Name", value="")
arrow_split = gr.Textbox(label="Split", value="")

with gr.Row():
arrow_text = gr.Textbox(label="Text", value="audio/array")
arrow_audio = gr.Textbox(label="Audio", value="transcript")
bt_covert_metadata = bt_create = gr.Button("Create Metadata")

gr.Markdown(
"""```plaintext
Place all your "wavs" folder and your "metadata.csv" file in your project name directory.
Expand All @@ -1530,6 +1604,7 @@ def get_audio_select(file_sample):
```"""
)
ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False)

bt_prepare = bt_create = gr.Button("Prepare")
txt_info_prepare = gr.Text(label="Info", value="")
txt_vocab_prepare = gr.Text(label="Vocab", value="")
Expand All @@ -1538,6 +1613,12 @@ def get_audio_select(file_sample):
fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare]
)

bt_covert_metadata.click(
fn=create_metadata_from_arrow,
inputs=[cm_project, arrow_type, arrow_path, arrow_name, arrow_text, arrow_audio, arrow_split],
outputs=[],
)

random_sample_prepare = gr.Button("Random Sample")

with gr.Row():
Expand Down

0 comments on commit 73bf6be

Please sign in to comment.