-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdataset.py
136 lines (116 loc) · 4.65 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#----------------description----------------#
# Author : Lei yuan
# E-mail : [email protected]
# Company : Fudan University
# Date : 2020-10-11 15:28:41
# LastEditors : Zihao Zhao
# LastEditTime : 2020-12-21 20:42:35
# FilePath : /speech-to-text-wavenet/torch_lyuan/dataset.py
# Description :
#-------------------------------------------#
import torch
from torch.utils.data import Dataset
import utils
import random
import json
import os
import numpy as np
# import tensorflow as tf
# def create(filepath, batch_size=1, repeat=False, buffsize=1000):
# def _parse(record):
# keys_to_features = {
# 'uid': tf.FixedLenFeature([], tf.string),
# 'audio/data': tf.VarLenFeature(tf.float32),
# 'audio/shape': tf.VarLenFeature(tf.int64),
# 'text': tf.VarLenFeature(tf.int64)
# }
# features = tf.parse_single_example(
# record,
# features=keys_to_features
# )
# audio = features['audio/data'].values
# shape = features['audio/shape'].values
# audio = tf.reshape(audio, shape)
# audio = tf.contrib.layers.dense_to_sparse(audio)
# text = features['text']
# return audio, text, shape[0], features['uid']
# dataset = tf.data.TFRecordDataset(filepath).map(_parse).batch(batch_size=batch_size)
# loader = torch.utils.data.DataLoader(dataset, batch_size=32)
# return loader
class VCTK(Dataset):
def __init__(self, cfg, mode):
self.cfg = cfg
self.mode = mode
assert self.mode in ['train', 'val']
if not os.path.exists(self.cfg.datalist):
raise ValueError('datalist must exists, initial datalist is not supported')
self.train_filenames, self.test_filenames = json.load(open(self.cfg.datalist, 'r', encoding='utf-8'))
if self.mode == 'train':
self.max_wave = 520
self.max_text = 256
else:
self.max_wave = 720
self.max_text = 256
def __getitem__(self, idx):
if self.mode =='train':
filenames = self.train_filenames[idx]
else:
filenames = self.test_filenames[idx]
wave_path = self.cfg.dataset + filenames[0]
txt_path = self.cfg.dataset + filenames[1]
try:
text_tmp = utils.read_txt(txt_path) # list
wave_tmp = utils.read_wave(wave_path) # numpy
except OSError:
print(txt_path)
print(wave_path)
return self.__getitem__(0)
wave_tmp = torch.from_numpy(wave_tmp)
wave = torch.zeros([40, self.max_wave]) # 512 may be too short, if error,fix it
length_wave = wave_tmp.shape[1]
# print(length_wave)
wave[:,:length_wave] = wave_tmp
# print(txt_path)
while 27 in text_tmp:
text_tmp.remove(27)
length_text = len(text_tmp)
text_tmp = torch.tensor(text_tmp)
text = torch.zeros([self.max_text]) # 256 may be too short, fix it, if error
text[:length_text] = text_tmp
name = filenames[0].split('/')[-1]
if length_text >= length_wave:
sample = {'name':name, 'wave':torch.zeros([40, self.max_wave],dtype=torch.float), 'text':torch.zeros([self.max_text],dtype=torch.float),
'length_wave':self.max_wave, 'length_text':self.max_text}
else:
sample = {'name':name, 'wave':wave, 'text':text,
'length_wave':length_wave, 'length_text':length_text}
return sample
def __len__(self):
if self.mode == 'train':
return len(self.train_filenames)
else:
return len(self.test_filenames)
if __name__ == '__main__':
# train_filenames, test_filenames = json.load(open('/lyuan/code/speech-to-text-wavenet/data/list.json', 'r', encoding='utf-8'))
# print(len(train_filenames), train_filenames) #[['/VCTK-Corpus/wav48/p376/p376_076.wav', '/VCTK-Corpus/txt/p376/p376_076.txt'], ['/VCTK-Corpus/wav48/p376/p376_021.wav', '/VCTK-Corpus/txt/p376/p376_021.txt']]
import config_train as cfg
# vctk = VCTK(cfg, 'train')
# length = len(vctk)
# max_length = 0
# for i in range(length):
# tmp = vctk[i]['wave'].shape[1]
# if tmp>max_length:
# max_length = tmp
# print(f'train set {max_length}')
vctk = VCTK(cfg, 'val')
length = len(vctk)
max_wave = 0
max_text = 0
for i in range(length):
length_wave = vctk[i]['length_wave']
if length_wave > max_wave:
max_wave = length_wave
length_text = vctk[i]['length_text']
if length_text > max_text:
max_text = length_text
print(f'val set {i}, {length}, {max_wave}, {max_text}, {length_wave}, {length_text}')