Skip to content

Commit

Permalink
Merge pull request #88 from JarodMica/main
Browse files Browse the repository at this point in the history
Update to make passing in custom paths easier for finetuning/training
  • Loading branch information
SWivid authored Oct 15, 2024
2 parents 03048d6 + 4cdcccf commit 658cfa5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
20 changes: 18 additions & 2 deletions model/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,15 @@ def __len__(self):

def load_dataset(
dataset_name: str,
tokenizer: str,
tokenizer: str = "pinyon",
dataset_type: str = "CustomDataset",
audio_type: str = "raw",
mel_spec_kwargs: dict = dict()
) -> CustomDataset | HFDataset:
'''
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
- "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
'''
) -> CustomDataset:

print("Loading dataset ...")
Expand All @@ -206,7 +211,18 @@ def load_dataset(
data_dict = json.load(f)
durations = data_dict["duration"]
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)


elif dataset_type == "CustomDatasetPath":
try:
train_dataset = load_from_disk(f"{dataset_name}/raw")
except:
train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")

with open(f"{dataset_name}/duration.json", 'r', encoding='utf-8') as f:
data_dict = json.load(f)
durations = data_dict["duration"]
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)

elif dataset_type == "HFDataset":
print("Should manually modify the path of huggingface dataset to your need.\n" +
"May also the corresponding script cuz different dataset may have different format.")
Expand Down
7 changes: 7 additions & 0 deletions model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
Expand All @@ -144,6 +145,12 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
elif tokenizer == "byte":
vocab_char_map = None
vocab_size = 256
elif tokenizer == "custom":
with open (dataset_name, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)

return vocab_char_map, vocab_size

Expand Down
11 changes: 7 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
n_mel_channels = 100
hop_length = 256

tokenizer = "pinyin"
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
dataset_name = "Emilia_ZH_EN"


# -------------------------- Training Settings -------------------------- #

exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
Expand Down Expand Up @@ -44,8 +44,11 @@
# ----------------------------------------------------------------------- #

def main():

vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
if tokenizer == "custom":
tokenizer_path = tokenizer_path
else:
tokenizer_path = dataset_name
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)

mel_spec_kwargs = dict(
target_sample_rate = target_sample_rate,
Expand Down

0 comments on commit 658cfa5

Please sign in to comment.