forked from yl4579/StarGANv2-VC
-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
168 lines (135 loc) · 5.27 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import yaml, os
from munch import Munch
import torch
import librosa
import torchaudio
from Utils.JDC.model import JDCNet
from models import Generator, MappingNetwork, StyleEncoder
from parallel_wavegan.utils import load_model # load vocoder
to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80,
n_fft=2048,
win_length=1200,
hop_length=300)
mean, std = -4, 4
# def build_speakers():
# cwd = os.getcwd()
# data_path = os.path.join(cwd, 'Data')
# speakers = []
# for file in os.listdir(data_path):
# # is directory and not raw
# if os.path.isdir(os.path.join(data_path, file)) and '_' in file:
# speakers.append(file)
# speakers_dict = {}
# for t in enumerate(speakers):
# # @lw: key = speaker name, value = index
# speakers_dict[t[1]] = t[0]
# return speakers_dict
def preprocess(wave):
wave_tensor = torch.from_numpy(wave).float()
mel_tensor = to_mel(wave_tensor)
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
return mel_tensor
def build_model(model_params={}):
args = Munch(model_params)
generator = Generator(args.dim_in,
args.style_dim,
args.max_conv_dim,
w_hpf=args.w_hpf,
F0_channel=args.F0_channel)
mapping_network = MappingNetwork(args.latent_dim,
args.style_dim,
args.num_domains,
hidden_dim=args.max_conv_dim)
style_encoder = StyleEncoder(args.dim_in, args.style_dim, args.num_domains,
args.max_conv_dim)
nets_ema = Munch(generator=generator,
mapping_network=mapping_network,
style_encoder=style_encoder)
return nets_ema
def compute_style(speaker_dicts):
reference_embeddings = {}
for key, (path, speaker) in speaker_dicts.items():
if path == "":
# @lw: speaker = idx of the speaker name
label = torch.LongTensor([key]).to('cuda')
latent_dim = starganv2.mapping_network.shared[0].in_features
# @lw: get the reference embedding from the mapping network
ref = starganv2.mapping_network(
torch.randn(1, latent_dim).to('cuda'), label)
else:
wave, sr = librosa.load(path, sr=24000)
audio, index = librosa.effects.trim(wave, top_db=30)
if sr != 24000:
wave = librosa.resample(wave, sr, 24000)
mel_tensor = preprocess(wave).to('cuda')
with torch.no_grad():
label = torch.LongTensor([key])
ref = starganv2.style_encoder(mel_tensor.unsqueeze(1), label)
reference_embeddings[key] = (ref, label)
return reference_embeddings
def load_F0(f0_path="./Utils/JDC/bst.t7"):
''' @lw
return F0 model
:f0_path: default path is "./Utils/JDC/bst.t7"
'''
assert torch.cuda.is_available(), "CUDA is unavailable."
F0_model = JDCNet(num_class=1, seq_len=192)
params = torch.load(f0_path)['net']
F0_model.load_state_dict(params)
_ = F0_model.eval()
F0_model = F0_model.to('cuda')
return F0_model
def load_vocoder(vocoder_path="./Vocoder/checkpoint-400000steps.pkl"):
'''@lw
return vocoder model
:vocoder_path: default path is "./Vocoder/checkpoint-400000steps.pkl"
'''
assert torch.cuda.is_available(), "CUDA is unavailable."
vocoder = load_model(vocoder_path).to('cuda').eval()
vocoder.remove_weight_norm()
_ = vocoder.eval()
return vocoder
def load_starganv2(gan_path='Models/epoch_v2_00248.pth'):
'''@lw
return starGANv2
:gan_path: default = Models/epoch_v2_00248.pth'
'''
assert torch.cuda.is_available(), "CUDA is unavailable."
with open('Models/config.yml') as f:
starganv2_config = yaml.safe_load(f)
starganv2 = build_model(model_params=starganv2_config["model_params"])
params = torch.load(gan_path, map_location='cpu')
params = params['model_ema']
# @lw: rebuild the parameter dictionary to elude key inconsistent issue
for k in params:
for s in list(params[k]):
v = params[k][s]
del params[k][s]
s = '.'.join(s.split('.')[1:])
params[k][s] = v
_ = [starganv2[key].load_state_dict(params[key]) for key in starganv2]
_ = [starganv2[key].eval() for key in starganv2]
starganv2.style_encoder = starganv2.style_encoder.to('cuda')
starganv2.mapping_network = starganv2.mapping_network.to('cuda')
starganv2.generator = starganv2.generator.to('cuda')
return starganv2
# speakers = build_speakers()
speakers = {
0: 'Dong_Mingzhu',
1: 'Hua_Chunying',
2: 'Li_Fanping',
3: 'Li_Gan',
4: 'Luo_Xiang',
5: 'Ma_Yun',
6: 'Shi_Zhuguo',
7: 'Wang_Cheng',
8: 'Wang_Kun',
9: 'Zhao_Lijian'
}
starganv2 = load_starganv2()
F0_model = load_F0()
vocoder = load_vocoder()
print('speakers id is {}'.format(id(speakers)))
print('starganv2 id is {}'.format(id(starganv2)))
print('F0_model id is {}'.format(id(F0_model)))
print('vocoder id is {}'.format(id(vocoder)))