-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathKITTI-ObjectDetection.py
171 lines (100 loc) · 3.69 KB
/
KITTI-ObjectDetection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#!/usr/bin/env python
# coding: utf-8
# In[ ]:
get_ipython().system('pip install torch torchvision torchaudio diffusers transformers tensorflow_datasets')
# In[ ]:
#### NOTE: This would take significant amount of time when running for the first time
import tensorflow_datasets as tfds
dataset, info = tfds.load('kitti', with_info=True)
# In[1]:
# print(info)
# In[21]:
### Documentation on the dataset
# https://datasetninja.com/kitti-object-detection#object-distribution
# In[3]:
train_dataset = dataset['train']
test_dataset = dataset['test']
validation_dataset = dataset['validation']
# In[6]:
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# Function to display images and annotations
def show_images_with_annotations(dataset, num_images):
plt.figure(figsize=(45, 45))
for i, example in enumerate(dataset.take(num_images)):
image = example['image'].numpy()
bboxes = example['objects']['bbox'].numpy()
plt.subplot(5, 5, i + 1)
plt.imshow(image)
for bbox in bboxes:
ymin, xmin, ymax, xmax = bbox
# Convert normalized coordinates to pixel values
height, width, _ = image.shape
xmin = int(xmin * width)
xmax = int(xmax * width)
ymin = int(ymin * height)
ymax = int(ymax * height)
rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='red', facecolor='none')
plt.gca().add_patch(rect)
plt.axis("off")
plt.show()
# Display 5 images with bounding boxes
show_images_with_annotations(train_dataset, 5)
# In[ ]:
# In[9]:
import tensorflow_datasets as tfds
import tensorflow as tf
# Get the number of records
num_train = tf.data.experimental.cardinality(train_dataset).numpy()
num_test = tf.data.experimental.cardinality(test_dataset).numpy()
num_validation = tf.data.experimental.cardinality(validation_dataset).numpy()
print(f'Number of training records: {num_train}')
print(f'Number of validation records: {num_validation}')
print(f'Number of testing records: {num_test}')
# In[11]:
def preprocess(data):
image = data['image']
image = tf.image.resize_with_pad(image, 128, 128) # Resize with padding to 128x128
image = tf.cast(image, tf.float32) / 255.0
# Create a binary vector for labels
labels = tf.reduce_sum(tf.one_hot(data['objects']['type'], depth=8), axis=0)
return image, labels
# In[12]:
batch_size = 128
train_dataset = train_dataset.map(preprocess).batch(batch_size)
validation_dataset = validation_dataset.map(preprocess).batch(batch_size)
test_dataset = test_dataset.map(preprocess).batch(batch_size)
# In[17]:
from tensorflow.keras import layers, models
model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 3)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(8, activation='sigmoid') # 8 classes for 'type'
])
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy']
)
# In[18]:
from tensorflow.keras.callbacks import EarlyStopping
early_stopping_callback = EarlyStopping(
monitor='val_loss',
patience=3,
mode='min',
verbose=1
)
# In[19]:
model.fit(train_dataset,
validation_data=validation_dataset,
epochs=5,
callbacks=[early_stopping_callback])
# In[20]:
test_loss, test_acc = model.evaluate(test_dataset)
print(f'Test accuracy: {test_acc}')
# In[ ]: