diff --git a/downloader.py b/downloader.py index efa65f1b..c9d7a0fd 100755 --- a/downloader.py +++ b/downloader.py @@ -15,7 +15,7 @@ parser.add_argument('-scrape_only_flickr', default=True, type=lambda x: (str(x).lower() == 'true')) parser.add_argument('-number_of_classes', default = 10, type=int) parser.add_argument('-images_per_class', default = 10, type=int) -parser.add_argument('-data_root', default='' , type=str) +parser.add_argument('-data_root', type=str) parser.add_argument('-use_class_list', default=False,type=lambda x: (str(x).lower() == 'true')) parser.add_argument('-class_list', default=[], nargs='*') parser.add_argument('-debug', default=False,type=lambda x: (str(x).lower() == 'true')) @@ -25,15 +25,7 @@ args, args_other = parser.parse_known_args() if args.debug: - logging.basicConfig(filename='imagenet_scarper.log', level=logging.DEBUG) - -if len(args.data_root) == 0: - logging.error("-data_root is required to run downloader!") - exit() - -if not os.path.isdir(args.data_root): - logging.error(f'folder {args.data_root} does not exist! please provide existing folder in -data_root arg!') - exit() + logging.basicConfig(filename='imagenet_scraper.log', level=logging.DEBUG) IMAGENET_API_WNID_TO_URLS = lambda wnid: f'http://www.image-net.org/api/text/imagenet.synset.geturls?wnid={wnid}' @@ -54,7 +46,7 @@ for item in args.class_list: classes_to_scrape.append(item) if item not in class_info_dict: - logging.error(f'Class {item} not found in ImageNete') + logging.error(f'Class {item} not found in ImageNet') exit() elif args.use_class_list == False: @@ -79,12 +71,14 @@ classes_to_scrape.append(potential_class_pool[idx]) -print("Picked the following clases:") +print("Picked the following classes:") print([ class_info_dict[class_wnid]['class_name'] for class_wnid in classes_to_scrape ]) -imagenet_images_folder = os.path.join(args.data_root, 'imagenet_images') -if not os.path.isdir(imagenet_images_folder): - os.mkdir(imagenet_images_folder) +train_folder = os.path.join(args.data_root, "train") +os.makedirs(train_folder, exist_ok=True) + +validation_folder = os.path.join(args.data_root, "validation") +os.makedirs(validation_folder, exist_ok=True) scraping_stats = dict( @@ -313,6 +307,8 @@ def finish(status): return finish('success') +validation_index = 1 + for class_wnid in classes_to_scrape: class_name = class_info_dict[class_wnid]["class_name"] @@ -322,9 +318,11 @@ def finish(status): time.sleep(0.05) resp = requests.get(url_urls) - class_folder = os.path.join(imagenet_images_folder, class_name) - if not os.path.exists(class_folder): - os.mkdir(class_folder) + # The key begins with a dollar sign + label = class_wnid.replace("$", "n") + + class_folder = os.path.join(train_folder, label) + os.makedirs(class_folder, exist_ok=True) class_images.value = 0 @@ -336,3 +334,26 @@ def finish(status): print(f"Multiprocessing workers: {args.multiprocessing_workers}") with Pool(processes=args.multiprocessing_workers) as p: p.map(get_image,urls) + + # Rename images + images = os.listdir(class_folder) + for i in range(0, len(images)): + # Copy the old name + old_name = images[i][:] + images[i] = f"{label}_{i:04d}.JPEG" + os.rename(os.path.join(class_folder, old_name), + os.path.join(class_folder, images[i])) + + + # Use one of the images for validation + validation_file = images[0] + os.rename(os.path.join(class_folder, validation_file), + os.path.join(validation_folder, f"ILSVRC2012_val_{validation_index:06d}.JPEG")) + + validation_index += 1 + +# Create a labels file for validation +with open(os.path.join(args.data_root, "synset_labels.txt"), "w") as f: + for class_wnid in classes_to_scrape: + label = class_wnid.replace("$", "n") + f.write(f"{label}\n")