-
Notifications
You must be signed in to change notification settings - Fork 0
/
slim.py
76 lines (64 loc) · 2.32 KB
/
slim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from patch import Patch
from tqdm import tqdm
import pdb
import torch.nn as nn
from utils_inject import *
def compute_l1_loss(model, start_layer_idx):
l1_loss = 0.
N = 0
cnt = 0
for m in model.modules():
if isinstance(m, Patch):
cnt += 1
if cnt > start_layer_idx:
l1_loss += torch.norm(m.slim_coef, 1)
N += len(m.slim_coef)
l1_loss /= N
return l1_loss
def slim(args, model, tokenizer, inputs, gold_set):
model.eval()
start_layer_idx = args.start_mask_layer if hasattr(args, 'start_mask_layer') else 0
# set tunable parameters
cnt = 0
params = []
for n, p in model.named_parameters():
if "slim" in n:
cnt += 1
if cnt > start_layer_idx:
p.requires_grad = True
print(n)
else:
p.requires_grad = False
params.append(p)
else:
p.requires_grad = False
print("-"*100)
optimizer = torch.optim.Adam(params, lr=args.lr)
# training
scores, reg_losses, lm_losses = [], [], []
for i in range(args.epoch):
optimizer.zero_grad()
outputs = model(**inputs)
l1_loss = compute_l1_loss(model, start_layer_idx)
lm_loss = outputs.loss
loss = lm_loss + args.lambda_l1 * l1_loss
if (i+1) % 10 == 0:
ckpt_params = torch.stack(params).clamp(min=0.0, max=1.0)
sparsity = (ckpt_params[start_layer_idx:] < args.threshold).float().mean().item()
print(i+1, f'lm loss: {lm_loss.item():.3f}, l1 loss: {l1_loss.item():.2f}')
print(' Sparsity:', sparsity)
if gold_set:
score = get_layerwise_scores(ckpt_params, gold_set, args.ratio)
else:
score = 0 # dummy
if args.save_ckpt: save_params(args, ckpt_params, f'{i+1}.pt')
scores.append(score)
lm_losses.append(lm_loss.item())
reg_losses.append(l1_loss.item())
if l1_loss < args.stop_loss: break
loss.backward()
optimizer.step()
params = torch.stack(params).clamp(min=0.0, max=1.0).detach().cpu()
torch.save(params, os.path.join(args.out_dir, 'slim.pt'))
save_records(args, scores, np.array(reg_losses), np.array(lm_losses), sparsity)
return params