-
Notifications
You must be signed in to change notification settings - Fork 2
/
test.py
79 lines (61 loc) · 2.3 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import os
import glob
import torchaudio
import torchaudio.transforms as T
class AudioPipeline(torch.nn.Module):
def __init__(
self,
freq=16000,
n_fft=1024,
n_mel=128,
win_length=1024,
hop_length=256,
):
super().__init__()
self.freq=freq
pad = int((n_fft-hop_length)/2)
self.spec = T.Spectrogram(n_fft=n_fft, win_length=win_length, hop_length=hop_length,
pad=pad, power=None,center=False, pad_mode='reflect', normalized=False, onesided=True)
self.mel_scale = T.MelScale(n_mels=n_mel, sample_rate=freq, n_stft=n_fft // 2 + 1)
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
shift_waveform = waveform
# Convert to power spectrogram
spec = self.spec(shift_waveform)
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
# Convert to mel-scale
mel = self.mel_scale(spec)
return mel
def load_local():
from hifigan.light.hifigan import HifiGAN
ckpt_path = None
if os.path.exists("logs_48k/lightning_logs"):
versions = glob.glob("logs_48k/lightning_logs/version_*")
if len(list(versions)) > 0:
last_ver = sorted(list(versions), key=lambda p: int(p.split("_")[-1]))[-1]
last_ckpt = os.path.join(last_ver, "checkpoints/last.ckpt")
if os.path.exists(last_ckpt):
ckpt_path = last_ckpt
print(ckpt_path)
model = HifiGAN.load_from_checkpoint(checkpoint_path=ckpt_path, strict=False)
return model.net_g
def load_remote():
return torch.hub.load("vtuber-plan/hifi-gan:v0.2.1", "hifigan_48k", force_reload=True)
device = "cpu"
# Load Remote checkpoint
# hifigan = load_remote().to(device)
# Load Local checkpoint
hifigan = load_local().to(device)
# Load audio
wav, sr = torchaudio.load("zszy_48k.wav")
assert sr == 48000
# mel = mel_spectrogram_torch(wav, 2048, 128, 48000, 512, 2048, 0, None, False)
audio_pipeline = AudioPipeline(freq=48000,
n_fft=2048,
n_mel=128,
win_length=2048,
hop_length=512)
mel = audio_pipeline(wav)
out = hifigan(mel)
wav_out = out.squeeze(0).cpu()
torchaudio.save("test_out.wav", wav_out, sr)