-
Notifications
You must be signed in to change notification settings - Fork 5
/
val.py
114 lines (77 loc) · 3.39 KB
/
val.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torchvision
import os
import sys
import cv2
import numpy as np
from pathlib import Path
from tqdm import tqdm
from argparse import ArgumentParser
from models.tracknet import TrackNet
from utils.dataloaders import create_dataloader
from utils.general import check_dataset, outcome, evaluation
# from yolov5 detect.py
FILE = Path(__file__).resolve()
ABS_ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ABS_ROOT) not in sys.path:
sys.path.append(str(ABS_ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ABS_ROOT, Path.cwd())) # relative
def wbce_loss(y_true, y_pred):
return -1*(
((1-y_pred)**2) * y_true * torch.log(torch.clamp(y_pred, min=1e-07, max=1)) +
(y_pred**2) * (1-y_true) * torch.log(torch.clamp(1-y_pred, min=1e-07, max=1))
).sum()
def validation_loop(device, model, val_loader, save_dir):
model.eval()
loss_sum = 0
TP = TN = FP1 = FP2 = FN = 0
with torch.inference_mode():
pbar = tqdm(val_loader, ncols=180)
for batch_index, (X, y) in enumerate(pbar):
X, y = X.to(device), y.to(device)
y_pred = model(X)
loss_sum += wbce_loss(y, y_pred).item()
y_ = y.detach().cpu().numpy()
y_pred_ = y_pred.detach().cpu().numpy()
y_pred_ = (y_pred_ > 0.5).astype('float32')
(tp, tn, fp1, fp2, fn) = outcome(y_pred_, y_)
TP += tp
TN += tn
FP1 += fp1
FP2 += fp2
FN += fn
(accuracy, precision, recall) = evaluation(TP, TN, FP1, FP2, FN)
pbar.set_description('Val loss: {:.6f} | TP: {}, TN: {}, FP1: {}, FP2: {}, FN: {} | Accuracy: {:.4f}, Precision: {:.4f}, Recall: {:.4f}'.format( \
loss_sum / ((batch_index+1)*X.shape[0]), TP, TN, FP1, FP2, FN, accuracy, precision, recall))
F1 = 2 * (precision*recall) / (precision+recall)
print("F1-score: {}".format(F1))
return loss_sum/len(val_loader)
def parse_opt():
parser = ArgumentParser()
parser.add_argument('--data', type=str, default=ROOT / 'data/match/test.yaml', help='Path to dataset.')
parser.add_argument('--weights', type=str, default=ROOT / 'best.pt', help='Path to trained model weights.')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[288, 512], help='image size h,w')
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
parser.add_argument('--project', default=ROOT / 'runs/val', help='save results to project/name')
opt = parser.parse_args()
return opt
def main(opt):
d_save_dir = str(opt.project)
f_weights = str(opt.weights)
batch_size = opt.batch_size
f_data = str(opt.data)
imgsz = opt.imgsz
data_dict = check_dataset(f_data)
train_path, val_path = data_dict['train'], data_dict['val']
if not os.path.exists(d_save_dir):
os.makedirs(d_save_dir)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = TrackNet().to(device)
assert os.path.exists(f_weights), f_weights+" is invalid"
print("load pretrain weights {}".format(f_weights))
model.load_state_dict(torch.load(f_weights))
val_loader = create_dataloader(val_path, imgsz, batch_size=batch_size)
validation_loop(device, model, val_loader, d_save_dir)
if __name__ == '__main__':
opt = parse_opt()
main(opt)