-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
114 lines (94 loc) · 4.49 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import tensorflow as tf
import argparse, math, os, time, traceback
from hparams import hparams
from models import create_model
from models.datafeeder import DataFeeder
from util import audio, infolog, plot, ValueWindow
from util.text import sequence_to_text
log = infolog.log
def add_stats(model):
with tf.variable_scope('stats') as scope:
tf.summary.histogram('linear_outputs', model.linear_outputs)
tf.summary.histogram('linear_targets', model.linear_targets)
tf.summary.histogram('mel_outputs', model.mel_outputs)
tf.summary.histogram('mel_targets', model.mel_targets)
tf.summary.scalar('loss_mel', model.mel_loss)
tf.summary.scalar('loss_linear', model.linear_loss)
tf.summary.scalar('learning_rate', model.learning_rate)
tf.summary.scalar('loss', model.loss)
gradient_norms = [tf.norm(grad) for grad in model.gradients]
tf.summary.histogram('gradient_norm', gradient_norms)
tf.summary.scalar('max_gradient_norm', tf.reduce_max(gradient_norms))
return tf.summary.merge_all()
def train(log_dir, args):
checkpoint_path = os.path.join(log_dir, 'model.ckpt')
input_path = os.path.join(args.base_dir, args.input)
coord = tf.train.Coordinator()
with tf.variable_scope('datafeeder') as scope:
feeder = DataFeeder(coord, input_path, hparams)
global_step = tf.Variable(0, name='global_step', trainable=False)
with tf.variable_scope('model') as scope:
model = create_model(args.model, hparams)
model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets)
model.add_loss()
model.add_optimizer(global_step)
stats = add_stats(model)
# Bookkeeping
time_window = ValueWindow(100)
loss_window = ValueWindow(100)
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)
# Train
with tf.Session() as sess:
try:
summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
sess.run(tf.global_variables_initializer())
if args.restore_step:
restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
saver.restore(sess, restore_path)
log('Resuming from checkpoint: %s' % (restore_path), slack=True)
feeder.start_in_session(sess)
while not coord.should_stop():
start_time = time.time()
step, loss, opt, mel_loss, linear_loss = \
sess.run([global_step, model.loss, model.optimize, model.mel_loss, model.linear_loss])
time_window.append(time.time() - start_time)
loss_window.append(loss)
message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, mel_loss=%.5f, linear_loss=%.5f]' % (
step, time_window.average, loss, loss_window.average, mel_loss, linear_loss)
log(message, slack=(step % args.checkpoint_interval == 0))
if step % args.summary_interval == 0:
log('Writing summary at step: %d' % step)
summary_writer.add_summary(sess.run(stats), step)
if step % args.checkpoint_interval == 0:
log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
saver.save(sess, checkpoint_path, global_step=step)
log('Saving audio and alignment...')
input_seq, spectrogram, alignment = sess.run([
model.inputs[0], model.linear_outputs[0], model.alignments[0]])
waveform = audio.inv_spectrogram(spectrogram.T)
audio.save_wav(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step))
input_seq = sequence_to_text(input_seq)
plot.plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step), input_seq,
info='%s, step=%d, loss=%.5f' % (args.model, step, loss), istrain=1)
log('Input: %s' % input_seq)
except Exception as e:
log('Exiting due to exception: %s' % e, slack=True)
traceback.print_exc()
coord.request_stop(e)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--base_dir', default='./')
parser.add_argument('--input', default='training/train.txt')
parser.add_argument('--model', default='tacotron')
parser.add_argument('--restore_step', type=int)
parser.add_argument('--summary_interval', type=int, default=100)
parser.add_argument('--checkpoint_interval', type=int, default=1000)
args = parser.parse_args()
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
run_name = args.model
log_dir = os.path.join(args.base_dir, 'logs-%s' % run_name)
os.makedirs(log_dir, exist_ok=True)
infolog.init(os.path.join(log_dir, 'train.log'), run_name)
train(log_dir, args)
if __name__ == '__main__':
main()