-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattention_my.py
122 lines (90 loc) · 4.65 KB
/
attention_my.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import tensorflow as tf
from tensorflow.python.keras import backend as K
logger = tf.get_logger()
class AttentionLayer(tf.keras.layers.Layer):
"""
This class implements Bahdanau attention (https://arxiv.org/pdf/1409.0473.pdf).
There are three sets of weights introduced W_a, U_a, and V_a
"""
def __init__(self, **kwargs):
super(AttentionLayer, self).__init__(**kwargs)
def build(self, input_shape):
assert isinstance(input_shape, list)
# Create a trainable weight variable for this layer.
self.W_a = self.add_weight(name='W_a',
shape=tf.TensorShape((input_shape[0][2], input_shape[0][2])),
initializer='uniform',
trainable=True)
self.U_a = self.add_weight(name='U_a',
shape=tf.TensorShape((input_shape[1][2], input_shape[0][2])),
initializer='uniform',
trainable=True)
self.V_a = self.add_weight(name='V_a',
shape=tf.TensorShape((input_shape[0][2], 1)),
initializer='uniform',
trainable=True)
super(AttentionLayer, self).build(input_shape) # Be sure to call this at the end
def call(self, inputs):
"""
inputs: [encoder_output_sequence, decoder_output_sequence]
"""
assert type(inputs) == list
encoder_out_seq, decoder_out_seq = inputs
logger.debug(f"encoder_out_seq.shape = {encoder_out_seq.shape}")
logger.debug(f"decoder_out_seq.shape = {decoder_out_seq.shape}")
def energy_step(inputs, states):
""" Step function for computing energy for a single decoder state
inputs: (batchsize * 1 * de_in_dim)
states: (batchsize * 1 * de_latent_dim)
"""
logger.debug("Running energy computation step")
if not isinstance(states, (list, tuple)):
raise TypeError(f"States must be an iterable. Got {states} of type {type(states)}")
encoder_full_seq = states[-1]
""" Computing S.Wa where S=[s0, s1, ..., si]"""
# <= batch size * en_seq_len * latent_dim
W_a_dot_s = K.dot(encoder_full_seq, self.W_a)
""" Computing hj.Ua """
U_a_dot_h = K.expand_dims(K.dot(inputs, self.U_a), 1) # <= batch_size, 1, latent_dim
logger.debug(f"U_a_dot_h.shape = {U_a_dot_h.shape}")
""" tanh(S.Wa + hj.Ua) """
# <= batch_size*en_seq_len, latent_dim
Ws_plus_Uh = K.tanh(W_a_dot_s + U_a_dot_h)
logger.debug(f"Ws_plus_Uh.shape = {Ws_plus_Uh.shape}")
""" softmax(va.tanh(S.Wa + hj.Ua)) """
# <= batch_size, en_seq_len
e_i = K.squeeze(K.dot(Ws_plus_Uh, self.V_a), axis=-1)
# <= batch_size, en_seq_len
e_i = K.softmax(e_i)
logger.debug(f"ei.shape = {e_i.shape}")
return e_i, [e_i]
def context_step(inputs, states):
""" Step function for computing ci using ei """
logger.debug("Running attention vector computation step")
if not isinstance(states, (list, tuple)):
raise TypeError(f"States must be an iterable. Got {states} of type {type(states)}")
encoder_full_seq = states[-1]
# <= batch_size, hidden_size
c_i = K.sum(encoder_full_seq * K.expand_dims(inputs, -1), axis=1)
logger.debug(f"ci.shape = {c_i.shape}")
return c_i, [c_i]
# we don't maintain states between steps when computing attention
# attention is stateless, so we're passing a fake state for RNN step function
fake_state_c = K.sum(encoder_out_seq, axis=1)
fake_state_e = K.sum(encoder_out_seq, axis=2) # <= (batch_size, enc_seq_len, latent_dim
""" Computing energy outputs """
# e_outputs => (batch_size, de_seq_len, en_seq_len)
last_out, e_outputs, _ = K.rnn(
energy_step, decoder_out_seq, [fake_state_e], constants=[encoder_out_seq]
)
""" Computing context vectors """
last_out, c_outputs, _ = K.rnn(
context_step, e_outputs, [fake_state_c], constants=[encoder_out_seq]
)
return c_outputs, e_outputs
def compute_output_shape(self, input_shape):
""" Outputs produced by the layer """
return [
tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[1][2])),
tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[0][1]))
]