-
Notifications
You must be signed in to change notification settings - Fork 141
/
train_postnet.py
69 lines (45 loc) · 2 KB
/
train_postnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from preprocess import get_post_dataset, DataLoader, collate_fn_postnet
from network import *
from tensorboardX import SummaryWriter
import torchvision.utils as vutils
import os
from tqdm import tqdm
def adjust_learning_rate(optimizer, step_num, warmup_step=4000):
lr = hp.lr * warmup_step**0.5 * min(step_num * warmup_step**-1.5, step_num**-0.5)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def main():
dataset = get_post_dataset()
global_step = 0
m = nn.DataParallel(ModelPostNet().cuda())
m.train()
optimizer = t.optim.Adam(m.parameters(), lr=hp.lr)
writer = SummaryWriter()
for epoch in range(hp.epochs):
dataloader = DataLoader(dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=collate_fn_postnet, drop_last=True, num_workers=8)
pbar = tqdm(dataloader)
for i, data in enumerate(pbar):
pbar.set_description("Processing at epoch %d"%epoch)
global_step += 1
if global_step < 400000:
adjust_learning_rate(optimizer, global_step)
mel, mag = data
mel = mel.cuda()
mag = mag.cuda()
mag_pred = m.forward(mel)
loss = nn.L1Loss()(mag_pred, mag)
writer.add_scalars('training_loss',{
'loss':loss,
}, global_step)
optimizer.zero_grad()
# Calculate gradients
loss.backward()
nn.utils.clip_grad_norm_(m.parameters(), 1.)
# Update weights
optimizer.step()
if global_step % hp.save_step == 0:
t.save({'model':m.state_dict(),
'optimizer':optimizer.state_dict()},
os.path.join(hp.checkpoint_path,'checkpoint_postnet_%d.pth.tar' % global_step))
if __name__ == '__main__':
main()