-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
122 lines (101 loc) · 5.7 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from waveUNet.test import *
import os
import pickle
from torch.utils.tensorboard import SummaryWriter
from model.waveunet_params import waveunet_params
from model.waveunet import Waveunet
from data.musdb import get_musdb_folds
from math import ceil
from functools import partial
from data.dataset import SeparationDataset
from data.utils import crop_targets, random_amplify
def _compute_metrics(args, musdb, model, writer, state):
# Mir_eval metrics
test_metrics = evaluate(args, musdb["test"], model, args.instruments)
# Dump all metrics results into pickle file for later analysis if needed
with open(os.path.join(args.checkpoint_dir, "results.pkl"), "wb") as f:
pickle.dump(test_metrics, f)
# Write most important metrics into Tensorboard log
avg_SDRs = {inst: np.mean([np.nanmean(song[inst]["SDR"]) for song in test_metrics]) for inst in args.instruments}
avg_SIRs = {inst: np.mean([np.nanmean(song[inst]["SIR"]) for song in test_metrics]) for inst in args.instruments}
for inst in args.instruments:
sdr_name = "test_SDR_" + inst
writer.add_scalar(sdr_name, avg_SDRs[inst], state["step"])
print(f'{sdr_name}: {avg_SDRs[inst]}')
sir_name = "test_SIR_" + inst
writer.add_scalar(sir_name, avg_SIRs[inst], state["step"])
print(f'{sir_name}: {avg_SIRs[inst]}')
overall_SDR = np.mean([v for v in avg_SDRs.values()])
writer.add_scalar("test_SDR", overall_SDR)
print("SDR: " + str(overall_SDR))
def _create_waveunet(args):
num_features = [args.features * i for i in range(1, args.levels + 1)] if args.feature_growth == "add" else \
[args.features * 2 ** i for i in range(0, args.levels)]
target_outputs = ceil(args.output_size * args.sr)
model = Waveunet(args.channels, num_features, args.channels, args.instruments, downsampling_kernel_size=args.downsampling_kernel_size,
upsampling_kernel_size=args.upsampling_kernel_size, bottleneck_kernel_size=args.bottleneck_kernel_size,
target_output_size=target_outputs, depth=args.depth, strides=args.strides,
conv_type=args.conv_type, res=args.res, separate=args.separate, num_convs=args.num_convs)
if args.cuda:
model = model_utils.DataParallel(model)
print("move model to gpu")
model.cuda()
print('model: ', model)
print('parameter count: ', str(sum(p.numel() for p in model.parameters())))
return model
def _load_musdb(args, data_shapes):
musdb = get_musdb_folds(args.dataset_dir)
# If not data augmentation, at least crop targets to fit model output shape
crop_func = partial(crop_targets, shapes=data_shapes)
# Data augmentation function for training
augment_func = partial(random_amplify, shapes=data_shapes, min=0.7, max=1.0)
train_data = SeparationDataset(musdb, "train", args.instruments, args.sr, args.channels, data_shapes, True,
args.hdf_dir, audio_transform=augment_func)
val_data = SeparationDataset(musdb, "val", args.instruments, args.sr, args.channels, data_shapes, False,
args.hdf_dir, audio_transform=crop_func)
test_data = SeparationDataset(musdb, "test", args.instruments, args.sr, args.channels, data_shapes, False,
args.hdf_dir, audio_transform=crop_func)
dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, worker_init_fn=utils.worker_init_fn)
return train_data, val_data, test_data, dataloader, musdb
def validate(args, model, criterion, dataloader):
'''
Iterate with a given model over a given test dataset and compute the desired loss
:param args: Options dictionary
:param model: Pytorch model
:param criterion: Loss function to use (similar to Pytorch criterions)
:param dataloader: Loads validation samples. Must have property 'num_samples' with number of samples in split.
:return:
'''
# VALIDATE
model.eval()
total_loss = 0.
with tqdm(total=dataloader.num_samples // args.batch_size) as pbar, torch.no_grad():
for example_num, (x, targets) in enumerate(dataloader):
if args.cuda:
x = x.cuda()
for k in list(targets.keys()):
targets[k] = targets[k].cuda()
_, avg_loss = model_utils.compute_loss(model, x, targets, criterion)
total_loss += (1. / float(example_num + 1)) * (avg_loss - total_loss)
pbar.set_description("Current loss: {:.4f}".format(total_loss))
pbar.update(1)
return total_loss
if __name__ == '__main__':
experiment_name = os.environ['JOB_NAME'] if 'JOB_NAME' in os.environ else 'test'
args = waveunet_params.parse_args().get_comb_partition(0, 1).get_comb(0)
hdf_subdir = "_".join(args.instruments) + f"_{args.sr}_{args.channels}"
args.hdf_dir = os.path.join(args.hdf_dir, hdf_subdir)
if (not os.path.exists(args.dataset_dir)):
raise ValueError(f"Dataset directory {args.dataset_dir} does not exist.")
# Save checkpoints and logs in separate directories for each experiment
args.checkpoint_dir = os.path.join(args.checkpoint_dir, experiment_name)
os.makedirs(args.checkpoint_dir,exist_ok=True)
args.log_dir = os.path.join(args.log_dir, experiment_name)
model = _create_waveunet(args)
writer = SummaryWriter(args.log_dir)
train_data, val_data, test_data, dataloader, musdb = _load_musdb(args, model.shapes)
print("Test model from checkpoint " + str(args.load_model))
state = model_utils.load_model(model, None, args.load_model, args.cuda)
_compute_metrics(args, musdb, model, writer, state)
writer.close()