forked from NVIDIA/tacotron2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsplit-data.py
79 lines (63 loc) · 2.19 KB
/
split-data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import random
import tarfile
import pandas as pd
import wandb
def split_dataset(source_artifact, n_train, n_validation):
"""Split raw data artifact into train and validation sets.
Args:
source_artifact: <artifact:version> formatted raw data artifact path.
n_train: Number of examples to include in the train set.
n_validation: Number of examples to include in the validation set.
"""
# Initialize wandb Run of type split-data
run = wandb.init(job_type="split-data")
# Download the raw data
source = run.use_artifact(source_artifact)
tarball_path = source.get_path("tarball").download()
# Extract raw data
tarball = tarfile.open(tarball_path, "r:bz2")
tarball.extractall()
# Construct new artifact
split_dataset = wandb.Artifact(
"split-ljs",
type="split data",
metadata={
"train-examples": n_train,
"val-examples": n_validation,
},
)
# Add transcription data to artifact
split_dataset.add_file("LJSpeech-1.1/metadata.csv", name="transcriptions")
meta = pd.read_csv(
"LJSpeech-1.1/metadata.csv",
sep="|",
names=["file", "sentence"],
index_col=0,
)
# Get a list of all wav files and randomize the order
all_files = os.listdir("LJSpeech-1.1/wavs")
assert n_train + n_validation <= len(all_files)
random.shuffle(all_files)
idx = 0
# Construct a tarball for each split of the data
for size, split in (
(n_train, "train"),
(n_validation, "validation"),
):
with tarfile.open(f"{split}.tar.bz2", "w:bz2") as tarball:
jdx = 0
while jdx < size:
if not str(meta.loc[all_files[idx].split(".")[0], "sentence"]).strip():
idx += 1
continue
tarball.add(
f"LJSpeech-1.1/wavs/{all_files[idx]}", arcname=f"{all_files[idx]}"
)
idx += 1
jdx += 1
split_dataset.add_file(f"{split}.tar.bz2")
# Log final artifact
run.log_artifact(split_dataset)
if __name__ == "__main__":
split_dataset("ljs-tarball:latest", 1024, 128)