Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for variable Image size data and Data Generator to consume low memory #190

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 132 additions & 59 deletions retrain_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,25 @@
import argparse

import os

from PIL import ImageOps
import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
from keras import backend as K
from keras.layers import Input, Lambda, Conv2D
from keras.models import load_model, Model
from keras import regularizers
from keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping

from yad2k.models.keras_yolo import (preprocess_true_boxes, yolo_body,
yolo_eval, yolo_head, yolo_loss)
from yad2k.utils.draw_boxes import draw_boxes

import h5py
import io
from yolo_data_gen import *

# Args
argparser = argparse.ArgumentParser(
description="Retrain or 'fine-tune' a pretrained YOLOv2 model for your own data.")
Expand Down Expand Up @@ -46,42 +51,34 @@
(7.88282, 3.52778), (9.77052, 9.16828)))

def _main(args):

data_path = os.path.expanduser(args.data_path)
classes_path = os.path.expanduser(args.classes_path)
anchors_path = os.path.expanduser(args.anchors_path)

class_names = get_classes(classes_path)
anchors = get_anchors(anchors_path)

data = np.load(data_path) # custom data saved as a numpy file.
# has 2 arrays: an object array 'boxes' (variable length of boxes in each image)
# and an array of images 'images'

image_data, boxes = process_data(data['images'], data['boxes'])


dataset = h5py.File(data_path,'r+')

anchors = YOLO_ANCHORS

detectors_mask, matching_true_boxes = get_detector_mask(boxes, anchors)
#detectors_mask, matching_true_boxes = get_detector_mask(boxes, anchors)

model_body, model = create_model(anchors, class_names)

train(
model,
class_names,
anchors,
image_data,
boxes,
detectors_mask,
matching_true_boxes
)

draw(model_body,
class_names,
anchors,
image_data,
image_set='val', # assumes training/validation split is 0.9
weights_name='trained_stage_3_best.h5',
save_all=False)
train( model, class_names, anchors, dataset) # image_data, boxes, detectors_mask, matching_true_boxes )

# TODO use data generator for draw as well


# draw(model_body,
# class_names,
# anchors,
# image_data,
# image_set='all', # assumes test set is 0.9
# weights_name='trained_stage_3_best.h5',
# save_all=True)


def get_classes(classes_path):
Expand All @@ -101,15 +98,69 @@ def get_anchors(anchors_path):
else:
Warning("Could not open anchors file, using default.")
return YOLO_ANCHORS



#Exactly Same as process data but handles images of different sizes in dataset
def scale_data(images, boxes=None):
'''processes the data'''
img_shape = (416,416)
images = [PIL.Image.open(io.BytesIO(i)) for i in images]


# Box preprocessing.
if boxes is not None:
# Original boxes stored as 1D list of class, x_min, y_min, x_max, y_max.
boxes = [box.reshape((-1, 5)) for box in boxes]
# Get box parameters as x_center, y_center, box_width, box_height, class.
boxes_xy = [0.5 * (box[:, 3:5] + box[:, 1:3]) for box in boxes]
boxes_wh = [box[:, 3:5] - box[:, 1:3] for box in boxes]

# get original size of each image and and convert the coordinates and w h
processed_images = []
for i,img in enumerate(images):
orig_size = np.array([images[i].width, images[i].height])
boxes_xy[i] = boxes_xy[i] / orig_size
boxes_wh[i] = boxes_wh[i] / orig_size
images_i = images[i].resize(img_shape, PIL.Image.BICUBIC)

images_i = np.array(images_i, dtype=np.float)
processed_images.append(images_i/255)

boxes = [np.concatenate((boxes_xy[i], boxes_wh[i], box[:, 0:1]), axis=1) for i, box in enumerate(boxes)]

# find the max number of boxes
max_boxes = 0
for boxz in boxes:
if boxz.shape[0] > max_boxes:
max_boxes = boxz.shape[0]

# add zero pad for training
for i, boxz in enumerate(boxes):
if boxz.shape[0] < max_boxes:
zero_padding = np.zeros( (max_boxes-boxz.shape[0], 5), dtype=np.float32)
boxes[i] = np.vstack((boxz, zero_padding))

return np.array(processed_images), np.array(boxes)

else:
processed_images = [resize_image(i,img_shape[0],img_shape[1],False) for i in images]
processed_images = [np.array(image, dtype=np.float) for image in processed_images]
processed_images = [image/255. for image in processed_images]
return np.array(processed_images)


def process_data(images, boxes=None):
'''processes the data'''
images = [PIL.Image.fromarray(i) for i in images]
#images = [PIL.Image.fromarray(i) for i in images]
images = [PIL.Image.open(io.BytesIO(i)) for i in images]
orig_size = np.array([images[0].width, images[0].height])
orig_size = np.expand_dims(orig_size, axis=0)

