-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
102 lines (77 loc) · 2.78 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from __future__ import absolute_import, division, print_function, unicode_literals
import glob
import os
import argparse
import json
import torch
import numpy as np
import torchaudio
from tqdm import tqdm
from env import AttrDict
from models import Generator
from meldataset import mel_spectrogram, spectral_normalize_torch, MAX_WAV_VALUE
h = None
device = None
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print("Loading '{}'".format(filepath))
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
def get_mel(x):
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
def scan_checkpoint(cp_dir, prefix):
pattern = os.path.join(cp_dir, prefix + '*')
cp_list = glob.glob(pattern)
if len(cp_list) == 0:
return ''
return sorted(cp_list)[-1]
def inference(a):
generator = Generator(h).to(device)
state_dict_g = load_checkpoint(a.checkpoint_file, device)
generator.load_state_dict(state_dict_g['generator'])
filelist = os.listdir(a.input_mels_dir)
os.makedirs(a.output_dir, exist_ok=True)
generator.eval()
generator.remove_weight_norm()
with torch.no_grad():
for filname in tqdm(filelist):
x = np.load(os.path.join(a.input_mels_dir, filname))
x = torch.from_numpy(x).to(device)
if not a.normalized_mel:
x = spectral_normalize_torch(x)
y_g_hat = generator(x)
audio = y_g_hat.squeeze()
audio *= MAX_WAV_VALUE
audio = audio.cpu().numpy().astype('int16')
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '.wav')
torchaudio.save(
output_file,
src=torch.from_numpy(audio).view(1, -1),
sample_rate=h.sampling_rate
)
print(output_file)
def main():
print('Initializing Inference Process..')
parser = argparse.ArgumentParser()
parser.add_argument('--input_mels_dir', default='test_files')
parser.add_argument('--output_dir', default='generated_files')
parser.add_argument('--checkpoint_file', required=True)
parser.add_argument('--normalized_mel', action="store_true")
a = parser.parse_args()
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json')
with open(config_file) as f:
data = f.read()
global h
json_config = json.loads(data)
h = AttrDict(json_config)
torch.manual_seed(h.seed)
global device
if torch.cuda.is_available():
torch.cuda.manual_seed(h.seed)
device = torch.device('cuda')
else:
device = torch.device('cpu')
inference(a)
if __name__ == '__main__':
main()