forked from qiuqiangkong/mini_source_separation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
60 lines (45 loc) · 1.42 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import time
import librosa
import numpy as np
import soundfile
from pathlib import Path
from models.unet import UNet
from tqdm import tqdm
import museval
import argparse
from train import get_model, separate
def inference(args):
# Arguments
model_name = args.model_name
# Default parameters
sr = 44100
segment_seconds = 2.
clip_samples = round(segment_seconds * sr)
batch_size = 16
device = "cuda"
# Load checkpoint
checkpoint_path = Path("checkpoints", "train", model_name, "latest.pth")
model = get_model(model_name)
model.load_state_dict(torch.load(checkpoint_path))
model.to(device)
# Load audio. Change this path to your favorite song.
root = "/datasets/musdb18hq/test"
mixture_path = Path(root, "Al James - Schoolboy Facination", "mixture.wav")
mixture, orig_sr = librosa.load(path=mixture_path, sr=None, mono=False)
# (channels_num, audio_samples)
sep_wav = separate(
model=model,
audio=mixture,
clip_samples=clip_samples,
batch_size=batch_size
)
# Write out separated audio
sep_path = "sep.wav"
soundfile.write(file=sep_path, data=sep_wav.T, samplerate=orig_sr)
print("Write to {}".format(sep_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default="UNet")
args = parser.parse_args()
inference(args)