-
Notifications
You must be signed in to change notification settings - Fork 127
/
Copy pathdemo_deo.py
59 lines (52 loc) · 2.15 KB
/
demo_deo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import OpenAttack
import nltk
from nltk.sentiment.vader import SentimentIntensityAnalyzer
import numpy as np
import datasets
import transformers
def make_model():
class MyClassifier(OpenAttack.Classifier):
def __init__(self):
try:
self.model = SentimentIntensityAnalyzer()
except LookupError:
nltk.download('vader_lexicon')
self.model = SentimentIntensityAnalyzer()
def get_pred(self, input_):
return self.get_prob(input_).argmax(axis=1)
def get_prob(self, input_):
ret = []
for sent in input_:
res = self.model.polarity_scores(sent)
prob = (res["pos"] + 1e-6) / (res["neg"] + res["pos"] + 1e-6)
ret.append(np.array([1 - prob, prob]))
return np.array(ret)
return MyClassifier()
def dataset_mapping(x):
return {
"x": x["sentence"],
"y": 1 if x["label"] > 0.5 else 0,
}
def main():
print("New Attacker")
#attacker = OpenAttack.attackers.PWWSAttacker()
print("Build model")
#clsf = OpenAttack.loadVictim("BERT.SST")
clsf = OpenAttack.DataManager.loadVictim("BERT.SST")
#tokenizer = transformers.AutoTokenizer.from_pretrained("./data/Victim.BERT.SST")
#model = transformers.AutoModelForSequenceClassification.from_pretrained("./data/Victim.BERT.SST", num_labels=2, output_hidden_states=True)
#clsf = OpenAttack.classifiers.TransformersClassifier(model, tokenizer=tokenizer, max_length=100, embedding_layer=model.bert.embeddings.word_embeddings)
dataset = datasets.load_dataset("sst", split="train[:100]").map(function=dataset_mapping)
print("New Attacker")
attacker = OpenAttack.attackers.UATAttacker()
attacker.set_triggers(clsf, dataset)
print("Start attack")
attack_eval = OpenAttack.AttackEval( attacker, clsf, metrics=[
OpenAttack.metric.Fluency(),
OpenAttack.metric.GrammaticalErrors(),
OpenAttack.metric.EditDistance(),
OpenAttack.metric.ModificationRate()
] )
attack_eval.eval(dataset, visualize=True, progress_bar=True)
if __name__ == "__main__":
main()