forked from sordonia/rnn-lm
-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathSS_dataset.py
120 lines (97 loc) · 3.1 KB
/
SS_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import numpy as np
import os, gc
import cPickle
import copy
import logging
import threading
import Queue
import collections
logger = logging.getLogger(__name__)
np.random.seed(1234)
class SSFetcher(threading.Thread):
def __init__(self, parent):
threading.Thread.__init__(self)
self.parent = parent
self.indexes = np.arange(parent.data_len)
def run(self):
diter = self.parent
self.parent.rng.shuffle(self.indexes)
offset = 0
# Take groups of 10000 triples and group by length
while not diter.exit_flag:
last_batch = False
triples = []
while len(triples) < diter.batch_size:
if offset == diter.data_len:
if not diter.use_infinite_loop:
last_batch = True
break
else:
# Infinite loop here, we reshuffle the indexes
# and reset the offset
np.random.shuffle(self.indexes)
offset = 0
index = self.indexes[offset]
s = diter.data[index]
offset += 1
if len(s) > diter.max_len:
continue
triples.append(s)
if len(triples):
diter.queue.put(triples)
if last_batch:
diter.queue.put(None)
return
class SSIterator(object):
def __init__(self,
rng,
batch_size,
triple_file=None,
dtype="int32",
can_fit=False,
queue_size=100,
cache_size=100,
shuffle=True,
use_infinite_loop=True,
max_len=1000):
args = locals()
args.pop("self")
self.rng = rng
self.__dict__.update(args)
self.load_files()
self.exit_flag = False
def load_files(self):
self.data = cPickle.load(open(self.triple_file, 'r'))
self.data_len = len(self.data)
logger.debug('Data len is %d' % self.data_len)
def start(self):
self.exit_flag = False
self.queue = Queue.Queue(maxsize=self.queue_size)
self.gather = SSFetcher(self)
self.gather.daemon = True
self.gather.start()
def __del__(self):
if hasattr(self, 'gather'):
self.gather.exitFlag = True
self.gather.join()
def __iter__(self):
return self
def next(self):
if self.exit_flag:
return None
batch = self.queue.get()
if not batch:
self.exit_flag = True
return batch
if __name__ == '__main__':
""" debug """
import sys
iterator = SSIterator(100, triple_file=sys.argv[1], use_infinite_loop=False)
iterator.start()
_cpt = 0
while True:
batch = iterator.next()
if batch is None:
break
_cpt += 1
print "Read %d batches" % _cpt