forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
beam_search.py
155 lines (127 loc) · 5.18 KB
/
beam_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Beam search module.
Beam search takes the top K results from the model, predicts the K results for
each of the previous K result, getting K*K results. Pick the top K results from
K*K results, and start over again until certain number of results are fully
decoded.
"""
import tensorflow as tf
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_bool('normalize_by_length', True, 'Whether normalize')
class Hypothesis(object):
"""Defines a hypothesis during beam search."""
def __init__(self, tokens, log_prob, state):
"""Hypothesis constructor.
Args:
tokens: start tokens for decoding.
log_prob: log prob of the start tokens, usually 1.
state: decoder initial states.
"""
self.tokens = tokens
self.log_prob = log_prob
self.state = state
def Extend(self, token, log_prob, new_state):
"""Extend the hypothesis with result from latest step.
Args:
token: latest token from decoding.
log_prob: log prob of the latest decoded tokens.
new_state: decoder output state. Fed to the decoder for next step.
Returns:
New Hypothesis with the results from latest step.
"""
return Hypothesis(self.tokens + [token], self.log_prob + log_prob,
new_state)
@property
def latest_token(self):
return self.tokens[-1]
def __str__(self):
return ('Hypothesis(log prob = %.4f, tokens = %s)' % (self.log_prob,
self.tokens))
class BeamSearch(object):
"""Beam search."""
def __init__(self, model, beam_size, start_token, end_token, max_steps):
"""Creates BeamSearch object.
Args:
model: Seq2SeqAttentionModel.
beam_size: int.
start_token: int, id of the token to start decoding with
end_token: int, id of the token that completes an hypothesis
max_steps: int, upper limit on the size of the hypothesis
"""
self._model = model
self._beam_size = beam_size
self._start_token = start_token
self._end_token = end_token
self._max_steps = max_steps
def BeamSearch(self, sess, enc_inputs, enc_seqlen):
"""Performs beam search for decoding.
Args:
sess: tf.Session, session
enc_inputs: ndarray of shape (enc_length, 1), the document ids to encode
enc_seqlen: ndarray of shape (1), the length of the sequnce
Returns:
hyps: list of Hypothesis, the best hypotheses found by beam search,
ordered by score
"""
# Run the encoder and extract the outputs and final state.
enc_top_states, dec_in_state = self._model.encode_top_state(
sess, enc_inputs, enc_seqlen)
# Replicate the initial states K times for the first step.
hyps = [Hypothesis([self._start_token], 0.0, dec_in_state)
] * self._beam_size
results = []
steps = 0
while steps < self._max_steps and len(results) < self._beam_size:
latest_tokens = [h.latest_token for h in hyps]
states = [h.state for h in hyps]
topk_ids, topk_log_probs, new_states = self._model.decode_topk(
sess, latest_tokens, enc_top_states, states)
# Extend each hypothesis.
all_hyps = []
# The first step takes the best K results from first hyps. Following
# steps take the best K results from K*K hyps.
num_beam_source = 1 if steps == 0 else len(hyps)
for i in xrange(num_beam_source):
h, ns = hyps[i], new_states[i]
for j in xrange(self._beam_size*2):
all_hyps.append(h.Extend(topk_ids[i, j], topk_log_probs[i, j], ns))
# Filter and collect any hypotheses that have the end token.
hyps = []
for h in self._BestHyps(all_hyps):
if h.latest_token == self._end_token:
# Pull the hypothesis off the beam if the end token is reached.
results.append(h)
else:
# Otherwise continue to the extend the hypothesis.
hyps.append(h)
if len(hyps) == self._beam_size or len(results) == self._beam_size:
break
steps += 1
if steps == self._max_steps:
results.extend(hyps)
return self._BestHyps(results)
def _BestHyps(self, hyps):
"""Sort the hyps based on log probs and length.
Args:
hyps: A list of hypothesis.
Returns:
hyps: A list of sorted hypothesis in reverse log_prob order.
"""
# This length normalization is only effective for the final results.
if FLAGS.normalize_by_length:
return sorted(hyps, key=lambda h: h.log_prob/len(h.tokens), reverse=True)
else:
return sorted(hyps, key=lambda h: h.log_prob, reverse=True)