From 5949df1c059a53d98a6004d5bfc93708e5ec6c4a Mon Sep 17 00:00:00 2001 From: Chanran Kim Date: Mon, 29 Jun 2020 22:08:43 +0900 Subject: [PATCH] To solve issue #13637 in keras repo (#270) --- .../image/dataframe_iterator.py | 7 +++- tests/image/dataframe_iterator_test.py | 42 +++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/keras_preprocessing/image/dataframe_iterator.py b/keras_preprocessing/image/dataframe_iterator.py index 801039e4..24b230b8 100644 --- a/keras_preprocessing/image/dataframe_iterator.py +++ b/keras_preprocessing/image/dataframe_iterator.py @@ -8,6 +8,7 @@ import warnings import numpy as np +from collections import OrderedDict from .iterator import BatchFromFilesMixin, Iterator from .utils import validate_filename @@ -249,7 +250,8 @@ def remove_classes(labels, classes): ) if classes: - classes = set(classes) # sort and prepare for membership lookup + # prepare for membership lookup + classes = list(OrderedDict.fromkeys(classes).keys()) df[y_col] = df[y_col].apply(lambda x: remove_classes(x, classes)) else: classes = set() @@ -258,7 +260,8 @@ def remove_classes(labels, classes): classes.update(v) else: classes.add(v) - return df.dropna(subset=[y_col]), sorted(classes) + classes = sorted(classes) + return df.dropna(subset=[y_col]), classes def _filter_valid_filepaths(self, df, x_col): """Keep only dataframe rows with valid filenames diff --git a/tests/image/dataframe_iterator_test.py b/tests/image/dataframe_iterator_test.py index ede444ee..9c2a06a0 100644 --- a/tests/image/dataframe_iterator_test.py +++ b/tests/image/dataframe_iterator_test.py @@ -647,5 +647,47 @@ def test_dataframe_iterator_with_subdirs(all_test_images, tmpdir): assert set(df_iterator.filenames) == set(filenames) +def test_dataframe_iterator_classes_indices_order(all_test_images, tmpdir): + # save the images in the paths + count = 0 + filenames = [] + for test_images in all_test_images: + for im in test_images: + filename = 'image-{}.png'.format(count) + im.save(str(tmpdir / filename)) + filenames.append(filename) + count += 1 + + # Test the class_indices without classes input + generator = image_data_generator.ImageDataGenerator() + label_opt = ['a', 'b', ['a'], ['b'], ['a', 'b'], ['b', 'a']] + df_f = pd.DataFrame({ + "filename": filenames, + "class": ['a', 'b'] + [random.choice(label_opt) for _ in filenames[:-2]] + }) + flow_forward_iter = generator.flow_from_dataframe(df_f, str(tmpdir)) + label_rev = ['b', 'a', ['b'], ['a'], ['b', 'a'], ['a', 'b']] + df_r = pd.DataFrame({ + "filename": filenames, + "class": ['b', 'a'] + [random.choice(label_rev) for _ in filenames[:-2]] + }) + flow_backward_iter = generator.flow_from_dataframe(df_r, str(tmpdir)) + + # check class_indices + assert flow_forward_iter.class_indices == flow_backward_iter.class_indices + + # Test the class_indices with classes input + generator_2 = image_data_generator.ImageDataGenerator() + df_f2 = pd.DataFrame([['data/A.jpg', 'A'], ['data/B.jpg', 'B']], + columns=['filename', 'class']) + flow_forward = generator_2.flow_from_dataframe(df_f2, classes=['A', 'B']) + df_b2 = pd.DataFrame([['data/A.jpg', 'A'], ['data/B.jpg', 'B']], + columns=['filename', 'class']) + flow_backward = generator_2.flow_from_dataframe(df_b2, classes=['B', 'A']) + + # check class_indices + assert flow_forward.class_indices != flow_backward.class_indices + + if __name__ == '__main__': pytest.main([__file__])