Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Emulate ImageNet folder structure #3

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
parser.add_argument('-scrape_only_flickr', default=True, type=lambda x: (str(x).lower() == 'true'))
parser.add_argument('-number_of_classes', default = 10, type=int)
parser.add_argument('-images_per_class', default = 10, type=int)
parser.add_argument('-data_root', default='' , type=str)
parser.add_argument('-data_root', type=str)
parser.add_argument('-use_class_list', default=False,type=lambda x: (str(x).lower() == 'true'))
parser.add_argument('-class_list', default=[], nargs='*')
parser.add_argument('-debug', default=False,type=lambda x: (str(x).lower() == 'true'))
Expand All @@ -25,15 +25,7 @@
args, args_other = parser.parse_known_args()

if args.debug:
logging.basicConfig(filename='imagenet_scarper.log', level=logging.DEBUG)

if len(args.data_root) == 0:
logging.error("-data_root is required to run downloader!")
exit()

if not os.path.isdir(args.data_root):
logging.error(f'folder {args.data_root} does not exist! please provide existing folder in -data_root arg!')
exit()
logging.basicConfig(filename='imagenet_scraper.log', level=logging.DEBUG)


IMAGENET_API_WNID_TO_URLS = lambda wnid: f'http://www.image-net.org/api/text/imagenet.synset.geturls?wnid={wnid}'
Expand All @@ -54,7 +46,7 @@
for item in args.class_list:
classes_to_scrape.append(item)
if item not in class_info_dict:
logging.error(f'Class {item} not found in ImageNete')
logging.error(f'Class {item} not found in ImageNet')
exit()

elif args.use_class_list == False:
Expand All @@ -79,12 +71,14 @@
classes_to_scrape.append(potential_class_pool[idx])


print("Picked the following clases:")
print("Picked the following classes:")
print([ class_info_dict[class_wnid]['class_name'] for class_wnid in classes_to_scrape ])

imagenet_images_folder = os.path.join(args.data_root, 'imagenet_images')
if not os.path.isdir(imagenet_images_folder):
os.mkdir(imagenet_images_folder)
train_folder = os.path.join(args.data_root, "train")
os.makedirs(train_folder, exist_ok=True)

validation_folder = os.path.join(args.data_root, "validation")
os.makedirs(validation_folder, exist_ok=True)


scraping_stats = dict(
Expand Down Expand Up @@ -313,6 +307,8 @@ def finish(status):
return finish('success')


validation_index = 1

for class_wnid in classes_to_scrape:

class_name = class_info_dict[class_wnid]["class_name"]
Expand All @@ -322,9 +318,11 @@ def finish(status):
time.sleep(0.05)
resp = requests.get(url_urls)

class_folder = os.path.join(imagenet_images_folder, class_name)
if not os.path.exists(class_folder):
os.mkdir(class_folder)
# The key begins with a dollar sign
label = class_wnid.replace("$", "n")

class_folder = os.path.join(train_folder, label)
os.makedirs(class_folder, exist_ok=True)

class_images.value = 0

Expand All @@ -336,3 +334,26 @@ def finish(status):
print(f"Multiprocessing workers: {args.multiprocessing_workers}")
with Pool(processes=args.multiprocessing_workers) as p:
p.map(get_image,urls)

# Rename images
images = os.listdir(class_folder)
for i in range(0, len(images)):
# Copy the old name
old_name = images[i][:]
images[i] = f"{label}_{i:04d}.JPEG"
os.rename(os.path.join(class_folder, old_name),
os.path.join(class_folder, images[i]))


# Use one of the images for validation
validation_file = images[0]
os.rename(os.path.join(class_folder, validation_file),
os.path.join(validation_folder, f"ILSVRC2012_val_{validation_index:06d}.JPEG"))

validation_index += 1

# Create a labels file for validation
with open(os.path.join(args.data_root, "synset_labels.txt"), "w") as f:
for class_wnid in classes_to_scrape:
label = class_wnid.replace("$", "n")
f.write(f"{label}\n")