-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
156 lines (143 loc) · 6.03 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
import cv2
import numpy as np
import os
import glob as glob
from xml.etree import ElementTree as et
from config import (
CLASSES, RESIZE_TO, TRAIN_DIR, VALID_DIR, BATCH_SIZE
)
from torch.utils.data import Dataset, DataLoader
from custom_utils import collate_fn, get_train_transform, get_valid_transform
# the dataset class
class CustomDataset(Dataset):
def __init__(self, dir_path, width, height, classes, transforms=None):
self.transforms = transforms
self.dir_path = dir_path
self.height = height
self.width = width
self.classes = classes
# get all the image paths in sorted order
self.image_paths = glob.glob(f"{self.dir_path}/*.jpg")
self.all_images = [image_path.split(os.path.sep)[-1] for image_path in self.image_paths]
self.all_images = sorted(self.all_images)
def __getitem__(self, idx):
# capture the image name and the full image path
image_name = self.all_images[idx]
image_path = os.path.join(self.dir_path, image_name)
# read the image
image = cv2.imread(image_path)
# convert BGR to RGB color format
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
image_resized = cv2.resize(image, (self.width, self.height))
image_resized /= 255.0
# capture the corresponding XML file for getting the annotations
annot_filename = image_name[:-4] + '.xml'
annot_file_path = os.path.join(self.dir_path, annot_filename)
boxes = []
labels = []
tree = et.parse(annot_file_path)
root = tree.getroot()
# get the height and width of the image
image_width = image.shape[1]
image_height = image.shape[0]
# box coordinates for xml files are extracted and corrected for image size given
for member in root.findall('object'):
# map the current object name to `classes` list to get...
# ... the label index and append to `labels` list
labels.append(self.classes.index(member.find('name').text))
# xmin = left corner x-coordinates
xmin = int(member.find('bndbox').find('xmin').text)
# xmax = right corner x-coordinates
xmax = int(member.find('bndbox').find('xmax').text)
# ymin = left corner y-coordinates
ymin = int(member.find('bndbox').find('ymin').text)
# ymax = right corner y-coordinates
ymax = int(member.find('bndbox').find('ymax').text)
# resize the bounding boxes according to the...
# ... desired `width`, `height`
xmin_final = (xmin/image_width)*self.width
xmax_final = (xmax/image_width)*self.width
ymin_final = (ymin/image_height)*self.height
yamx_final = (ymax/image_height)*self.height
boxes.append([xmin_final, ymin_final, xmax_final, yamx_final])
# bounding box to tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
# area of the bounding boxes
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# no crowd instances
iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
# labels to tensor
labels = torch.as_tensor(labels, dtype=torch.int64)
# prepare the final `target` dictionary
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["area"] = area
target["iscrowd"] = iscrowd
image_id = torch.tensor([idx])
target["image_id"] = image_id
# apply the image transforms
if self.transforms:
sample = self.transforms(image = image_resized,
bboxes = target['boxes'],
labels = labels)
image_resized = sample['image']
target['boxes'] = torch.Tensor(sample['bboxes'])
return image_resized, target
def __len__(self):
return len(self.all_images)
# prepare the final datasets and data loaders
def create_train_dataset():
train_dataset = CustomDataset(TRAIN_DIR, RESIZE_TO, RESIZE_TO, CLASSES, get_train_transform())
return train_dataset
def create_valid_dataset():
valid_dataset = CustomDataset(VALID_DIR, RESIZE_TO, RESIZE_TO, CLASSES, get_valid_transform())
return valid_dataset
def create_train_loader(train_dataset, num_workers=0):
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=num_workers,
collate_fn=collate_fn
)
return train_loader
def create_valid_loader(valid_dataset, num_workers=0):
valid_loader = DataLoader(
valid_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn
)
return valid_loader
# execute datasets.py using Python command from Terminal...
# ... to visualize sample images
# USAGE: python datasets.py
if __name__ == '__main__':
# sanity check of the Dataset pipeline with sample visualization
dataset = CustomDataset(
TRAIN_DIR, RESIZE_TO, RESIZE_TO, CLASSES
)
print(f"Number of training images: {len(dataset)}")
# function to visualize a single sample
def visualize_sample(image, target):
for box_num in range(len(target['boxes'])):
box = target['boxes'][box_num]
label = CLASSES[target['labels'][box_num]]
cv2.rectangle(
image,
(int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
(0, 255, 0), 2
)
cv2.putText(
image, label, (int(box[0]), int(box[1]-5)),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2
)
cv2.imshow('Image', image)
cv2.waitKey(0)
NUM_SAMPLES_TO_VISUALIZE = 5
for i in range(NUM_SAMPLES_TO_VISUALIZE):
image, target = dataset[i]
visualize_sample(image, target)