-
Notifications
You must be signed in to change notification settings - Fork 17
/
model.py
195 lines (184 loc) · 8.9 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from __future__ import print_function
import tensorflow as tf
from qrnn import QRNN_layer
from tensorflow.contrib.layers import xavier_initializer
from tensorflow.contrib.layers import flatten, fully_connected
import numpy as np
import os
def scalar_summary(name, x):
try:
summ = tf.summary.scalar(name, x)
except AttributeError:
summ = tf.scalar_summary(name, x)
return summ
def histogram_summary(name, x):
try:
summ = tf.summary.histogram(name, x)
except AttributeError:
summ = tf.histogram_summary(name, x)
return summ
class QRNN_lm(object):
""" Implement the Language Model from https://arxiv.org/abs/1611.01576 """
def __init__(self, args, infer=False, test=False):
self.batch_size = args.batch_size
self.seq_len = args.seq_len
if infer:
self.batch_size = 1
self.seq_len = 1
self.infer = infer
self.vocab_size = args.vocab_size
self.emb_dim = args.emb_dim
self.zoneout = args.zoneout
self.dropout = args.dropout
if test:
self.zoneout = self.dropout = 0
self.test = test
self.qrnn_size = args.qrnn_size
self.qrnn_layers = args.qrnn_layers
self.words_in = tf.placeholder(tf.int32, [self.batch_size,
self.seq_len])
self.words_gtruth = tf.placeholder(tf.int32, [self.batch_size,
self.seq_len])
self.logits, self.output = self.inference()
self.loss = self.lm_loss(self.logits, self.words_gtruth)
self.loss_summary = scalar_summary('loss', self.loss)
self.perp_summary = scalar_summary('perplexity', tf.exp(self.loss))
# set up optimizer
self.lr = tf.Variable(args.learning_rate, trainable=False)
self.lr_summary = scalar_summary('lr', self.lr)
tvars = tf.trainable_variables()
grads = []
for grad in tf.gradients(self.loss, tvars):
if grad is not None:
grads.append(tf.clip_by_norm(grad, args.grad_clip))
else:
grads.append(grad)
#grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars),
# args.grad_clip)
self.opt = tf.train.GradientDescentOptimizer(self.lr)
self.train_op = self.opt.apply_gradients(zip(grads, tvars))
def inference(self):
words_in = self.words_in
embeddings = None
# keep track of Recurrent states to re-initialize them when needed
self.initial_states = []
self.last_states = []
self.qrnns = []
with tf.variable_scope('QRNN_LM'):
word_W = tf.get_variable("word_W",
[self.vocab_size,
self.emb_dim],
initializer=tf.random_uniform_initializer(minval=-.05, maxval=.05))
words = tf.split(1, self.seq_len, tf.expand_dims(words_in, -1))
# print('len of words: ', len(words))
for word_idx in words:
word_embed = tf.nn.embedding_lookup(word_W, word_idx)
if (not self.test or not self.infer) and self.dropout > 0:
word_embed = tf.nn.dropout(word_embed, (1. - self.dropout),
name='dout_word_emb')
# print('word embed shape: ', word_embed.get_shape().as_list())
if embeddings is None:
embeddings = tf.squeeze(word_embed, [1])
else:
embeddings = tf.concat(1, [embeddings,
tf.squeeze(word_embed, [1])])
qrnn_h = embeddings
for qrnn_l in range(self.qrnn_layers):
qrnn_ = QRNN_layer(self.qrnn_size, pool_type='fo',
zoneout=self.zoneout,
name='QRNN_layer{}'.format(qrnn_l),
infer=self.infer)
qrnn_h, last_state = qrnn_(qrnn_h)
#qrnn_h = qrnn_.h
# apply dropout if required
if (not self.test or not self.infer) and self.dropout > 0:
qrnn_h_f = tf.reshape(qrnn_h, [-1, self.qrnn_size])
qrnn_h_dout = tf.nn.dropout(qrnn_h_f, (1. - self.dropout),
name='dout_qrnn{}'.format(qrnn_l))
qrnn_h = tf.reshape(qrnn_h_dout, [self.batch_size, -1, self.qrnn_size])
#self.last_states.append(qrnn_.last_state)
self.last_states.append(last_state)
histogram_summary('qrnn_state_{}'.format(qrnn_l),
last_state)
scalar_summary('qrnn_avg_state_{}'.format(qrnn_l),
tf.reduce_mean(last_state))
self.initial_states.append(qrnn_.initial_state)
self.qrnns.append(qrnn_)
qrnn_h_f = tf.reshape(qrnn_h, [-1, self.qrnn_size])
logits = fully_connected(qrnn_h_f,
self.vocab_size,
activation_fn=None,
weights_initializer=tf.random_uniform_initializer(minval=-.05, maxval=.05),
scope='output_softmax')
output = tf.nn.softmax(logits)
return logits, output
def lm_loss(self, logits, words_gtruth):
f_words_gtruth = tf.reshape(words_gtruth,
[self.batch_size * self.seq_len])
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits,
f_words_gtruth)
loss = tf.reduce_sum(tf.reshape(loss, [self.batch_size, -1]), 1)
return tf.reduce_mean(loss)/self.seq_len
def sample(self, sess, num_words, vocab, first_word='hello'):
word2idx = vocab['word2idx']
idx2word = vocab['idx2word']
vocab_size = len(word2idx)
# make sure it's lowercase
first_word = first_word.lower()
curr_word = np.zeros((1, 1), dtype=np.int32)
try:
curr_word[0, 0] = word2idx[first_word]
print('First word idx: ', curr_word[0, 0])
except KeyError:
print('First word {} is not in vocab, '
'setting <unk>'.format(first_word))
curr_word[0, 0] = word2idx['<unk>']
def sample_temperature(preds, temperature=1.0):
# helper function to sample an index from a probability array
preds = np.asarray(preds).astype('float64')
preds = np.log(preds) / temperature
exp_preds = np.exp(preds)
preds = exp_preds / np.sum(exp_preds)
probas = np.random.multinomial(1, preds, 1)
return np.argmax(probas)
qrnn_activations = [[], []]
prev_states = []
for qrnn_ in self.qrnns:
prev_states.append(sess.run(qrnn_.initial_state))
out_stream = [first_word]
print('---Sampling LM from first word \"{}\"---'.format(first_word))
for widx in range(num_words):
print(idx2word[curr_word[0, 0]], end=' ')
fdict = {self.words_in: curr_word}
for state, init_state in zip(prev_states, self.initial_states):
fdict.update({init_state: state})
output, logits, states, Z1, Z2 = sess.run([self.output,
self.logits,
self.qrnns[0].Z,
self.qrnns[1].Z,
self.last_states],
feed_dict=fdict)
qrnn_activations[0].append(Z1[0][0])
qrnn_activations[1].append(Z2[0][0])
curr_word[0, 0] = sample_temperature(output[0], 0.75)
out_stream.append(idx2word[curr_word[0, 0]])
for idx, new_state in enumerate(states):
prev_states[idx] = states[idx]
print('')
return ' '.join(out_stream), qrnn_activations
def save(self, sess, save_filename, global_step):
if not hasattr(self, 'saver'):
self.saver = tf.train.Saver()
print('Saving checkpoint...')
self.saver.save(sess, save_filename, global_step)
def load(self, sess, save_path):
if not hasattr(self, 'saver'):
self.saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(save_path)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
print('Loading checkpoint {}...'.format(ckpt_name))
self.saver.restore(sess, os.path.join(save_path, ckpt_name))
return True
else:
return False