print(type(images[0]))
# Image preprocessing.
processed_images = [i.resize((416, 416), PIL.Image.BICUBIC) for i in images]
#processed_images = [resize_image(i,416,416,False) for i in images]

processed_images = [np.array(image, dtype=np.float) for image in processed_images]
processed_images = [image/255. for image in processed_images]

Expand All @@ -119,7 +170,7 @@ def process_data(images, boxes=None):
boxes = [box.reshape((-1, 5)) for box in boxes]
# Get extents as y_min, x_min, y_max, x_max, class for comparision with
# model output.
boxes_extents = [box[:, [2, 1, 4, 3, 0]] for box in boxes]
#boxes_extents = [box[:, [2, 1, 4, 3, 0]] for box in boxes]

# Get box parameters as x_center, y_center, box_width, box_height, class.
boxes_xy = [0.5 * (box[:, 3:5] + box[:, 1:3]) for box in boxes]
Expand Down Expand Up @@ -204,7 +255,7 @@ def create_model(anchors, class_names, load_pretrained=True, freeze_body=True):
if freeze_body:
for layer in topless_yolo.layers:
layer.trainable = False
final_layer = Conv2D(len(anchors)*(5+len(class_names)), (1, 1), activation='linear')(topless_yolo.output)
final_layer = Conv2D(len(anchors)*(5+len(class_names)), (1, 1), activation='linear',kernel_regularizer= regularizers.l2(5e-4))(topless_yolo.output)

model_body = Model(image_input, final_layer)

Expand All @@ -227,7 +278,7 @@ def create_model(anchors, class_names, load_pretrained=True, freeze_body=True):

return model_body, model

