forked from xcmyz/FastSpeech
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
34 lines (23 loc) · 1.15 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
def cut(A, B, C):
min_len = min(A.size(1), B.size(1), C.size(1))
return A[:, 0:min_len, :], B[:, 0:min_len, :], C[:, 0:min_len, :]
class FastSpeechLoss(nn.Module):
""" FastSPeech Loss """
def __init__(self):
super(FastSpeechLoss, self).__init__()
self.mse_loss = nn.MSELoss()
def forward(self, mel, mel_postnet, duration_predictor, mel_target, duration_predictor_target):
mel_target.requires_grad = False
mel, mel_postnet, mel_target = cut(mel, mel_postnet, mel_target)
mel_loss = torch.abs(mel - mel_target)
mel_loss = torch.mean(mel_loss)
mel_postnet_loss = torch.abs(mel_postnet - mel_target)
mel_postnet_loss = torch.mean(mel_postnet_loss)
duration_predictor_target.requires_grad = False
duration_predictor_target = duration_predictor_target + 1
duration_predictor_target = torch.log(duration_predictor_target)
duration_predictor_loss = self.mse_loss(
duration_predictor, duration_predictor_target)
return mel_loss, mel_postnet_loss, duration_predictor_loss