-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathAdversarialUtils.py
184 lines (163 loc) · 8.66 KB
/
AdversarialUtils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# -*- coding:utf-8 -*-
import torch
class FGM:
def __init__(self, model):
self.model = model
self.backup = {}
def attack(self, epsilon=8 / 255, emb_name='module.bert.embeddings.word_embeddings.weight'):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0:
r_at = epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self, emb_name='module.bert.embeddings.word_embeddings.weight'):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
def fgm_use_bert_adv(fgm, model, input_ids, attention_mask, token_type_ids, y, criterion, args):
fgm.attack()
output_adv = model(input_ids, attention_mask, token_type_ids)[0]
loss_adv = criterion(output_adv, y) / args.accum_iter
loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
fgm.restore() # 恢复embedding参数
class PGD:
def __init__(self, model):
self.model = model
self.emb_backup = {}
self.grad_backup = {}
def attack(self, epsilon=8 / 255, alpha=10 / 255, emb_name='module.bert.embeddings.word_embeddings.weight',
is_first_attack=False):
# emb_name 模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
if is_first_attack:
self.emb_backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0:
r_at = alpha * param.grad / norm
param.data.add_(r_at)
param.data = self.project(name, param.data, epsilon)
def restore(self, emb_name='module.bert.embeddings.word_embeddings.weight'):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.emb_backup
param.data = self.emb_backup[name]
self.emb_backup = {}
def project(self, param_name, param_data, epsilon):
r = param_data - self.emb_backup[param_name]
if torch.norm(r) > epsilon:
r = epsilon * r / torch.norm(r)
return self.emb_backup[param_name] + r
def backup_grad(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
self.grad_backup[name] = param.grad.clone()
def restore_grad(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
param.grad = self.grad_backup[name]
def pgd_use_bert_adv(pgd, model, input_ids, attention_mask, token_type_ids, y, criterion, args):
pgd.backup_grad()
# 对抗训练
for t in range(args.k_pdg):
pgd.attack(is_first_attack=(t == 0)) # 在embedding上添加对抗扰动, first attack时备份param.data
if t != args.k_pdg - 1:
model.zero_grad()
else:
pgd.restore_grad()
output_adv = model(input_ids, attention_mask, token_type_ids)[0]
loss_adv = criterion(output_adv, y) / args.accum_iter
loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
pgd.restore() # 恢复embedding参数
class FreeLB(object):
def __init__(self, adv_K, adv_lr, adv_init_mag, adv_max_norm=2e-1, adv_norm_type='l2', base_model='bert'):
self.adv_K = adv_K
self.adv_lr = adv_lr
self.adv_max_norm = adv_max_norm
self.adv_init_mag = adv_init_mag # adv-training initialize with what magnitude, 即我们用多大的数值初始化delta
self.adv_norm_type = adv_norm_type
self.base_model = base_model
def attack(self, model, inputs, gradient_accumulation_steps=1):
input_ids = inputs['input_ids']
if isinstance(model, torch.nn.DataParallel):
embeds_init = getattr(model.module, self.base_model).embeddings.word_embeddings(input_ids)
else:
embeds_init = getattr(model.module, self.base_model).embeddings.word_embeddings(input_ids)
if self.adv_init_mag > 0: # 影响attack首步是基于原始梯度(delta=0),还是对抗梯度(delta!=0)
input_mask = inputs['attention_mask'].to(embeds_init)
input_lengths = torch.sum(input_mask, 1)
if self.adv_norm_type == "l2":
delta = torch.zeros_like(embeds_init).uniform_(-1, 1) * input_mask.unsqueeze(2)
dims = input_lengths * embeds_init.size(-1)
mag = self.adv_init_mag / torch.sqrt(dims)
delta = (delta * mag.view(-1, 1, 1)).detach()
elif self.adv_norm_type == "linf":
delta = torch.zeros_like(embeds_init).uniform_(-self.adv_init_mag, self.adv_init_mag)
delta = delta * input_mask.unsqueeze(2)
else:
delta = torch.zeros_like(embeds_init) # 扰动初始化
loss, logits = None, None
for astep in range(self.adv_K):
delta.requires_grad_()
inputs['inputs_embeds'] = delta + embeds_init # 累积一次扰动delta
inputs['input_ids'] = None
outputs = model(**inputs)
loss, logits = outputs[:2] # model outputs are always tuple in transformers (see doc)
loss = loss.mean() # mean() to average on multi-gpu parallel training
loss = loss / gradient_accumulation_steps
loss.backward()
delta_grad = delta.grad.clone().detach() # 备份扰动的grad
if self.adv_norm_type == "l2":
denorm = torch.norm(delta_grad.view(delta_grad.size(0), -1), dim=1).view(-1, 1, 1)
denorm = torch.clamp(denorm, min=1e-8)
delta = (delta + self.adv_lr * delta_grad / denorm).detach()
if self.adv_max_norm > 0:
delta_norm = torch.norm(delta.view(delta.size(0), -1).float(), p=2, dim=1).detach()
exceed_mask = (delta_norm > self.adv_max_norm).to(embeds_init)
reweights = (self.adv_max_norm / delta_norm * exceed_mask + (1 - exceed_mask)).view(-1, 1, 1)
delta = (delta * reweights).detach()
elif self.adv_norm_type == "linf":
denorm = torch.norm(delta_grad.view(delta_grad.size(0), -1), dim=1, p=float("inf")).view(-1, 1,
1) # p='inf',无穷范数,获取绝对值最大者
denorm = torch.clamp(denorm, min=1e-8) # 类似np.clip,将数值夹逼到(min, max)之间
delta = (delta + self.adv_lr * delta_grad / denorm).detach() # 计算该步的delta,然后累加到原delta值上(梯度上升)
if self.adv_max_norm > 0:
delta = torch.clamp(delta, -self.adv_max_norm, self.adv_max_norm).detach()
else:
raise ValueError("Norm type {} not specified.".format(self.adv_norm_type))
if isinstance(model, torch.nn.DataParallel):
embeds_init = getattr(model.module, self.base_model).embeddings.word_embeddings(input_ids)
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
embeds_init = getattr(model.module, self.base_model).embeddings.word_embeddings(input_ids)
else:
embeds_init = getattr(model, self.base_model).embeddings.word_embeddings(input_ids)
return loss, logits
# https://github.com/lonePatient/TorchBlocks/blob/master/torchblocks/callback/adversarial.py
def freeLB_use_bert_adv(freelb, model, input_ids, attention_mask, token_type_ids, y, criterion, args):
inputs = {
"input_ids": input_ids,
# "bbox": layout,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
"masked_lm_labels": y
}
loss, prediction_scores = freelb.attack(model, inputs, args.accum_iter)
return loss, prediction_scores
# if args.do_adv:
# inputs = {
# "input_ids": input_ids,
# "bbox": layout,
# "token_type_ids": segment_ids,
# "attention_mask": input_mask,
# "masked_lm_labels": lm_label_ids
# }
# loss, prediction_scores = freelb.attack(model, inputs)
# loss.backward()
# optimizer.step()
# scheduler.step()
# model.zero_grad()