def train(model, class_names, anchors, image_data, boxes, detectors_mask, matching_true_boxes, validation_split=0.1):
def train(model, class_names, anchors, dataset):#image_data, boxes, detectors_mask, matching_true_boxes, validation_split=0.1):
'''
retrain/fine-tune the model

Expand All @@ -243,17 +294,28 @@ def train(model, class_names, anchors, image_data, boxes, detectors_mask, matchi
}) # This is a hack to use the custom loss function in the last layer.


logging = TensorBoard()
logging = TensorBoard()#log_dir='./train_logs', histogram_freq=1, write_graph=False, write_images=True)
checkpoint = ModelCheckpoint("trained_stage_3_best.h5", monitor='val_loss',
save_weights_only=True, save_best_only=True)
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=15, verbose=1, mode='auto')

model.fit([image_data, boxes, detectors_mask, matching_true_boxes],
np.zeros(len(image_data)),
validation_split=validation_split,
batch_size=32,
epochs=5,
callbacks=[logging])
batch_size = 8
dataTrain = dataset['train']
dataVal= dataset['val']
train_set_size =dataTrain.attrs['dataset_size']
val_set_size =dataVal.attrs['dataset_size']
training_generator = DataGenerator(dataTrain, train_set_size,batch_size=batch_size)
validation_generator = DataGenerator(dataVal, val_set_size,batch_size=batch_size,is_train=0)
# model.fit([image_data, boxes, detectors_mask, matching_true_boxes],
# np.zeros(len(image_data)),
# validation_split=validation_split,
# batch_size=8,
# epochs=5,
# callbacks=[logging])
model.fit_generator(generator=training_generator,
validation_data=validation_generator,
use_multiprocessing=False,
epochs=5,verbose = 1, callbacks=[logging])
model.save_weights('trained_stage_1.h5')

model_body, model = create_model(anchors, class_names, load_pretrained=False, freeze_body=False)
Expand All @@ -265,22 +327,33 @@ def train(model, class_names, anchors, image_data, boxes, detectors_mask, matchi
'yolo_loss': lambda y_true, y_pred: y_pred
}) # This is a hack to use the custom loss function in the last layer.


model.fit([image_data, boxes, detectors_mask, matching_true_boxes],
np.zeros(len(image_data)),
validation_split=0.1,
batch_size=8,
epochs=30,
callbacks=[logging])


# model.fit([image_data, boxes, detectors_mask, matching_true_boxes],
# np.zeros(len(image_data)),
# validation_split=validation_split,
# batch_size=8,
# epochs=30,
# callbacks=[logging])
training_generator = DataGenerator(dataTrain, train_set_size,batch_size=batch_size)
validation_generator = DataGenerator(dataVal, val_set_size,batch_size=batch_size,is_train=0)
model.fit_generator(generator=training_generator,
validation_data=validation_generator,
use_multiprocessing=False,
epochs=30,verbose = 1, callbacks=[logging])
model.save_weights('trained_stage_2.h5')

model.fit([image_data, boxes, detectors_mask, matching_true_boxes],
np.zeros(len(image_data)),
validation_split=0.1,
batch_size=8,
epochs=30,
callbacks=[logging, checkpoint, early_stopping])

training_generator = DataGenerator(dataTrain, train_set_size,batch_size=batch_size)
validation_generator = DataGenerator(dataVal, val_set_size,batch_size=batch_size,is_train=0)
model.fit_generator(generator=training_generator,
validation_data=validation_generator,
use_multiprocessing=False,
epochs=30,verbose = 1, callbacks=[logging, checkpoint, early_stopping])
# model.fit([image_data, boxes, detectors_mask, matching_true_boxes],
# np.zeros(len(image_data)),
# validation_split=validation_split,
# batch_size=8,
# epochs=30,
# callbacks=[logging, checkpoint, early_stopping])

model.save_weights('trained_stage_3.h5')

Expand Down Expand Up @@ -308,7 +381,7 @@ def draw(model_body, class_names, anchors, image_data, image_set='val',
yolo_outputs = yolo_head(model_body.output, anchors, len(class_names))
input_image_shape = K.placeholder(shape=(2, ))
boxes, scores, classes = yolo_eval(
yolo_outputs, input_image_shape, score_threshold=0.07, iou_threshold=0)
yolo_outputs, input_image_shape, score_threshold=0.7, iou_threshold=0.7)

# Run prediction on overfit image.
sess = K.get_session() # TODO: Remove dependence on Tensorflow session.
Expand All @@ -328,15 +401,15 @@ def draw(model_body, class_names, anchors, image_data, image_set='val',

# Plot image with predicted boxes.
image_with_boxes = draw_boxes(image_data[i][0], out_boxes, out_classes,
class_names, out_scores)
class_names, out_scores,out_path+"\\"+str(i)+'.jpg')
# Save the image:
if save_all or (len(out_boxes) > 0):
if save_all :
image = PIL.Image.fromarray(image_with_boxes)
image.save(os.path.join(out_path,str(i)+'.png'))
image.save(os.path.join(out_path,str(i)+'.jpg'))

# To display (pauses the program):
# plt.imshow(image_with_boxes, interpolation='nearest')
# plt.show()
plt.imshow(image_with_boxes, interpolation='nearest')
plt.show()



Expand Down
5 changes: 4 additions & 1 deletion voc_conversion_scripts/voc_to_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,10 @@ def _main(args):
add_to_dataset(voc_path, '2012', val_ids, val_images, val_boxes)
print('Processing Pascal VOC 2007 test set.')
add_to_dataset(voc_path, '2007', test_ids, test_images, test_boxes)


train_group.attrs['dataset_size'] = total_train_ids
val_group.attrs['dataset_size'] = len(val_ids)
test_group.attrs['dataset_size'] = len(test_ids)
print('Closing HDF5 file.')
voc_h5file.close()
print('Done.')
Expand Down
71 changes: 71 additions & 0 deletions yolo_data_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import keras
from retrain_yolo import * #scale_data,get_detector_mask,YOLO_ANCHORS
from PIL import ImageOps
import numpy as np
import PIL
class DataGenerator(keras.utils.Sequence):

def getTrainData(self,dataset,indexes):
train_img_list = dataset['images'][indexes]
train_boxes = dataset['boxes'][indexes]
image_data, boxes = scale_data(train_img_list, train_boxes)
return (image_data, boxes)

def getValData(self,dataset,indexes):
val_img_list = dataset['val_img_list'][indexes]
val_padded_txt_list = dataset['val_padded_txt_list'][indexes]
val_label_length_list = dataset['val_label_length_list'][indexes]
val_input_length_list = dataset['val_input_length_list'][indexes]
return (val_img_list,val_padded_txt_list,val_label_length_list,val_input_length_list)

'Generates data for Keras'
def __init__(self, hdf5_dataset,data_set_size,is_train=1, batch_size=8, shuffle=True):
'Initialization'
self.is_train = is_train
self.batch_size = batch_size
self.shuffle = shuffle
self.hdf5_dataset = hdf5_dataset
self.data_set_size = data_set_size
self.indexes = np.arange(data_set_size)
no_of_batches = int(np.floor(self.data_set_size / self.batch_size))
self.no_of_batches = no_of_batches if self.data_set_size % self.batch_size == 0 else no_of_batches+1
self.on_epoch_end()

def __len__(self):
'Denotes the number of batches per epoch'
return self.no_of_batches

def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
if index == self.no_of_batches -1 :
indexes = self.indexes[index*self.batch_size:]
else :
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

# Generate data , need to load data from file from training list
X, y = self.__data_generation(indexes)

return X, y

def on_epoch_end(self):
'Updates indexes after each epoch'

if self.shuffle == True:
np.random.shuffle(self.indexes)

def __data_generation(self, indexes):
'Generates data containing batch_size samples'
#Indexing elements must be in increasing order
indexes.sort()
# Initialization
image_data, boxes = self.getTrainData(self.hdf5_dataset,indexes)

detectors_mask, matching_true_boxes = get_detector_mask(boxes, YOLO_ANCHORS)
# Store sample
X = [image_data,boxes,detectors_mask,matching_true_boxes]
y = np.zeros(image_data.shape[0])

return X, y