-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
84 lines (66 loc) · 3.06 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import bbox_visualizer as bbv
import torch
import numpy as np
from helpers import *
import cv2
## load paths of different files
PATH = 'sample_images/images'
FILES = os.listdir(PATH)
rand_idx = 2 ## can be changes to 0 .. 6
curr_file = FILES[rand_idx]
image_path = os.path.join(PATH, curr_file) ## ex: P016_tissue1_6625
file_name = '_'.join(curr_file.split('_')[:2])
frame_idx = int(curr_file.split('_')[-1].split('.')[0])
## load ground truth files
ground_truth_df_left = pd.read_csv(f'sample_videos/tools_left/{file_name}.txt', header=None, sep=' ', names=['start','end', 'label' ])
ground_truth_df_right = pd.read_csv(f'sample_videos/tools_right/{file_name}.txt', header=None, sep=' ', names=['start','end', 'label'])
ground_truth_bbox = pd.read_csv(f'sample_images/bbox_labels/{file_name}_{frame_idx}.txt', header=None, sep=' ',
names=["label_index", "xcenter", "ycenter", "w", "h"])
ground_truth_bbox = darknetbbox_to_yolo(ground_truth_bbox)
## load repo from original git yolov5
model = torch.hub.load('ultralytics/yolov5', 'custom', path='best.pt', source='github')
## define model params
model.conf = 0.6 ## allow preds over this threshold
model.max_det = 2 ## predict max 2 classes
size= (640,640)
## load frame into cv
frame = cv2.imread(image_path)
frame = cv2.resize(frame, size)
frame_to_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
## inference
output = model(frame_to_rgb)
output.render()
output_df = output.pandas().xyxy[0]
try:
if len(output_df) == 2:
left_df, right_df = extract_left_right(output_df)
boxes1 = [df_to_bbox(left_df), df_to_bbox(right_df)]
labels = [left_df['name'].values[0], right_df['name'].values[0]]
elif len(output_df) == 1:
boxes1 = df_to_bbox(output_df)
labels = [output_df['name'].values[0]]
except TypeError:
## TODO - handle
boxes1 = []
labels = []
frame = bbv.draw_multiple_rectangles(frame, boxes1, bbox_color=(255, 0, 0))
frame = bbv.add_multiple_labels(frame, labels, boxes1, text_bg_color=(255, 0, 0))
## Left
real_label_left = extract_label(ground_truth_df_left, frame_idx)
draw_text(frame, text=real_label_left, font_scale=1,pos=(500, 20), text_color_bg=(255, 0, 0), draw='left')
## Right
real_label_right = extract_label(ground_truth_df_right, frame_idx)
draw_text(frame, text=real_label_right, font_scale=1, pos=(10, 20), text_color_bg=(255, 0, 0), draw='right')
## draw ground truth bbox
left_df, right_df = extract_left_right(ground_truth_bbox)
boxes1 = [df_to_bbox(left_df), df_to_bbox(right_df)]
labels = ['GT:' + left_df['name'].values[0], 'GT:' + right_df['name'].values[0]]
frame = bbv.draw_multiple_rectangles(frame, boxes1)
frame = bbv.add_multiple_labels(frame, labels, boxes1, top = False)
# Display the resulting frame
# image = cv2.copyMakeBorder(frame, 15,15,15,15, borderType = cv2.BORDER_CONSTANT)
cv2.namedWindow("Frame", cv2.WINDOW_NORMAL)
cv2.imshow('Frame', image)
cv2.waitKey(0) #is required so that the image doesn’t close immediately. It will Wait for a key press before closing the image.
# Closes all the frames
cv2.destroyAllWindows()