-
Notifications
You must be signed in to change notification settings - Fork 10
/
benchmark_pytorch.py
122 lines (104 loc) · 3.21 KB
/
benchmark_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import matplotlib
matplotlib.use('Agg')
import torch.utils
import os
import os.path
import random
import time
import argparse
import librosa
import utils
import loaders
import torch
def get_files(dir, extension):
audio_files = []
dir = os.path.expanduser(dir)
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if fname.endswith(extension):
path = os.path.join(root, fname)
item = path
audio_files.append(item)
return audio_files
class AudioFolder(torch.utils.data.Dataset):
def __init__(
self,
root,
extension='wav',
lib="librosa",
):
self.root = os.path.expanduser(root)
self.data = []
self.audio_files = get_files(dir=self.root, extension=extension)
self.loader_function = getattr(loaders, lib)
def __getitem__(self, index):
audio = self.loader_function(self.audio_files[index])
return torch.as_tensor(audio).view(1, 1, -1)
def __len__(self):
return len(self.audio_files)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--ext', type=str, default="wav")
args = parser.parse_args()
repeat = 3
columns = [
'ext',
'lib',
'duration',
'time',
]
store = utils.DF_writer(columns)
# audio formats to be bench
# libraries to be benchmarked
libs = [
'stempeg',
'soxbindings',
'ar_ffmpeg',
'aubio',
'pydub',
'soundfile',
'librosa',
'scipy',
'scipy_mmap',
]
if args.ext != "mp4":
libs.append('torchaudio-sox_io')
libs.append('torchaudio-soundfile')
for lib in libs:
print("Testing: %s" % lib)
if "torchaudio" in lib:
backend = lib.split("torchaudio-")[-1]
import torchaudio
torchaudio.set_audio_backend(backend)
call_fun = "load_torchaudio"
else:
call_fun = 'load_' + lib
for root, dirs, fnames in sorted(os.walk('AUDIO')):
for audio_dir in dirs:
try:
duration = int(audio_dir)
data = torch.utils.data.DataLoader(
AudioFolder(
os.path.join(root, audio_dir),
lib=call_fun,
extension=args.ext
),
batch_size=1,
num_workers=0,
shuffle=False
)
start = time.time()
for i in range(repeat):
for X in data:
X.max()
end = time.time()
store.append(
ext=args.ext,
lib=lib,
duration=duration,
time=float(end-start) / (len(data) * repeat),
)
except:
"Error but continue"
continue
store.df.to_pickle("results/benchmark_%s_%s.pickle" % ("pytorch", args.ext))