-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodel.py
executable file
·202 lines (171 loc) · 7.69 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import torch
import torch.nn as nn
import torch.nn.functional as F
from beam import Beam
class EncoderRNN(nn.Module):
""""encode the input sequence with Bi-GRU"""
def __init__(self, num_input, num_hidden, num_token, padding_idx, emb_dropout, hid_dropout):
super(EncoderRNN, self).__init__()
self.num_hidden = num_hidden
self.emb = nn.Embedding(num_token, num_input, padding_idx=padding_idx)
self.bi_gru = nn.GRU(num_input, num_hidden, 1, batch_first=True, bidirectional=True)
self.enc_emb_dp = nn.Dropout(emb_dropout)
self.enc_hid_dp = nn.Dropout(hid_dropout)
def init_hidden(self, batch_size):
weight = next(self.parameters())
h0 = weight.new_zeros(2, batch_size, self.num_hidden)
return h0
def forward(self, input, mask):
hidden = self.init_hidden(input.size(0))
input = self.enc_emb_dp(self.emb(input))
length = mask.sum(1).tolist()
total_length = mask.size(1)
input = nn.utils.rnn.pack_padded_sequence(input, length, batch_first=True)
output, hidden = self.bi_gru(input, hidden)
output = torch.nn.utils.rnn.pad_packed_sequence(
output, batch_first=True, total_length=total_length
)[0]
output = self.enc_hid_dp(output)
hidden = torch.cat([hidden[0], hidden[1]], dim=-1)
return output, hidden
class Attention(nn.Module):
"""Attention Mechanism"""
def __init__(self, num_hidden, ncontext, natt):
super(Attention, self).__init__()
self.h2s = nn.Linear(num_hidden, natt)
self.s2s = nn.Linear(ncontext, natt)
self.a2o = nn.Linear(natt, 1)
def forward(self, hidden, mask, context):
shape = context.size()
attn_h = self.s2s(context.view(-1, shape[2]))
attn_h = attn_h.view(shape[0], shape[1], -1)
attn_h += self.h2s(hidden).unsqueeze(1).expand_as(attn_h)
logit = self.a2o(torch.tanh(attn_h)).view(shape[0], shape[1])
if mask.any():
logit.data.masked_fill_(~mask, -float("inf"))
softmax = F.softmax(logit, dim=1)
output = torch.bmm(softmax.unsqueeze(1), context).squeeze(1)
return output
class DecoderRNN(nn.Module):
def __init__(self, num_input, num_hidden, enc_ncontext, natt, nreadout, readout_dropout):
super(DecoderRNN, self).__init__()
self.gru1 = nn.GRUCell(num_input, num_hidden)
self.gru2 = nn.GRUCell(enc_ncontext, num_hidden)
self.enc_attn = Attention(num_hidden, enc_ncontext, natt)
self.embedding2out = nn.Linear(num_input, nreadout)
self.hidden2out = nn.Linear(num_hidden, nreadout)
self.c2o = nn.Linear(enc_ncontext, nreadout)
self.readout_dp = nn.Dropout(readout_dropout)
def forward(self, emb, hidden, enc_mask, enc_context):
hidden = self.gru1(emb, hidden)
attn_enc = self.enc_attn(hidden, enc_mask, enc_context)
hidden = self.gru2(attn_enc, hidden)
output = torch.tanh(self.embedding2out(emb) + self.hidden2out(hidden) + self.c2o(attn_enc))
output = self.readout_dp(output)
return output, hidden
class AttEncDecRNN(nn.Module):
def __init__(self, opt):
super(AttEncDecRNN, self).__init__()
self.dec_num_hidden = opt.dec_num_hidden
self.dec_sos = opt.dec_sos
self.dec_eos = opt.dec_eos
self.dec_pad = opt.dec_pad
self.enc_pad = opt.enc_pad
self.emb = nn.Embedding(opt.dec_num_token, opt.dec_num_input, padding_idx=opt.dec_pad)
self.encoder = EncoderRNN(
opt.enc_num_input,
opt.enc_num_hidden,
opt.enc_num_token,
opt.enc_pad,
opt.enc_emb_dropout,
opt.enc_hid_dropout,
)
self.decoder = DecoderRNN(
opt.dec_num_input,
opt.dec_num_hidden,
2 * opt.enc_num_hidden,
opt.dec_natt,
opt.nreadout,
opt.readout_dropout,
)
self.affine = nn.Linear(opt.nreadout, opt.dec_num_token)
self.init_affine = nn.Linear(2 * opt.enc_num_hidden, opt.dec_num_hidden)
self.dec_emb_dp = nn.Dropout(opt.dec_emb_dropout)
def forward(self, src, src_mask, f_trg, f_trg_mask, b_trg=None, b_trg_mask=None):
enc_context, _ = self.encoder(src, src_mask)
enc_context = enc_context.contiguous()
avg_enc_context = enc_context.sum(1)
enc_context_len = src_mask.sum(1).unsqueeze(-1).expand_as(avg_enc_context)
avg_enc_context = avg_enc_context / enc_context_len
attn_mask = src_mask.bool()
hidden = torch.tanh(self.init_affine(avg_enc_context))
loss = 0
for i in range(f_trg.size(1) - 1):
output, hidden = self.decoder(
self.dec_emb_dp(self.emb(f_trg[:, i])), hidden, attn_mask, enc_context
)
loss += (
F.cross_entropy(self.affine(output), f_trg[:, i + 1], reduction="none")
* f_trg_mask[:, i + 1]
)
w_loss = loss.sum() / f_trg_mask[:, 1:].sum()
loss = loss.mean()
return loss.unsqueeze(0), w_loss.unsqueeze(0)
def beamsearch(
self, src, src_mask, beam_size=10, normalize=False, max_len=None, min_len=None
):
max_len = src.size(1) * 3 if max_len is None else max_len
min_len = src.size(1) / 2 if min_len is None else min_len
enc_context, _ = self.encoder(src, src_mask)
enc_context = enc_context.contiguous()
avg_enc_context = enc_context.sum(1)
enc_context_len = src_mask.sum(1).unsqueeze(-1).expand_as(avg_enc_context)
avg_enc_context = avg_enc_context / enc_context_len
attn_mask = src_mask.bool()
hidden = torch.tanh(self.init_affine(avg_enc_context))
prev_beam = Beam(beam_size)
prev_beam.candidates = [[self.dec_sos]]
prev_beam.scores = [0]
f_done = lambda x: x[-1] == self.dec_eos
valid_size = beam_size
hyp_list = []
for k in range(max_len):
candidates = prev_beam.candidates
input = src.new_tensor([cand[-1] for cand in candidates])
input = self.dec_emb_dp(self.emb(input))
output, hidden = self.decoder(input, hidden, attn_mask, enc_context)
log_prob = F.log_softmax(self.affine(output), dim=1)
if k < min_len:
log_prob[:, self.dec_eos] = -float("inf")
if k == max_len - 1:
eos_prob = log_prob[:, self.dec_eos].clone()
log_prob[:, :] = -float("inf")
log_prob[:, self.dec_eos] = eos_prob
next_beam = Beam(valid_size)
done_list, remain_list = next_beam.step(-log_prob, prev_beam, f_done)
hyp_list.extend(done_list)
valid_size -= len(done_list)
if valid_size == 0:
break
beam_remain_ix = src.new_tensor(remain_list)
enc_context = enc_context.index_select(0, beam_remain_ix)
attn_mask = attn_mask.index_select(0, beam_remain_ix)
hidden = hidden.index_select(0, beam_remain_ix)
prev_beam = next_beam
score_list = [hyp[1] for hyp in hyp_list]
hyp_list = [
hyp[0][1 : hyp[0].index(self.dec_eos)]
if self.dec_eos in hyp[0]
else hyp[0][1:]
for hyp in hyp_list
]
if normalize:
for k, (hyp, score) in enumerate(zip(hyp_list, score_list)):
if len(hyp) > 0:
score_list[k] = score_list[k] / len(hyp)
score = hidden.new_tensor(score_list)
sort_score, sort_ix = torch.sort(score)
output = []
for ix in sort_ix.tolist():
output.append((hyp_list[ix], score[ix].item()))
return output