-
Notifications
You must be signed in to change notification settings - Fork 1
/
create_tfrecord.py
151 lines (123 loc) · 5.3 KB
/
create_tfrecord.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
""" TFRecord generator
Generates tfrecord from MOT17 folders (containing video frames) and det.txt file (containing detections coordinates for each frame)
usage: create_tfrecord.py [VIDEOS_DIR] [OUTPUT_PATH] [LABELS_PATH]
required arguments:
VIDEOS_DIR
Path to the folder containing MOT17 folders (where the input images and det.txt files are stored)
OUTPUT_PATH
Path of output TFRecord (.record) file.
LABELS_PATH, --labels_path LABELS_PATH
Path to the labels (.pbtxt) file.
"""
import os
import argparse
import io
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow.compat.v1 as tf
from PIL import Image
from object_detection.utils import dataset_util, label_map_util
parser = argparse.ArgumentParser()
parser.add_argument("videos_dir", type=str, help="Path to the folder containing MOT17 folders (where the input images and det.txt files are stored)")
parser.add_argument("output_path", type=str, help="Path of output TFRecord (.record) file.")
parser.add_argument("labels_path", type=str, help="Path to the labels (.pbtxt) file.")
args = parser.parse_args()
def create_boxes_dict(video_path):
'''Creates a dictionary of boxes detected in every frame (only for frames that are to be selected).
Parameters:
----------
video_path : str
The path to video MOT17 folder
Returns
-------
frame_boxes_dict : dictionary
'''
frame_boxes_dict = {}
det_path = os.path.join(video_path, "det/det.txt")
with open(det_path, "r") as det_file :
det_lines = det_file.readlines()
for line in det_lines:
parsed_line = [round(float(elem)) for elem in line.split(",")]
frame_index = parsed_line[0]
# append x,y,w,h to dict
frame_boxes_dict.setdefault(frame_index, []).append(parsed_line[2:6])
return frame_boxes_dict
def create_tf_example(boxes, image_path, image_name):
'''Creates a tf example for the given image and its detection boxes
Parameters:
----------
boxes : array
The array containing the coordinates (x,y,w,h) of each box
image_path : str
The path to the image file
image_name : str
The image file name
Returns
-------
tf_example
'''
category_index = label_map_util.create_category_index_from_labelmap(args.labels_path)
with tf.gfile.GFile(image_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = Image.open(encoded_jpg_io)
width, height = image.size
filename = image_name.encode('utf8')
image_format = b'jpg'
xmins = []
xmaxs = []
ymins = []
ymaxs = []
classes_text = []
classes = []
for box in boxes:
xmin, ymin, w, h = box
xmax = xmin + w
ymax = ymin + h
xmins.append(xmin / width)
xmaxs.append(xmax / width)
ymins.append(ymin / height)
ymaxs.append(ymax / height)
classes_text.append(category_index[1]["name"].encode('utf8'))
classes.append(category_index[1]["id"])
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename),
'image/source_id': dataset_util.bytes_feature(filename),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature(image_format),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
}))
return tf_example
def main():
writer = tf.python_io.TFRecordWriter(args.output_path)
videos = os.listdir(args.videos_dir)
for video in videos:
if(not video.startswith("MOT17-")):
continue
video_path = os.path.join(args.videos_dir, video, "")
frame_boxes_dict = create_boxes_dict(video_path)
# select images and write corresponding tf examples to tfrecord
images_dir_path = os.path.join(video_path, "img1", "")
images = sorted(os.listdir(images_dir_path))
for image_name in images:
# image names of the form "000001.jpg", with always less than 10000 images per video
if image_name[-4:] == ".jpg":
image_id = int(image_name[-8:-4])
boxes = frame_boxes_dict.get(image_id)
if (boxes is None):
continue
image_path = os.path.join(images_dir_path, image_name)
# add video id to image name in order to avoid conflict between image names of different videos
new_image_name = str(int(video[6:8])) + image_name
tf_example = create_tf_example(boxes, image_path, new_image_name)
writer.write(tf_example.SerializeToString())
writer.close()
print('Successfully created the TFRecord file: {}'.format(args.output_path))
if __name__ == '__main__':
main()