-
Notifications
You must be signed in to change notification settings - Fork 2
/
rtest.py
125 lines (100 loc) · 4.56 KB
/
rtest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from models import *
from utils.datasets import *
from utils.utils import *
from tqdm import tnrange, tqdm_notebook
def test(cfg, data_cfg, weights=None, batch_size=16, img_size=416, iou_thres=0.5, conf_thres=0.001, nms_thres=0.5, model=None):
if model is None:
device = torch_utils.select_device()
# Initialize model
model = Darknet(cfg, img_size).to(device)
# Load weights
if weights.endswith('.pt'): # pytorch format
model.load_state_dict(torch.load(weights, map_location=device)['model'])
else: # darknet format
_ = load_darknet_weights(model, weights)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
else:
device = next(model.parameters()).device # get model device
# Configure run
data_cfg = parse_data_cfg(data_cfg)
nc = int(data_cfg['classes']) # number of classes
test_path = data_cfg['valid'] # path to test images
names = load_classes(data_cfg['names']) # class names
# Dataloader
dataset = LoadDataset(test_path, img_size=img_size)
dataloader = DataLoader(dataset,
batch_size=batch_size,
num_workers=4,
pin_memory=True,
collate_fn=dataset.collate_fn)
seen = 0
model.eval()
#print(('%20s' + '%10s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP', 'F1'))
loss, p, r, f1, mp, mr, map, mf1 = 0., 0., 0., 0., 0., 0., 0., 0.
jdict, stats, ap, ap_class = [], [], [], []
for batch_i, (imgs, targets, paths, shapes) in enumerate(tqdm_notebook(dataloader)):
targets = targets.to(device)
imgs = imgs.to(device)
# Plot images with bounding boxes
if batch_i == 0 and not os.path.exists('test_batch0.jpg'):
plot_images(imgs=imgs, targets=targets, fname='test_batch0.jpg')
# Run model
inf_out, train_out = model(imgs) # inference and training outputs
# Build targets
target_list = build_targets(model, targets)
# Compute loss
loss_i, _ = compute_loss(train_out, target_list)
loss += loss_i.item()
# Run NMS
output = non_max_suppression(inf_out, conf_thres=conf_thres, nms_thres=nms_thres)
# Statistics per image
for si, pred in enumerate(output):
labels = targets[targets[:, 0] == si, 1:]
nl = len(labels)
tcls = labels[:, 0].tolist() if nl else [] # target class
seen += 1
if pred is None:
if nl:
stats.append(([], torch.Tensor(), torch.Tensor(), tcls))
continue
# Assign all predictions as incorrect
correct = [0] * len(pred)
if nl:
detected = []
tbox = xywh2xyxy(labels[:, 1:5]) * img_size # target boxes
# Search for correct predictions
for i, (*pbox, pconf, pcls_conf, pcls) in enumerate(pred):
# Break if all targets already located in image
if len(detected) == nl:
break
# Continue if predicted class not among image classes
if pcls.item() not in tcls:
continue
# Best iou, index between pred and targets
iou, bi = bbox_iou(pbox, tbox).max(0)
# If iou > threshold and class is correct mark as correct
if iou > iou_thres and bi not in detected:
correct[i] = 1
detected.append(bi)
# Append statistics (correct, conf, pcls, tcls)
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls))
# Compute statistics
stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))]
nt = np.bincount(stats_np[3].astype(np.int64), minlength=nc) # number of targets per class
if len(stats_np):
p, r, ap, f1, ap_class = ap_per_class(*stats_np)
mp, mr, map, mf1 = p.mean(), r.mean(), ap.mean(), f1.mean()
# Print results
#pf = '%20s' + '%10.3g' * 6 # print format
#print(pf % ('all', seen, nt.sum(), mp, mr, map, mf1), end='\n\n')
# Print results per class
#if nc > 1 and len(stats_np):
#for i, c in enumerate(ap_class):
#print(pf % (names[c], seen, nt[c], p[i], r[i], ap[i], f1[i]))
# Return results
return mp, mr, map, mf1, loss / len(dataloader)