Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
To solve issue #13637 in keras repo (#270)
Browse files Browse the repository at this point in the history
  • Loading branch information
seriousran authored Jun 29, 2020
1 parent 422f890 commit 5949df1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
7 changes: 5 additions & 2 deletions keras_preprocessing/image/dataframe_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings

import numpy as np
from collections import OrderedDict

from .iterator import BatchFromFilesMixin, Iterator
from .utils import validate_filename
Expand Down Expand Up @@ -249,7 +250,8 @@ def remove_classes(labels, classes):
)

if classes:
classes = set(classes) # sort and prepare for membership lookup
# prepare for membership lookup
classes = list(OrderedDict.fromkeys(classes).keys())
df[y_col] = df[y_col].apply(lambda x: remove_classes(x, classes))
else:
classes = set()
Expand All @@ -258,7 +260,8 @@ def remove_classes(labels, classes):
classes.update(v)
else:
classes.add(v)
return df.dropna(subset=[y_col]), sorted(classes)
classes = sorted(classes)
return df.dropna(subset=[y_col]), classes

def _filter_valid_filepaths(self, df, x_col):
"""Keep only dataframe rows with valid filenames
Expand Down
42 changes: 42 additions & 0 deletions tests/image/dataframe_iterator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,5 +647,47 @@ def test_dataframe_iterator_with_subdirs(all_test_images, tmpdir):
assert set(df_iterator.filenames) == set(filenames)


def test_dataframe_iterator_classes_indices_order(all_test_images, tmpdir):
# save the images in the paths
count = 0
filenames = []
for test_images in all_test_images:
for im in test_images:
filename = 'image-{}.png'.format(count)
im.save(str(tmpdir / filename))
filenames.append(filename)
count += 1

# Test the class_indices without classes input
generator = image_data_generator.ImageDataGenerator()
label_opt = ['a', 'b', ['a'], ['b'], ['a', 'b'], ['b', 'a']]
df_f = pd.DataFrame({
"filename": filenames,
"class": ['a', 'b'] + [random.choice(label_opt) for _ in filenames[:-2]]
})
flow_forward_iter = generator.flow_from_dataframe(df_f, str(tmpdir))
label_rev = ['b', 'a', ['b'], ['a'], ['b', 'a'], ['a', 'b']]
df_r = pd.DataFrame({
"filename": filenames,
"class": ['b', 'a'] + [random.choice(label_rev) for _ in filenames[:-2]]
})
flow_backward_iter = generator.flow_from_dataframe(df_r, str(tmpdir))

# check class_indices
assert flow_forward_iter.class_indices == flow_backward_iter.class_indices

# Test the class_indices with classes input
generator_2 = image_data_generator.ImageDataGenerator()
df_f2 = pd.DataFrame([['data/A.jpg', 'A'], ['data/B.jpg', 'B']],
columns=['filename', 'class'])
flow_forward = generator_2.flow_from_dataframe(df_f2, classes=['A', 'B'])
df_b2 = pd.DataFrame([['data/A.jpg', 'A'], ['data/B.jpg', 'B']],
columns=['filename', 'class'])
flow_backward = generator_2.flow_from_dataframe(df_b2, classes=['B', 'A'])

# check class_indices
assert flow_forward.class_indices != flow_backward.class_indices


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 5949df1

Please sign in to comment.