-
Notifications
You must be signed in to change notification settings - Fork 16
/
data_iterator.py
113 lines (90 loc) · 3.35 KB
/
data_iterator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy as np
import theano
import theano.tensor as T
import sys, getopt
import logging
from state import *
from utils import *
from SS_dataset import *
import itertools
import sys
import pickle
import random
import datetime
logger = logging.getLogger(__name__)
def create_padded_batch(state, x):
mx = state['seqlen']
n = state['bs']
X = numpy.zeros((mx, n), dtype='int32')
Xmask = numpy.zeros((mx, n), dtype='float32')
# Fill X and Xmask
# Keep track of number of predictions and maximum triple length
num_preds = 0
max_length = 0
for idx in xrange(len(x[0])):
# Insert sequence idx in a column of matrix X
sent_length = len(x[0][idx])
if mx < sent_length:
continue
X[:sent_length, idx] = x[0][idx][:sent_length]
max_length = max(max_length, sent_length)
# Set the number of predictions == sum(Xmask), for cost purposes
num_preds += sent_length
# Mark the end of phrase
if len(x[0][idx]) < mx:
X[sent_length:, idx] = state['eos_sym']
# Initialize Xmask column with ones in all positions that
# were just set in X
Xmask[:sent_length, idx] = 1.
assert num_preds == numpy.sum(Xmask)
return {'x': X, 'x_mask': Xmask, 'num_preds': num_preds, 'max_length': max_length}
def get_batch_iterator(rng, state):
class Iterator(SSIterator):
def __init__(self, *args, **kwargs):
SSIterator.__init__(self, rng, *args, **kwargs)
self.batch_iter = None
def get_homogenous_batch_iter(self):
while True:
k_batches = state['sort_k_batches']
batch_size = state['bs']
data = []
for k in range(k_batches):
batch = SSIterator.next(self)
if batch:
data.append(batch)
if not len(data):
return
triples = data
x = numpy.asarray(list(itertools.chain(*triples)))
lens = numpy.asarray([map(len, x)])
order = numpy.argsort(lens.max(axis=0)) if state['sort_k_batches'] > 1 \
else numpy.arange(len(x))
for k in range(len(triples)):
indices = order[k * batch_size:(k + 1) * batch_size]
batch = create_padded_batch(state, [x[indices]])
if batch:
yield batch
def start(self):
SSIterator.start(self)
self.batch_iter = None
def next(self):
if not self.batch_iter:
self.batch_iter = self.get_homogenous_batch_iter()
try:
batch = next(self.batch_iter)
except StopIteration:
return None
return batch
train_data = Iterator(
batch_size=int(state['bs']),
triple_file=state['train_sentences'],
queue_size=100,
use_infinite_loop=True,
max_len=state['seqlen'])
valid_data = Iterator(
batch_size=int(state['bs']),
triple_file=state['valid_sentences'],
use_infinite_loop=False,
queue_size=100,
max_len=state['seqlen'])
return train_data, valid_data