-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbert_aug.py
323 lines (288 loc) · 13.2 KB
/
bert_aug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
#!/usr/bin/env python
# encoding=utf-8
import tensorflow as tf
# from transformers import *
import heapq
# from tensorflow.python.ops.gen_math_ops import mod
import numpy as np
from bert import modeling as modeling, tokenization
from collections import defaultdict
print(tf.__version__)
def gather_indexes(sequence_tensor, positions):
"""Gathers the vectors at the specific positions over a minibatch."""
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
batch_size = sequence_shape[0]
seq_length = sequence_shape[1]
width = sequence_shape[2]
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions):
"""Get loss and log probs for the masked LM."""
input_tensor = gather_indexes(input_tensor, positions)
with tf.variable_scope("cls/predictions"):
# We apply one more non-linear transformation before the output layer.
# This matrix is not used after pre-training.
with tf.variable_scope("transform"):
input_tensor = tf.layers.dense(
input_tensor,
units=bert_config.hidden_size,
activation=modeling.get_activation(bert_config.hidden_act),
kernel_initializer=modeling.create_initializer(
bert_config.initializer_range))
input_tensor = modeling.layer_norm(input_tensor)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
output_bias = tf.get_variable(
"output_bias",
shape=[bert_config.vocab_size],
initializer=tf.zeros_initializer())
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
return logits
class BertAugmentor(object):
def __init__(self, model_dir, beam_size=1):
self.beam_size = beam_size # 每个带mask的句子最多生成 beam_size 个。
# bert的配置文件
self.bert_config_file = model_dir + 'bert_config.json'
self.init_checkpoint = model_dir + 'bert_model.ckpt'
self.bert_vocab_file = model_dir + 'vocab.txt'
self.bert_config = modeling.BertConfig.from_json_file(
self.bert_config_file)
self.token = tokenization.FullTokenizer(vocab_file=self.bert_vocab_file, do_lower_case=False)
self.mask_token = "[MASK]"
self.mask_id = self.token.convert_tokens_to_ids([self.mask_token])[0]
self.cls_token = "[CLS]"
self.cls_id = self.token.convert_tokens_to_ids([self.cls_token])[0]
self.sep_token = "[SEP]"
self.sep_id = self.token.convert_tokens_to_ids([self.sep_token])[0]
# 构图
self.build()
# sess init
self.build_sess()
def __del__(self):
# 析构函数
self.close_sess()
def build(self):
# placeholder
self.input_ids = tf.placeholder(
tf.int32, shape=[None, None], name='input_ids')
self.input_mask = tf.placeholder(
tf.int32, shape=[None, None], name='input_masks')
self.segment_ids = tf.placeholder(
tf.int32, shape=[None, None], name='segment_ids')
self.masked_lm_positions = tf.placeholder(
tf.int32, shape=[None, None], name='masked_lm_positions')
# 初始化BERT
self.model = modeling.BertModel(
config=self.bert_config,
is_training=False,
input_ids=self.input_ids,
input_mask=self.input_mask,
token_type_ids=self.segment_ids,
use_one_hot_embeddings=False)
self.masked_logits = get_masked_lm_output(
self.bert_config, self.model.get_sequence_output(), self.model.get_embedding_table(),
self.masked_lm_positions)
self.predict_prob = tf.nn.softmax(self.masked_logits, axis=-1)
# 加载bert模型
tvars = tf.trainable_variables()
(assignment, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
tvars, self.init_checkpoint)
tf.train.init_from_checkpoint(self.init_checkpoint, assignment)
def build_sess(self):
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
def close_sess(self):
# self.sess.close()
pass
def predict_single_mask(self, word_ids: list, mask_index: int, prob: float = None):
"""输入一个句子token id list,对其中第mask_index个的mask的可能内容,返回 self.beam_size 个候选词语,以及prob"""
word_ids_out = []
word_mask = [1] * len(word_ids)
word_segment_ids = [0] * len(word_ids)
fd = {self.input_ids: [word_ids], self.input_mask: [word_mask], self.segment_ids: [
word_segment_ids], self.masked_lm_positions: [[mask_index]]}
mask_probs = self.sess.run(self.predict_prob, feed_dict=fd)
for mask_prob in mask_probs:
mask_prob = mask_prob.tolist()
max_num_index_list = map(mask_prob.index, heapq.nlargest(self.beam_size, mask_prob))
for i in max_num_index_list:
if prob and mask_prob[i] < prob:
continue
cur_word_ids = word_ids.copy()
cur_word_ids[mask_index] = i
word_ids_out.append([cur_word_ids, mask_prob[i]])
return word_ids_out
def predict_batch_mask(self, query_ids: list, mask_indexes: int, prob: float = 0.5):
"""输入多个token id list,对其中第mask_index个的mask的可能内容,返回 self.beam_size 个候选词语,以及prob
word_ids: [word_ids1:list, ], shape=[batch, query_lenght]
mask_indexes: query要预测的mask_id, [[mask_id], ...], shape=[batch, 1, 1]
"""
word_ids_out = []
word_mask = [[1] * len(x) for x in query_ids]
word_segment_ids = [[1] * len(x) for x in query_ids]
fd = {self.input_ids: query_ids, self.input_mask: word_mask, self.segment_ids:
word_segment_ids, self.masked_lm_positions: mask_indexes}
mask_probs = self.sess.run(self.predict_prob, feed_dict=fd)
for mask_prob, word_ids_, mask_index in zip(mask_probs, query_ids, mask_indexes):
# each query of batch
cur_out = []
mask_prob = mask_prob.tolist()
max_num_index_list = map(mask_prob.index, heapq.nlargest(self.n_best, mask_prob))
for i in max_num_index_list:
cur_word_ids = word_ids_.copy()
cur_word_ids[mask_index[0]] = i
cur_out.append([cur_word_ids, mask_prob[i]])
word_ids_out.append(cur_out)
return word_ids_out
def gen_sen(self, word_ids: list, indexes: list):
"""
输入是一个word id list, 其中包含mask,对mask生产对应的词语。
因为每个query的mask数量不一致,预测测试不一致,需要单独预测
"""
out_arr = []
for i, index_ in enumerate(indexes):
if i == 0:
out_arr = self.predict_single_mask(word_ids, index_)
else:
tmp_arr = out_arr.copy()
out_arr = []
for word_ids_, prob in tmp_arr:
cur_arr = self.predict_single_mask(word_ids_, index_)
cur_arr = [[x[0], x[1] * prob] for x in cur_arr]
out_arr.extend(cur_arr)
# 筛选前beam size个
out_arr = sorted(out_arr, key=lambda x: x[1], reverse=True)[:self.beam_size]
for i, (each, _) in enumerate(out_arr):
covert_ids_to_tokens_list = self.token.convert_ids_to_tokens(each)
query_src = covert_ids_to_tokens_list
# query_src = [x.convert_ids_to_tokens() for x in each]
out_arr[i][0] = query_src
return out_arr
def word_insert(self, query):
"""随机将某些词语mask,使用bert来生成 mask 的内容。
max_query: 所有query最多生成的个数。
"""
out_arr = []
seg_list = query.split(' ')
# 随机选择非停用词mask。
i, index_arr = 1, [1]
for each in seg_list:
# i += len(each)
i += 1
index_arr.append(i)
# query转id
split_tokens = self.token.tokenize(query)
word_ids = self.token.convert_tokens_to_ids(split_tokens)
word_ids.insert(0, self.cls_id)
word_ids.append(self.sep_id)
word_ids_arr, word_index_arr = [], []
# 随机insert n 个字符, 1<=n<=3
for index_ in index_arr:
insert_num = np.random.randint(1, 4)
word_ids_ = word_ids.copy()
word_index = []
for i in range(insert_num):
word_ids_.insert(index_, self.mask_id)
word_index.append(index_ + i)
word_ids_arr.append(word_ids_)
word_index_arr.append(word_index)
for word_ids, word_index in zip(word_ids_arr, word_index_arr):
arr_ = self.gen_sen(word_ids, indexes=word_index)
out_arr.extend(arr_)
pass
# 这个是所有生成的句子中,筛选出前 beam size 个。
out_arr = sorted(out_arr, key=lambda x: x[1], reverse=True)
out_arr = [" ".join(x[0][1:-1][:-1]) + " " + x[0][1:-1][-1] for x in out_arr[:self.beam_size]]
return out_arr
def word_replace(self, query):
"""随机将某些词语mask,使用bert来生成 mask 的内容。"""
out_arr = []
seg_list = query.split(' ')
# 随机选择非停用词mask。
i, index_map = 1, {}
for each in seg_list:
# index_map[i] = len(each)
# i += len(each)
index_map[i] = 1
i += 1
# query转id
split_tokens = self.token.tokenize(query)
word_ids = self.token.convert_tokens_to_ids(split_tokens)
word_ids.insert(0, self.cls_id)
word_ids.append(self.sep_id)
word_ids_arr, word_index_arr = [], []
# 依次mask词语,
for index_, word_len in index_map.items():
word_ids_ = word_ids.copy()
word_index = []
for i in range(word_len):
word_ids_[index_ + i] = self.mask_id
word_index.append(index_ + i)
word_ids_arr.append(word_ids_)
word_index_arr.append(word_index)
for word_ids, word_index in zip(word_ids_arr, word_index_arr):
arr_ = self.gen_sen(word_ids, indexes=word_index)
out_arr.extend(arr_)
pass
out_arr = sorted(out_arr, key=lambda x: x[1], reverse=True)
out_arr = [" ".join(x[0][1:-1][:-1]) + " " + x[0][1:-1][-1] for x in out_arr[:self.beam_size]]
return out_arr
def insert_word2queries(self, query, beam_size=10):
self.beam_size = beam_size
out_map = defaultdict(list)
out_map[query] = self.word_insert(query)
return out_map
def replace_word2queries(self, queries: list, beam_size=10):
self.beam_size = beam_size
out_map = defaultdict(list)
for query in queries:
out_map[query] = self.word_replace(query)
return out_map
def get_words_and_lables_from_sentence(sentence):
words = []
lables = []
word_lable_pairs = sentence.split(" ")
for word_lable_pair in word_lable_pairs:
temp = word_lable_pair.rfind(":")
w = word_lable_pair[0:temp]
l = word_lable_pair[temp:]
words.append(w)
lables.append(l)
return dict(zip(words, lables))
def augment(ori_sentence, aug_num, model_dir=None):
sentence = ori_sentence.split(" <=> ")[0]
intent = ori_sentence.split(" <=> ")[1]
dict_word_lable = get_words_and_lables_from_sentence(sentence)
if not model_dir:
raise Exception("must feed params:[model_dir]")
mask_model = BertAugmentor(model_dir)
mask_model.beam_size = aug_num
words = dict_word_lable.keys()
sentence = " ".join(w for w in words)
insert_result = mask_model.word_insert(sentence)
line = insert_result[0]
tmp_line = line.replace(' ##', '')
new_line = tmp_line.replace('##', '')
new_words = new_line.split(" ")
new_slots = []
for w in new_words:
if w not in list(words):
new_slots.append(":O")
else:
new_slots.append(dict_word_lable[w])
new_dict_w_l = dict(zip(new_words, new_slots))
aug_sentence = ""
for key, values in new_dict_w_l.items():
aug_sentence = aug_sentence + key + values + " "
return aug_sentence + "<=> " + intent
if __name__ == "__main__":
model_dir = '/home/cici/major/NLPDataAugmentation/wwm_cased_L-24_H-1024_A-16/'
seed_sentence = "I'd:O like:O to:O have:O this:O track:B-music_item onto:O my:B-playlist_owner " \
"Classical:B-playlist Relaxations:I-playlist playlist.:O <=> AddToPlaylist"
print(augment(seed_sentence, 1, model_dir=model_dir))