From ddeff906b63b4a532385a6edb1b70a4a88383f75 Mon Sep 17 00:00:00 2001 From: Peter Somers <37300408+psomers3@users.noreply.github.com> Date: Mon, 7 Feb 2022 17:55:31 +0100 Subject: [PATCH 1/2] More robust random seeding using a numpy generator Only initialize the seed once on creation. Maybe an option to reset it could be added. --- keras_preprocessing/image/iterator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/keras_preprocessing/image/iterator.py b/keras_preprocessing/image/iterator.py index c62b1d3a..8eecc62a 100644 --- a/keras_preprocessing/image/iterator.py +++ b/keras_preprocessing/image/iterator.py @@ -33,6 +33,7 @@ def __init__(self, n, batch_size, shuffle, seed): self.n = n self.batch_size = batch_size self.seed = seed + self.numpy_generator = np.random.default_rnq(seed) self.shuffle = shuffle self.batch_index = 0 self.total_batches_seen = 0 @@ -43,7 +44,10 @@ def __init__(self, n, batch_size, shuffle, seed): def _set_index_array(self): self.index_array = np.arange(self.n) if self.shuffle: - self.index_array = np.random.permutation(self.n) + if self.seed is not None: + self.index_array = self.numpy_generator.permutation(self.n) + else: + self.index_array = np.random.permutation(self.n) def __getitem__(self, idx): if idx >= len(self): @@ -51,8 +55,6 @@ def __getitem__(self, idx): 'but the Sequence ' 'has length {length}'.format(idx=idx, length=len(self))) - if self.seed is not None: - np.random.seed(self.seed + self.total_batches_seen) self.total_batches_seen += 1 if self.index_array is None: self._set_index_array() @@ -73,8 +75,6 @@ def _flow_index(self): # Ensure self.batch_index is 0. self.reset() while 1: - if self.seed is not None: - np.random.seed(self.seed + self.total_batches_seen) if self.batch_index == 0: self._set_index_array() From 22e9bd98ba566cd8780ff4b5e740e1b5dadec7ea Mon Sep 17 00:00:00 2001 From: Peter Somers <37300408+psomers3@users.noreply.github.com> Date: Mon, 7 Feb 2022 18:40:35 +0100 Subject: [PATCH 2/2] cleaner generator usage --- keras_preprocessing/image/iterator.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/keras_preprocessing/image/iterator.py b/keras_preprocessing/image/iterator.py index 8eecc62a..d810ef9d 100644 --- a/keras_preprocessing/image/iterator.py +++ b/keras_preprocessing/image/iterator.py @@ -44,10 +44,7 @@ def __init__(self, n, batch_size, shuffle, seed): def _set_index_array(self): self.index_array = np.arange(self.n) if self.shuffle: - if self.seed is not None: - self.index_array = self.numpy_generator.permutation(self.n) - else: - self.index_array = np.random.permutation(self.n) + self.index_array = self.numpy_generator.permutation(self.n) def __getitem__(self, idx): if idx >= len(self):