-
Notifications
You must be signed in to change notification settings - Fork 19
/
predict_yolo3.py
executable file
·150 lines (116 loc) · 5.41 KB
/
predict_yolo3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#! /usr/bin/env python
import time
import os
import argparse
import json
import cv2
import sys
sys.path += [os.path.abspath('keras-yolo3-master')]
from utils.utils import get_yolo_boxes, makedirs
from utils.bbox import draw_boxes
from tensorflow.keras.models import load_model
from tqdm import tqdm
import numpy as np
def _main_(args):
config_path = args.conf
input_path = args.input
output_path = args.output
with open(config_path) as config_buffer:
config = json.load(config_buffer)
makedirs(output_path)
###############################
# Set some parameter
###############################
net_h, net_w = 416, 416 # a multiple of 32, the smaller the faster
obj_thresh, nms_thresh = 0.5, 0.3
###############################
# Load the model
###############################
os.environ['CUDA_VISIBLE_DEVICES'] = config['train']['gpus']
infer_model = load_model(config['train']['saved_weights_name'])
###############################
# Predict bounding boxes
###############################
if 'webcam' in input_path: # do detection on the first webcam
video_reader = cv2.VideoCapture(0)
# the main loop
batch_size = 1
images = []
while True:
ret_val, image = video_reader.read()
if ret_val == True: images += [image]
if (len(images)==batch_size) or (ret_val==False and len(images)>0):
batch_boxes = get_yolo_boxes(infer_model, images, net_h, net_w, config['model']['anchors'], obj_thresh, nms_thresh)
for i in range(len(images)):
draw_boxes(images[i], batch_boxes[i], config['model']['labels'], obj_thresh)
cv2.imshow('video with bboxes', images[i])
images = []
if cv2.waitKey(1) == 27:
break # esc to quit
cv2.destroyAllWindows()
elif input_path[-4:] == '.mp4': # do detection on a video
video_out = output_path + input_path.split('/')[-1]
video_reader = cv2.VideoCapture(input_path)
nb_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
frame_h = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_w = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
video_writer = cv2.VideoWriter(video_out,
cv2.VideoWriter_fourcc(*'MPEG'),
50.0,
(frame_w, frame_h))
# the main loop
batch_size = 1
images = []
start_point = 0 #%
show_window = False
for i in tqdm(range(nb_frames)):
_, image = video_reader.read()
if (float(i+1)/nb_frames) > start_point/100.:
images += [image]
if (i%batch_size == 0) or (i == (nb_frames-1) and len(images) > 0):
# predict the bounding boxes
batch_boxes = get_yolo_boxes(infer_model, images, net_h, net_w, config['model']['anchors'], obj_thresh, nms_thresh)
for i in range(len(images)):
# draw bounding boxes on the image using labels
draw_boxes(images[i], batch_boxes[i], config['model']['labels'], obj_thresh)
# show the video with detection bounding boxes
if show_window: cv2.imshow('video with bboxes', images[i])
# write result to the output video
video_writer.write(images[i])
images = []
if show_window and cv2.waitKey(1) == 27: break # esc to quit
if show_window: cv2.destroyAllWindows()
video_reader.release()
video_writer.release()
else: # do detection on an image or a set of images
image_paths = []
if os.path.isdir(input_path):
for inp_file in os.listdir(input_path):
image_paths += [input_path + inp_file]
else:
image_paths += [input_path]
image_paths = [inp_file for inp_file in image_paths if (inp_file[-4:] in ['.jpg', '.png', 'JPEG'])]
# the main loop
times = []
for image_path in image_paths:
image = cv2.imread(image_path)
print(image_path)
start = time.time()
# predict the bounding boxes
boxes = get_yolo_boxes(infer_model, [image], net_h, net_w, config['model']['anchors'], obj_thresh, nms_thresh)[0]
print('Elapsed time = {}'.format(time.time() - start))
times.append(time.time() - start)
# draw bounding boxes on the image using labels
draw_boxes(image, boxes, config['model']['labels'], obj_thresh)
# write the image with bounding boxes to file
cv2.imwrite(output_path + image_path.split('/')[-1], np.uint8(image))
file = open(args.output + '/time.txt','w')
file.write('Tiempo promedio:' + str(np.mean(times)))
file.close()
if __name__ == '__main__':
argparser = argparse.ArgumentParser(description='Predict with a trained yolo model')
argparser.add_argument('-c', '--conf', help='path to configuration file')
argparser.add_argument('-i', '--input', help='path to an image, a directory of images, a video, or webcam')
argparser.add_argument('-o', '--output', default='output/', help='path to output directory')
args = argparser.parse_args()
_main_(args)