diff --git a/tools/im2rec.py b/tools/im2rec.py index da3a1dddc87c..717ad9d6c760 100644 --- a/tools/im2rec.py +++ b/tools/im2rec.py @@ -90,6 +90,19 @@ def write_list(path_out, image_list): line += '%s\n' % item[1] fout.write(line) + +def each_class_to_beginning(image_list: list): + """Take off one photo of each class""" + images = {} + for elt in image_list: + cls = elt[-1] + if cls not in images: + images[cls] = elt + unique_classes = list(images.values()) + for elt in unique_classes: + image_list.remove(elt) + return unique_classes + image_list + def make_list(args): """Generates .lst file. Parameters @@ -101,6 +114,7 @@ def make_list(args): if args.shuffle is True: random.seed(100) random.shuffle(image_list) + image_list = each_class_to_beginning(image_list) N = len(image_list) chunk_size = (N + args.chunks - 1) // args.chunks for i in range(args.chunks): @@ -115,10 +129,10 @@ def make_list(args): write_list(args.prefix + str_chunk + '.lst', chunk) else: if args.test_ratio: - write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test]) + write_list(args.prefix + str_chunk + '_test.lst', chunk[sep:sep+sep_test]) if args.train_ratio + args.test_ratio < 1.0: - write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:]) - write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep]) + write_list(args.prefix + str_chunk + '_val.lst', chunk[sep+sep_test:]) + write_list(args.prefix + str_chunk + '_train.lst', chunk[:sep]) def read_list(path_in): """Reads the .lst file and generates corresponding iterator.