-
Notifications
You must be signed in to change notification settings - Fork 317
/
librispeech.py
executable file
·118 lines (99 loc) · 4.31 KB
/
librispeech.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from tqdm import tqdm
from pathlib import Path
from os.path import join, getsize
from joblib import Parallel, delayed
from torch.utils.data import Dataset
# Additional (official) text src provided
OFFICIAL_TXT_SRC = ['librispeech-lm-norm.txt']
# Remove longest N sentence in librispeech-lm-norm.txt
REMOVE_TOP_N_TXT = 5000000
# Default num. of threads used for loading LibriSpeech
READ_FILE_THREADS = 4
def read_text(file):
'''Get transcription of target wave file,
it's somewhat redundant for accessing each txt multiplt times,
but it works fine with multi-thread'''
src_file = '-'.join(file.split('-')[:-1])+'.trans.txt'
idx = file.split('/')[-1].split('.')[0]
with open(src_file, 'r') as fp:
for line in fp:
if idx == line.split(' ')[0]:
return line[:-1].split(' ', 1)[1]
class LibriDataset(Dataset):
def __init__(self, path, split, tokenizer, bucket_size, ascending=False):
# Setup
self.path = path
self.bucket_size = bucket_size
# List all wave files
file_list = []
for s in split:
split_list = list(Path(join(path, s)).rglob("*.flac"))
assert len(split_list) > 0, "No data found @ {}".format(join(path,s))
file_list += split_list
# Read text
text = Parallel(n_jobs=READ_FILE_THREADS)(
delayed(read_text)(str(f)) for f in file_list)
#text = Parallel(n_jobs=-1)(delayed(tokenizer.encode)(txt) for txt in text)
text = [tokenizer.encode(txt) for txt in text]
# Sort dataset by text length
#file_len = Parallel(n_jobs=READ_FILE_THREADS)(delayed(getsize)(f) for f in file_list)
self.file_list, self.text = zip(*[(f_name, txt)
for f_name, txt in sorted(zip(file_list, text), reverse=not ascending, key=lambda x:len(x[1]))])
def __getitem__(self, index):
if self.bucket_size > 1:
# Return a bucket
index = min(len(self.file_list)-self.bucket_size, index)
return [(f_path, txt) for f_path, txt in
zip(self.file_list[index:index+self.bucket_size], self.text[index:index+self.bucket_size])]
else:
return self.file_list[index], self.text[index]
def __len__(self):
return len(self.file_list)
class LibriTextDataset(Dataset):
def __init__(self, path, split, tokenizer, bucket_size):
# Setup
self.path = path
self.bucket_size = bucket_size
self.encode_on_fly = False
read_txt_src = []
# List all wave files
file_list, all_sent = [], []
for s in split:
if s in OFFICIAL_TXT_SRC:
self.encode_on_fly = True
with open(join(path, s), 'r') as f:
all_sent += f.readlines()
file_list += list(Path(join(path, s)).rglob("*.flac"))
assert (len(file_list) > 0) or (len(all_sent)
> 0), "No data found @ {}".format(path)
# Read text
text = Parallel(n_jobs=READ_FILE_THREADS)(
delayed(read_text)(str(f)) for f in file_list)
all_sent.extend(text)
del text
# Encode text
if self.encode_on_fly:
self.tokenizer = tokenizer
self.text = all_sent
else:
self.text = [tokenizer.encode(txt) for txt in tqdm(all_sent)]
del all_sent
# Read file size and sort dataset by file size (Note: feature len. may be different)
self.text = sorted(self.text, reverse=True, key=lambda x: len(x))
if self.encode_on_fly:
del self.text[:REMOVE_TOP_N_TXT]
def __getitem__(self, index):
if self.bucket_size > 1:
index = min(len(self.text)-self.bucket_size, index)
if self.encode_on_fly:
for i in range(index, index+self.bucket_size):
if type(self.text[i]) is str:
self.text[i] = self.tokenizer.encode(self.text[i])
# Return a bucket
return self.text[index:index+self.bucket_size]
else:
if self.encode_on_fly and type(self.text[index]) is str:
self.text[index] = self.tokenizer.encode(self.text[index])
return self.text[index]
def __len__(self):
return len(self.text)