forked from yl4579/StarGANv2-VC
-
Notifications
You must be signed in to change notification settings - Fork 1
/
realtime_infer.py
146 lines (112 loc) · 5.19 KB
/
realtime_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/env python3
"""Pass input directly to output.
See https://www.assembla.com/spaces/portaudio/subversion/source/HEAD/portaudio/trunk/test/patest_wire.c
"""
import argparse
import logging
import os.path
import pickle
from vocoder import *
from models import *
MAX_WAV_VALUE = 32768.0
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("-i", "--input-device", type=int, help="input device ID")
parser.add_argument("-o", "--output-device", type=int, help="output device ID")
parser.add_argument("-c", "--channels", type=int, default=2,
help="number of channels")
parser.add_argument("-t", "--dtype", help="audio data type")
parser.add_argument("-s", "--samplerate", type=float, help="sampling rate", default=24000)
parser.add_argument("-b", "--blocksize", type=int, help="block size", default=32 * 300)
parser.add_argument("-l", "--latency", type=float, help="latency in seconds", default=0)
parser.add_argument("-p", "--speaker", type=float, help="spekaer id", default=8)
parser.add_argument("-m", "--mel-buffer-size", type=float, help="mel buffer size", default=192)
parser.add_argument("-w", "--wave-buffer-size", type=float, help="wave buffer size in frame", default=10)
args = parser.parse_args()
generator = load_vocoder()
starganv2, F0_model = load_models()
ref = compute_style('reference.wav', starganv2, args.speaker)
def next_waves(wave, wave_next, wave_average=None):
wave_left = wave_next.shape[-1] - args.wave_buffer_size * 300 - args.blocksize
wave_right = wave_next.shape[-1] - args.wave_buffer_size * 300
a = wave.squeeze()[-300:]
b = wave_next[..., wave_left - a.shape[-1]:wave_right].squeeze()
print(a.shape, b.shape)
b = b[:a.shape[-1] + 300]
wave = wave_next[..., wave_left - a.shape[-1]:wave_right + 1 + 300].squeeze()
wave = np.clip(wave, -1, 1)
wave = wave[a.shape[-1] + 1:]
wave_return = wave[:-300]
wave_buffer_return = wave[-300:]
if wave_average is not None:
buffer_weight = np.linspace(1, 0, 300)
wave_return[0:300] = buffer_weight * wave_average + (1 - buffer_weight) * wave_return[0:300]
return wave_return, wave_buffer_return
try:
import sounddevice as sd
mel_buffer = torch.zeros(1, 80, int(args.mel_buffer_size - args.blocksize / 300)).to('cuda')
callback_status = sd.CallbackFlags()
wave_buffer = np.zeros(args.mel_buffer_size * 300)
previous_wave = None
wave_average = None
noisy_part = None
def callback(indata, outdata, frames, time, status):
""" @lw
callback is used to process the indata (recording audio) and yield outdata (converted audio)
"""
global callback_status
global mel_buffer
global wave_buffer
global previous_wave
global wave_average
# @lw: why use |= in here?
# @lw: check the overflow or underflow status
# @lw: TODO: how to adjust the buffer size dynamically?
callback_status |= status
with torch.no_grad():
audio = indata[:, 0].squeeze()
# wave buffer
wave_buffer = np.concatenate((wave_buffer, audio), axis=-1)
buffer_size = args.mel_buffer_size * 300 - args.blocksize
buffer_cut = int(wave_buffer.shape[-1] - buffer_size)
wave_buffer = wave_buffer[..., max(0, buffer_cut):]
print('mel_buffer', wave_buffer.shape)
mel = preprocess_GPU(wave_buffer)
mel = mel[..., 1:]
print('mel', mel.shape)
out = convert(mel, starganv2, F0_model, ref, args.speaker)
mel_left = out.shape[-1] - args.wave_buffer_size - int(args.blocksize / 300)
mel_right = out.shape[-1] - args.wave_buffer_size
mel_out = out[..., mel_left:mel_right]
# out = mel
wave = generator(out.squeeze().unsqueeze(0))
wave = wave.cpu().numpy()
wave.dtype = np.float32
if previous_wave is None:
wave_left = wave.shape[-1] - args.wave_buffer_size * 300 - args.blocksize
wave_right = wave.shape[-1] - args.wave_buffer_size * 300
print('wave range', wave_left, wave_right)
print('wave shape', wave.shape)
wave = wave[..., wave_left:wave_right]
previous_wave = wave
else:
wave, wave_average = next_waves(previous_wave, wave, wave_average)
previous_wave = wave
wave = np.expand_dims(wave.squeeze(), axis=1)
out = np.repeat(wave, 2, axis=1).squeeze()
if out.shape[0] == 2:
outdata[:] = out.transpose()
else:
outdata[:] = out
with sd.Stream(device=(args.input_device, args.output_device),
samplerate=args.samplerate, blocksize=args.blocksize,
dtype=args.dtype, latency=args.latency,
channels=args.channels, callback=callback):
print("#" * 80)
print("press Return to quit")
print("#" * 80)
input()
if callback_status:
logging.warning(str(callback_status))
except BaseException as e:
# This avoids printing the traceback, especially if Ctrl-C is used.
raise SystemExit(str(e))