forked from JHL-HUST/PWWS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparaphrase.py
253 lines (214 loc) · 9.4 KB
/
paraphrase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
# coding: utf-8
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import attr
from config import config
import nltk
import spacy
from functools import partial
from nltk.tokenize import word_tokenize
from nltk.tag import pos_tag
# nltk.download('wordnet')
from nltk.corpus import wordnet as wn
from get_NE_list import NE_list
import numpy as np
from unbuffered import Unbuffered
sys.stdout = Unbuffered(sys.stdout)
# from pywsd.lesk import simple_lesk as disambiguate
nlp = spacy.load('en_core_web_sm')
# Penn TreeBank POS tags:
# http://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html
supported_pos_tags = [
'CC', # coordinating conjunction, like "and but neither versus whether yet so"
# 'CD', # Cardinal number, like "mid-1890 34 forty-two million dozen"
# 'DT', # Determiner, like all "an both those"
# 'EX', # Existential there, like "there"
# 'FW', # Foreign word
# 'IN', # Preposition or subordinating conjunction, like "among below into"
'JJ', # Adjective, like "second ill-mannered"
'JJR', # Adjective, comparative, like "colder"
'JJS', # Adjective, superlative, like "cheapest"
# 'LS', # List item marker, like "A B C D"
# 'MD', # Modal, like "can must shouldn't"
'NN', # Noun, singular or mass
'NNS', # Noun, plural
'NNP', # Proper noun, singular
'NNPS', # Proper noun, plural
# 'PDT', # Predeterminer, like "all both many"
# 'POS', # Possessive ending, like "'s"
# 'PRP', # Personal pronoun, like "hers herself ours they theirs"
# 'PRP$', # Possessive pronoun, like "hers his mine ours"
'RB', # Adverb
'RBR', # Adverb, comparative, like "lower heavier"
'RBS', # Adverb, superlative, like "best biggest"
# 'RP', # Particle, like "board about across around"
# 'SYM', # Symbol
# 'TO', # to
# 'UH', # Interjection, like "wow goody"
'VB', # Verb, base form
'VBD', # Verb, past tense
'VBG', # Verb, gerund or present participle
'VBN', # Verb, past participle
'VBP', # Verb, non-3rd person singular present
'VBZ', # Verb, 3rd person singular present
# 'WDT', # Wh-determiner, like "that what whatever which whichever"
# 'WP', # Wh-pronoun, like "that who"
# 'WP$', # Possessive wh-pronoun, like "whose"
# 'WRB', # Wh-adverb, like "however wherever whenever"
]
@attr.s
class SubstitutionCandidate:
token_position = attr.ib()
similarity_rank = attr.ib()
original_token = attr.ib()
candidate_word = attr.ib()
def vsm_similarity(doc, original, synonym):
window_size = 3
start = max(0, original.i - window_size)
return doc[start: original.i + window_size].similarity(synonym)
def _get_wordnet_pos(spacy_token):
'''Wordnet POS tag'''
pos = spacy_token.tag_[0].lower()
if pos in ['r', 'n', 'v']: # adv, noun, verb
return pos
elif pos == 'j':
return 'a' # adj
def _synonym_prefilter_fn(token, synonym):
'''
Similarity heuristics go here
'''
if (len(synonym.text.split()) > 2 or ( # the synonym produced is a phrase
synonym.lemma == token.lemma) or ( # token and synonym are the same
synonym.tag != token.tag) or ( # the pos of the token synonyms are different
token.text.lower() == 'be')): # token is be
return False
else:
return True
def _generate_synonym_candidates(token, token_position, rank_fn=None):
'''
Generate synonym candidates.
For each token in the doc, the list of WordNet synonyms is expanded.
:return candidates, a list, whose type of element is <class '__main__.SubstitutionCandidate'>
like SubstitutionCandidate(token_position=0, similarity_rank=10, original_token=Soft, candidate_word='subdued')
'''
if rank_fn is None:
rank_fn = vsm_similarity
candidates = []
if token.tag_ in supported_pos_tags:
wordnet_pos = _get_wordnet_pos(token) # 'r', 'a', 'n', 'v' or None
wordnet_synonyms = []
synsets = wn.synsets(token.text, pos=wordnet_pos)
for synset in synsets:
wordnet_synonyms.extend(synset.lemmas())
synonyms = []
for wordnet_synonym in wordnet_synonyms:
spacy_synonym = nlp(wordnet_synonym.name().replace('_', ' '))[0]
synonyms.append(spacy_synonym)
synonyms = filter(partial(_synonym_prefilter_fn, token), synonyms)
candidate_set = set()
for _, synonym in enumerate(synonyms):
candidate_word = synonym.text
if candidate_word in candidate_set: # avoid repetition
continue
candidate_set.add(candidate_word)
candidate = SubstitutionCandidate(
token_position=token_position,
similarity_rank=None,
original_token=token,
candidate_word=candidate_word)
candidates.append(candidate)
return candidates
def _compile_perturbed_tokens(doc, accepted_candidates):
'''
Traverse the list of accepted candidates and do the token substitutions.
'''
candidate_by_position = {}
for candidate in accepted_candidates:
candidate_by_position[candidate.token_position] = candidate
final_tokens = []
for position, token in enumerate(doc):
word = token.text
if position in candidate_by_position:
candidate = candidate_by_position[position]
word = candidate.candidate_word.replace('_', ' ')
final_tokens.append(word)
return final_tokens
def PWWS(
doc,
true_y,
dataset,
word_saliency_list=None,
rank_fn=None,
heuristic_fn=None, # Defined in adversarial_tools.py
halt_condition_fn=None, # Defined in adversarial_tools.py
verbose=True):
# defined in Eq.(8)
def softmax(x):
exp_x = np.exp(x)
softmax_x = exp_x / np.sum(exp_x)
return softmax_x
heuristic_fn = heuristic_fn or (lambda _, candidate: candidate.similarity_rank)
halt_condition_fn = halt_condition_fn or (lambda perturbed_text: False)
perturbed_doc = doc
perturbed_text = perturbed_doc.text
substitute_count = 0 # calculate how many substitutions used in a doc
substitute_tuple_list = [] # save the information of substitute word
word_saliency_array = np.array([word_tuple[2] for word_tuple in word_saliency_list])
word_saliency_array = softmax(word_saliency_array)
NE_candidates = NE_list.L[dataset][true_y]
NE_tags = list(NE_candidates.keys())
use_NE = True # whether use NE as a substitute
max_len = config.word_max_len[dataset]
# for each word w_i in x, use WordNet to build a synonym set L_i
for (position, token, word_saliency, tag) in word_saliency_list:
if position >= max_len:
break
candidates = []
if use_NE:
NER_tag = token.ent_type_
if NER_tag in NE_tags:
candidate = SubstitutionCandidate(position, 0, token, NE_candidates[NER_tag])
candidates.append(candidate)
else:
candidates = _generate_synonym_candidates(token=token, token_position=position, rank_fn=rank_fn)
else:
candidates = _generate_synonym_candidates(token=token, token_position=position, rank_fn=rank_fn)
if len(candidates) == 0:
continue
perturbed_text = perturbed_doc.text
# The substitute word selection method R(w_i;L_i) defined in Eq.(4)
sorted_candidates = zip(map(partial(heuristic_fn, doc.text), candidates), candidates)
# Sorted according to the return value of heuristic_fn function, that is, \Delta P defined in Eq.(4)
sorted_candidates = list(sorted(sorted_candidates, key=lambda t: t[0]))
# delta_p_star is defined in Eq.(5); substitute is w_i^*
delta_p_star, substitute = sorted_candidates.pop()
# delta_p_star * word_saliency_array[position] equals H(x, x_i^*, w_i) defined in Eq.(7)
substitute_tuple_list.append(
(position, token.text, substitute, delta_p_star * word_saliency_array[position], token.tag_))
# sort all the words w_i in x in descending order based on H(x, x_i^*, w_i)
sorted_substitute_tuple_list = sorted(substitute_tuple_list, key=lambda t: t[3], reverse=True)
# replace w_i in x^(i-1) with w_i^* to craft x^(i)
NE_count = 0 # calculate how many NE used in a doc
change_tuple_list = []
for (position, token, substitute, score, tag) in sorted_substitute_tuple_list:
# if score <= 0:
# break
if nlp(token)[0].ent_type_ in NE_tags:
NE_count += 1
change_tuple_list.append((position, token, substitute, score, tag))
perturbed_text = ' '.join(_compile_perturbed_tokens(perturbed_doc, [substitute]))
perturbed_doc = nlp(perturbed_text)
substitute_count += 1
if halt_condition_fn(perturbed_text):
if verbose:
print("use", substitute_count, "substitution; use", NE_count, 'NE')
sub_rate = substitute_count / len(doc)
NE_rate = NE_count / substitute_count
return perturbed_text, sub_rate, NE_rate, change_tuple_list
if verbose:
print("use", substitute_count, "substitution; use", NE_count, 'NE')
sub_rate = substitute_count / len(doc)
NE_rate = NE_count / substitute_count
return perturbed_text, sub_rate, NE_rate, change_tuple_list