Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added recursive image searching to the retrainer #113

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 109 additions & 54 deletions scripts/retrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,114 @@
# sizes. If you want to adapt this script to work with another model, you will
# need to update these to reflect the values in the network you're using.
MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1 # ~134M
# This here defines what is used to join labels together in create_file_list()
# when dealing with recursive directory searching
CLASS_LABEL_JOINER = "."

def create_sub_folder_image_lists(image_dir, testing_percentage, validation_percentage):
"""Builds a list of training images from the file system.

Recursively analyzes the sub folders in the image directory, splits them into stable
training, testing, and validation sets, and returns a data structure
describing the lists of images for each label (determined by the relative path to the image,
with / replaced by . and image_dir prepended) and their paths.

Args:
image_dir: String path to a folder containing subfolders of images.
testing_percentage: Integer percentage of the images to reserve for tests.
validation_percentage: Integer percentage of images reserved for validation.

Returns:
A dictionary containing an entry for each label subfolder, with images split
into training, testing, and validation sets within each label.
"""
# Since root dir is automatically ignored
if not gfile.Exists(image_dir):
tf.logging.error("Image directory '" + image_dir + "' not found.")
return None
result = collections.OrderedDict()
image_dir_contents = [
os.path.join(image_dir,item)
for item in gfile.ListDirectory(image_dir)]
sub_dirs = sorted(item for item in image_dir_contents
if gfile.IsDirectory(item))
# Loop through sub_dirs and re run this for each
for sub_dir in sub_dirs:
if sub_dir == image_dir:
continue
sub_dir_results = create_sub_folder_image_lists(sub_dir, testing_percentage, validation_percentage)
# Handle errors
if not sub_dir_results:
tf.logging.error("An error occured finding images in directory '" + image_dir + "'")
continue
# Loop through each result and append the image_dir to the beginning.
for key in sub_dir_results:
# Since this function only returns basename, not whole path to dir
# relative to first image_dir
# we add image dir here
sub_dir_results[key]["dir"] = os.path.join(os.path.basename(image_dir), sub_dir_results[key]["dir"])
result[os.path.basename(image_dir) + CLASS_LABEL_JOINER + key] = sub_dir_results[key]
# Actual bit where we generate file lists
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
file_list = []
# Get files matching extensions
for extension in extensions:
file_glob = os.path.join(image_dir, '*.' + extension)
file_list.extend(gfile.Glob(file_glob))
if not file_list:
tf.logging.warning('No files found')
return result
if len(file_list) < 20:
tf.logging.warning(
'WARNING: Folder has less than 20 images, which may cause issues.')
elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:
tf.logging.warning(
'WARNING: Folder {} has more than {} images. Some images will '
'never be selected.'.format(image_dir, MAX_NUM_IMAGES_PER_CLASS))
label_name = re.sub(r'[^a-z0-9]+', ' ', os.path.basename(image_dir).lower())
training_images = []
testing_images = []
validation_images = []
for file_name in file_list:
base_name = os.path.basename(file_name)
# We want to ignore anything after '_nohash_' in the file name when
# deciding which set to put an image in, the data set creator has a way of
# grouping photos that are close variations of each other. For example
# this is used in the plant disease data set to group multiple pictures of
# the same leaf.
hash_name = re.sub(r'_nohash_.*$', '', file_name)
# This looks a bit magical, but we need to decide whether this file should
# go into the training, testing, or validation sets, and we want to keep
# existing files in the same set even if more files are subsequently
# added.
# To do that, we need a stable way of deciding based on just the file name
# itself, so we do a hash of that and then use that to generate a
# probability value that we use to assign it.
hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
percentage_hash = ((int(hash_name_hashed, 16) %
(MAX_NUM_IMAGES_PER_CLASS + 1)) *
(100.0 / MAX_NUM_IMAGES_PER_CLASS))
if percentage_hash < validation_percentage:
validation_images.append(base_name)
elif percentage_hash < (testing_percentage + validation_percentage):
testing_images.append(base_name)
else:
training_images.append(base_name)
result[label_name] = {
'dir': os.path.basename(image_dir),
'training': training_images,
'testing': testing_images,
'validation': validation_images,
}
return result

def create_image_lists(image_dir, testing_percentage, validation_percentage):
"""Builds a list of training images from the file system.

Analyzes the sub folders in the image directory, splits them into stable
Recursively analyzes the sub folders in the image directory, splits them into stable
training, testing, and validation sets, and returns a data structure
describing the lists of images for each label and their paths.
This is different to create_sub_folder_image_lists as image_dir is not included in the labels.

Args:
image_dir: String path to a folder containing subfolders of images.
Expand All @@ -148,63 +248,18 @@ def create_image_lists(image_dir, testing_percentage, validation_percentage):
sub_dirs = sorted(item for item in sub_dirs
if gfile.IsDirectory(item))
for sub_dir in sub_dirs:
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
file_list = []
dir_name = os.path.basename(sub_dir)
if dir_name == image_dir:
if sub_dir == image_dir:
continue
tf.logging.info("Looking for images in '" + dir_name + "'")
for extension in extensions:
file_glob = os.path.join(image_dir, dir_name, '*.' + extension)
file_list.extend(gfile.Glob(file_glob))
if not file_list:
tf.logging.warning('No files found')
sub_dir_results = create_sub_folder_image_lists(sub_dir, testing_percentage, validation_percentage)
# Handle errors
if not sub_dir_results:
tf.logging.error("An error occured finding images in directory '" + image_dir + "'")
continue
if len(file_list) < 20:
tf.logging.warning(
'WARNING: Folder has less than 20 images, which may cause issues.')
elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:
tf.logging.warning(
'WARNING: Folder {} has more than {} images. Some images will '
'never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS))
label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
training_images = []
testing_images = []
validation_images = []
for file_name in file_list:
base_name = os.path.basename(file_name)
# We want to ignore anything after '_nohash_' in the file name when
# deciding which set to put an image in, the data set creator has a way of
# grouping photos that are close variations of each other. For example
# this is used in the plant disease data set to group multiple pictures of
# the same leaf.
hash_name = re.sub(r'_nohash_.*$', '', file_name)
# This looks a bit magical, but we need to decide whether this file should
# go into the training, testing, or validation sets, and we want to keep
# existing files in the same set even if more files are subsequently
# added.
# To do that, we need a stable way of deciding based on just the file name
# itself, so we do a hash of that and then use that to generate a
# probability value that we use to assign it.
hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
percentage_hash = ((int(hash_name_hashed, 16) %
(MAX_NUM_IMAGES_PER_CLASS + 1)) *
(100.0 / MAX_NUM_IMAGES_PER_CLASS))
if percentage_hash < validation_percentage:
validation_images.append(base_name)
elif percentage_hash < (testing_percentage + validation_percentage):
testing_images.append(base_name)
else:
training_images.append(base_name)
result[label_name] = {
'dir': dir_name,
'training': training_images,
'testing': testing_images,
'validation': validation_images,
}
# Loop through each result and append the image_dir to the beginning.
for key, value in sub_dir_results.items():
result[key] = value
return result


def get_image_path(image_lists, label_name, index, image_dir, category):
""""Returns a path to an image for a label at the given index.

Expand Down