forked from as-ideas/ForwardTacotron
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_forward.py
181 lines (151 loc) · 8.26 KB
/
gen_forward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
from models.fatchord_version import WaveRNN
from models.forward_tacotron import ForwardTacotron
from utils import hparams as hp
from utils.text.symbols import phonemes
from utils.paths import Paths
import argparse
from utils.text import text_to_sequence, clean_text
from utils.display import simple_table
from utils.dsp import reconstruct_waveform, save_wav
if __name__ == '__main__':
# Parse Arguments
parser = argparse.ArgumentParser(description='TTS Generator')
parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!')
parser.add_argument('--tts_weights', type=str, help='[string/path] Load in different FastSpeech weights')
parser.add_argument('--save_attention', '-a', dest='save_attn', action='store_true', help='Save Attention Plots')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
parser.add_argument('--alpha', type=float, default=1., help='Parameter for controlling length regulator for speedup '
'or slow-down of generated speech, e.g. alpha=2.0 is double-time')
parser.add_argument('--amp', type=float, default=1., help='Parameter for controlling pitch amplification')
parser.set_defaults(input_text=None)
parser.set_defaults(weights_path=None)
# name of subcommand goes to args.vocoder
subparsers = parser.add_subparsers(dest='vocoder')
wr_parser = subparsers.add_parser('wavernn', aliases=['wr'])
wr_parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation')
wr_parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation')
wr_parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
wr_parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
wr_parser.add_argument('--voc_weights', type=str, help='[string/path] Load in different WaveRNN weights')
wr_parser.set_defaults(batched=None)
gl_parser = subparsers.add_parser('griffinlim', aliases=['gl'])
gl_parser.add_argument('--iters', type=int, default=32, help='[int] number of griffinlim iterations')
mg_parser = subparsers.add_parser('melgan', aliases=['mg'])
args = parser.parse_args()
if args.vocoder in ['griffinlim', 'gl']:
args.vocoder = 'griffinlim'
elif args.vocoder in ['wavernn', 'wr']:
args.vocoder = 'wavernn'
elif args.vocoder in ['melgan', 'mg']:
args.vocoder = 'melgan'
else:
raise argparse.ArgumentError('Must provide a valid vocoder type!')
hp.configure(args.hp_file) # Load hparams from file
# set defaults for any arguments that depend on hparams
if args.vocoder == 'wavernn':
if args.target is None:
args.target = hp.voc_target
if args.overlap is None:
args.overlap = hp.voc_overlap
if args.batched is None:
args.batched = hp.voc_gen_batched
batched = args.batched
target = args.target
overlap = args.overlap
input_text = args.input_text
tts_weights = args.tts_weights
save_attn = args.save_attn
paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
if not args.force_cpu and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print('Using device:', device)
if args.vocoder == 'wavernn':
print('\nInitialising WaveRNN Model...\n')
# Instantiate WaveRNN Model
voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
fc_dims=hp.voc_fc_dims,
bits=hp.bits,
pad=hp.voc_pad,
upsample_factors=hp.voc_upsample_factors,
feat_dims=hp.num_mels,
compute_dims=hp.voc_compute_dims,
res_out_dims=hp.voc_res_out_dims,
res_blocks=hp.voc_res_blocks,
hop_length=hp.hop_length,
sample_rate=hp.sample_rate,
mode=hp.voc_mode).to(device)
voc_load_path = args.voc_weights if args.voc_weights else paths.voc_latest_weights
voc_model.load(voc_load_path)
print('\nInitialising Forward TTS Model...\n')
tts_model = ForwardTacotron(embed_dims=hp.forward_embed_dims,
num_chars=len(phonemes),
durpred_rnn_dims=hp.forward_durpred_rnn_dims,
durpred_conv_dims=hp.forward_durpred_conv_dims,
durpred_dropout=hp.forward_durpred_dropout,
pitch_rnn_dims=hp.forward_pitch_rnn_dims,
pitch_conv_dims=hp.forward_pitch_conv_dims,
pitch_dropout=hp.forward_pitch_dropout,
pitch_emb_dims=hp.forward_pitch_emb_dims,
pitch_proj_dropout=hp.forward_pitch_proj_dropout,
rnn_dim=hp.forward_rnn_dims,
postnet_k=hp.forward_postnet_K,
postnet_dims=hp.forward_postnet_dims,
prenet_k=hp.forward_prenet_K,
prenet_dims=hp.forward_prenet_dims,
highways=hp.forward_num_highways,
dropout=hp.forward_dropout,
n_mels=hp.num_mels).to(device)
tts_load_path = tts_weights if tts_weights else paths.forward_latest_weights
tts_model.load(tts_load_path)
if input_text:
text = clean_text(input_text.strip())
inputs = [text_to_sequence(text)]
else:
with open('sentences.txt') as f:
inputs = [clean_text(l.strip()) for l in f]
inputs = [text_to_sequence(t) for t in inputs]
tts_k = tts_model.get_step() // 1000
if args.vocoder == 'wavernn':
voc_k = voc_model.get_step() // 1000
simple_table([('Forward Tacotron', str(tts_k) + 'k'),
('Vocoder Type', 'WaveRNN'),
('WaveRNN', str(voc_k) + 'k'),
('Generation Mode', 'Batched' if batched else 'Unbatched'),
('Target Samples', target if batched else 'N/A'),
('Overlap Samples', overlap if batched else 'N/A')])
elif args.vocoder == 'griffinlim':
simple_table([('Forward Tacotron', str(tts_k) + 'k'),
('Vocoder Type', 'Griffin-Lim'),
('GL Iters', args.iters)])
elif args.vocoder == 'melgan':
simple_table([('Forward Tacotron', str(tts_k) + 'k'),
('Vocoder Type', 'MelGAN')])
# simpla amplification of pitch
pitch_function = lambda x: x * args.amp
for i, x in enumerate(inputs, 1):
print(f'\n| Generating {i}/{len(inputs)}')
_, m, dur, pitch = tts_model.generate(x, alpha=args.alpha, pitch_function=pitch_function)
if args.vocoder == 'griffinlim':
v_type = args.vocoder
elif args.vocoder == 'wavernn' and args.batched:
v_type = 'wavernn_batched'
else:
v_type = 'wavernn_unbatched'
if input_text:
save_path = paths.forward_output/f'{input_text[:10]}_{args.alpha}_{v_type}_{tts_k}k_amp{args.amp}.wav'
else:
save_path = paths.forward_output/f'{i}_{v_type}_{tts_k}k_alpha{args.alpha}_amp{args.amp}.wav'
if args.vocoder == 'wavernn':
m = torch.tensor(m).unsqueeze(0)
voc_model.generate(m, save_path, batched, hp.voc_target, hp.voc_overlap, hp.mu_law)
if args.vocoder == 'melgan':
m = torch.tensor(m).unsqueeze(0)
torch.save(m, paths.forward_output/f'{i}_{tts_k}_alpha{args.alpha}_amp{args.amp}.mel')
elif args.vocoder == 'griffinlim':
wav = reconstruct_waveform(m, n_iter=args.iters)
save_wav(wav, save_path)
print('\n\nDone.\n')