-
Notifications
You must be signed in to change notification settings - Fork 0
/
visualize_anet.py
100 lines (85 loc) · 3.91 KB
/
visualize_anet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import json
import h5py
import random
import numpy as np
from matplotlib.pyplot import imshow
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
# load the following to files from DETECTED_SGG_DIR
custom_prediction = json.load(open('/home/tqsang/scene_graph/Scene-Graph-Benchmark.pytorch/checkpoints/test_anet_output/custom_prediction.json'))
custom_data_info = json.load(open('/home/tqsang/scene_graph/Scene-Graph-Benchmark.pytorch/checkpoints/test_anet_output/custom_data_info.json'))
def draw_single_box(pic, box, color='red', draw_info=None):
draw = ImageDraw.Draw(pic)
x1,y1,x2,y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
draw.rectangle(((x1, y1), (x2, y2)), outline=color)
if draw_info:
draw.rectangle(((x1, y1), (x1+50, y1+10)), fill=color)
info = draw_info
draw.text((x1, y1), info)
def print_list(name, input_list, scores=None):
for i, item in enumerate(input_list):
if scores == None:
print(name + ' ' + str(i) + ': ' + str(item))
else:
print(name + ' ' + str(i) + ': ' + str(item) + '; score: ' + str(scores[i]))
def draw_image(image_idx, img_path, boxes, box_labels, rel_labels, box_scores=None, rel_scores=None):
size = get_size(Image.open(img_path).size)
pic = Image.open(img_path).resize(size)
num_obj = len(boxes)
for i in range(num_obj):
info = str(i) + '_' + box_labels[i]
draw_single_box(pic, boxes[i], draw_info=info)
# display(pic)
pic.save('test_anet/{}.jpg'.format(image_idx))
print('*' * 50)
print_list('box_labels', box_labels, box_scores)
print('*' * 50)
print_list('rel_labels', rel_labels, rel_scores)
return None
def get_size(image_size):
min_size = 600
max_size = 1000
w, h = image_size
size = min_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(round(max_size * min_original_size / max_original_size))
if (w <= h and w == size) or (h <= w and h == size):
return (w, h)
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
return (ow, oh)
# parameters
for image_idx in range(100):
# image_idx = 0
box_topk = 30 # select top k bounding boxes
rel_topk = 20 # select top k relationships
ind_to_classes = custom_data_info['ind_to_classes']
ind_to_predicates = custom_data_info['ind_to_predicates']
image_path = custom_data_info['idx_to_files'][image_idx]
boxes = custom_prediction[str(image_idx)]['bbox'][:box_topk]
box_labels = custom_prediction[str(image_idx)]['bbox_labels'][:box_topk]
box_scores = custom_prediction[str(image_idx)]['bbox_scores'][:box_topk]
all_rel_labels = custom_prediction[str(image_idx)]['rel_labels']
all_rel_scores = custom_prediction[str(image_idx)]['rel_scores']
all_rel_pairs = custom_prediction[str(image_idx)]['rel_pairs']
for i in range(len(box_labels)):
box_labels[i] = ind_to_classes[box_labels[i]]
rel_labels = []
rel_scores = []
for i in range(len(all_rel_pairs)):
if all_rel_pairs[i][0] < box_topk and all_rel_pairs[i][1] < box_topk:
rel_scores.append(all_rel_scores[i])
# label = str(all_rel_pairs[i][0]) + '_' + box_labels[all_rel_pairs[i][0]] + ' => ' + ind_to_predicates[all_rel_labels[i]] + ' => ' + str(all_rel_pairs[i][1]) + '_' + box_labels[all_rel_pairs[i][1]]
label = box_labels[all_rel_pairs[i][0]] + ' ' + ind_to_predicates[all_rel_labels[i]] + ' ' + box_labels[all_rel_pairs[i][1]]
rel_labels.append(label)
rel_labels = rel_labels[:rel_topk]
rel_scores = rel_scores[:rel_topk]
draw_image(image_idx, image_path, boxes, box_labels, rel_labels, box_scores=box_scores, rel_scores=rel_scores)