-
Notifications
You must be signed in to change notification settings - Fork 4
/
data_generator.py
28 lines (22 loc) · 1.12 KB
/
data_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import tensorflow as tf
def data_generator(path, batch_size):
files = tf.data.Dataset.list_files(path)
def parse_function(example_proto):
features = tf.io.parse_single_example(example_proto,
features={
'data': tf.io.FixedLenFeature(
shape=(256, 256, 10),
dtype=tf.float32),
})
return features['data']
def data_iterator(tfrecords):
dataset = tf.data.TFRecordDataset(tfrecords, num_parallel_reads=12)
dataset = dataset.map(map_func=parse_function, num_parallel_calls=12)
dataset = dataset.repeat(-1)
# dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(batch_size)
dataset = dataset.apply(tf.data.experimental.prefetch_to_device("/gpu:0"))
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
train_iterator = data_iterator(files)
return train_iterator