-
Notifications
You must be signed in to change notification settings - Fork 5
/
DataSet.py
102 lines (73 loc) · 2.83 KB
/
DataSet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import skimage
import skimage.io
import skimage.transform
import csv
import skipthoughts
import random
import numpy as np
class DataSet(object):
def __init__(self, imagepath, tagfile, image_size):
self._images = []
self._tags = []
self._tag_vecs = []
with open(tagfile, 'r') as f:
print('Reading images and tags.')
reader = csv.reader(f, delimiter=',')
for idx, line in enumerate(reader):
tags = line[-1]
valid_tags = ''
for tag in tags.split('\t'):
text = tag.split(':')[0]
if 'hair' in text or 'eye' in text:
if not valid_tags:
valid_tags += text
else:
valid_tags += ' ' + text
if len(valid_tags) > 0:
ID = line[0]
filename = os.path.join(imagepath, ID+'.jpg')
img = skimage.io.imread(filename)
img = skimage.transform.resize(img, (image_size, image_size))
self._images.append(img)
self._tags.append(valid_tags)
sent2vec = skipthoughts.load_model()
self._tag_vecs = skipthoughts.encode(sent2vec, self._tags)
self._images = np.array(self._images)
self._tags = np.array(self._tags)
self._tag_vecs = np.array(self._tag_vecs)
self._image_num = len(self._tags)
self._index_in_epoch = 0
self._N_epoch = 0
return
def next_batch(self, batch_size=1):
read_images = []
wrong_images = []
vecs = []
for _ in range(batch_size):
if self._index_in_epoch >= self._image_num:
random_idx = np.arange(0, self._image_num)
np.random.shuffle(random_idx)
self._images = self._images[random_idx]
self._tags = self._tags[random_idx]
self._tag_vecs = self._tag_vecs[random_idx]
self._index_in_epoch = 0
self._N_epoch += 1
while True:
random_ID = random.randint(0, self._image_num-1)
if self._tags[self._index_in_epoch] not in self._tags[random_ID]:
break
read_images.append(self._images[self._index_in_epoch])
wrong_images.append(self._images[random_ID])
vecs.append(self._tag_vecs[self._index_in_epoch])
self._index_in_epoch += 1
return read_images, wrong_images, vecs
@property
def image_num(self):
return self._image_num
@property
def index_in_epoch(self):
return self._index_in_epoch
@property
def N_epoch(self):
return self._N_epoch