-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathmain.py
288 lines (230 loc) · 10 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import math
from argparse import ArgumentParser
from itertools import permutations
import copy
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from grokfast import *
class Block(nn.Module):
"""Causal transformer block
"""
def __init__(self, dim, num_heads):
super().__init__()
self.ln_1 = nn.LayerNorm(dim)
self.ln_2 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim),
)
def forward(self, x):
attn_mask = torch.full(
(len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype
)
attn_mask = torch.triu(attn_mask, diagonal=1)
attn_mask[torch.isnan(attn_mask)] = 0.0 # fixes all 'nan' on 'mps' device
x = self.ln_1(x)
a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
return x
class Decoder(nn.Module):
"""Causal Transformer decoder
"""
def __init__(self, dim=128, num_layers=2, num_heads=4, num_tokens=97, seq_len=5):
super().__init__()
self.token_embeddings = nn.Embedding(num_tokens, dim)
self.position_embeddings = nn.Embedding(seq_len, dim)
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(Block(dim, num_heads))
self.ln_f = nn.LayerNorm(dim)
self.head = nn.Linear(dim, num_tokens, bias=False)
def forward(self, x):
h = self.token_embeddings(x)
positions = torch.arange(x.shape[0], device=x.device).unsqueeze(-1)
h = h + self.position_embeddings(positions).expand_as(h)
for layer in self.layers:
h = layer(h)
h = self.ln_f(h)
logits = self.head(h)
return logits
def multiplication_mod_p_data(p, eq_token, op_token):
"""x◦y = x/y (mod p) for 0 ≤ x < p, 0 < y < p
"""
x = torch.arange(p)
y = torch.arange(1, p)
x, y = torch.cartesian_prod(x, y).T
eq = torch.ones_like(x) * eq_token
op = torch.ones_like(x) * op_token
result = x * y % p
# "All of our experiments used a small transformer trained on datasets of
# equations of the form a◦b = c, where each of “a”, “◦”, “b”, “=”, and “c”
# is a seperate token"
return torch.stack([x, op, y, eq, result])
def main(args):
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# tokens for <op> and <=>. It's not clear why <=> is needed at all since it
# has no effect on the output, but we'll leave it in to best follow the
# paper.
eq_token = args.p
op_token = args.p + 1
# "We trained a standard decoder-only transformer (Vaswani et al., 2017)
# with causal attention masking, and calculated loss and accuracy only on
# the answer part of the equation. For all experiments we used a
# transformer with 2 layers, width 128, and 4 attention heads"
model = Decoder(
dim=128, num_layers=2, num_heads=4, num_tokens=args.p + 2, seq_len=5
).to(device)
nparams = sum([p.numel() for p in model.parameters() if p.requires_grad])
print(model)
print(f'Total number of parameters: {nparams}')
data = multiplication_mod_p_data(args.p, eq_token, op_token)
train_idx, valid_idx = torch.randperm(data.shape[1]).split(data.shape[1] // 2)
train_data, valid_data = data[:, train_idx], data[:, valid_idx]
# For most experiments we used AdamW optimizer with learning rate 10−3,
# weight decay 1, β1 = 0.9, β2 = 0.98
optimizer = getattr(torch.optim, args.optimizer)(
model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.beta1, args.beta2),
)
# linear learning rate warmup over the first 10 updates
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lambda update: 1 if update > 10 else update / 10
)
steps_per_epoch = math.ceil(train_data.shape[1] / args.batch_size)
its, train_acc, val_acc, train_loss, val_loss = [], [], [], [], []
grads = None
i = 0
# For logging network weights.
net_its, nets = [], []
for e in tqdm(range(int(args.budget) // steps_per_epoch)):
# randomly shuffle train data
train_data = train_data[:, torch.randperm(train_data.shape[1])]
for data, is_train in [(train_data, True), (valid_data, False)]:
model.train(is_train)
total_loss = 0
total_acc = 0
# torch.split faster than dataloader with tensor
dl = torch.split(data, args.batch_size, dim=1)
for input in dl:
input = input.to(device)
with torch.set_grad_enabled(is_train):
logits = model(input[:-1])
# calculate loss only on the answer part of the equation (last element
loss = F.cross_entropy(logits[-1], input[-1])
total_loss += loss.item() * input.shape[-1]
if is_train:
model.zero_grad()
loss.backward()
#######
trigger = i < 500 if args.two_stage else False
if args.filter == "none":
pass
elif args.filter == "ma":
grads = gradfilter_ma(model, grads=grads, window_size=args.window_size, lamb=args.lamb, trigger=trigger)
elif args.filter == "ema":
grads = gradfilter_ema(model, grads=grads, alpha=args.alpha, lamb=args.lamb)
else:
raise ValueError(f"Invalid gradient filter type `{args.filter}`")
#######
optimizer.step()
scheduler.step()
i += 1
acc = (logits[-1].argmax(-1) == input[-1]).float().mean()
total_acc += acc.item() * input.shape[-1]
if is_train:
train_acc.append(total_acc / train_data.shape[-1])
train_loss.append(total_loss / train_data.shape[-1])
its.append(i)
else:
val_acc.append(total_acc / valid_data.shape[-1])
val_loss.append(total_loss / valid_data.shape[-1])
if args.save_weights:
do_save = e <= 500 or (e > 500 and (e + 1) % 100 == 0) or e == int(args.budget) // steps_per_epoch - 1
else:
do_save = (e + 1) % 100 == 0
if do_save:
steps = torch.arange(len(train_acc)).numpy() * steps_per_epoch
plt.plot(steps, train_acc, label="train")
plt.plot(steps, val_acc, label="val")
plt.legend()
plt.title("Modular Multiplication (training on 50% of data)")
plt.xlabel("Optimization Steps")
plt.ylabel("Accuracy")
plt.xscale("log", base=10)
plt.grid()
plt.savefig(f"results/acc_{args.label}.png", dpi=150)
plt.close()
plt.plot(steps, train_loss, label="train")
plt.plot(steps, val_loss, label="val")
plt.legend()
plt.title("Modular Multiplication (training on 50% of data)")
plt.xlabel("Optimization Steps")
plt.ylabel("Loss")
plt.xscale("log", base=10)
plt.grid()
plt.savefig(f"results/loss_{args.label}.png", dpi=150)
plt.close()
results = {
'its': its,
'train_acc': train_acc,
'train_loss': train_loss,
'val_acc': val_acc,
'val_loss': val_loss,
}
if args.save_weights:
net_its.append(e)
nets.append(copy.deepcopy(model.state_dict()))
results['net_its'] = net_its
results['net'] = nets
torch.save(results, f"results/res_{args.label}.pt")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--label", default="")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--p", type=int, default=97)
parser.add_argument("--budget", type=int, default=3e5)
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--beta1", type=float, default=0.9)
parser.add_argument("--beta2", type=float, default=0.98)
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--optimizer", default="Adam")
# Grokfast
parser.add_argument("--filter", type=str, choices=["none", "ma", "ema", "fir"], default="none")
parser.add_argument("--alpha", type=float, default=0.99)
parser.add_argument("--window_size", type=int, default=100)
parser.add_argument("--lamb", type=float, default=5.0)
# Ablation studies
parser.add_argument("--two_stage", action='store_true')
parser.add_argument("--save_weights", action='store_true')
args = parser.parse_args()
filter_str = ('_' if args.label != '' else '') + args.filter
window_size_str = f'_w{args.window_size}'
alpha_str = f'_a{args.alpha:.3f}'.replace('.', '')
lamb_str = f'_l{int(args.lamb)}'
if args.filter == 'none':
filter_suffix = ''
elif args.filter == 'ma':
filter_suffix = window_size_str + lamb_str
elif args.filter == 'ema':
filter_suffix = alpha_str + lamb_str
else:
raise ValueError(f"Unrecognized filter type {args.filter}")
optim_suffix = ''
if args.weight_decay != 0:
optim_suffix = optim_suffix + f'_wd{args.weight_decay:.1e}'.replace('.', '')
if args.lr != 1e-3:
optim_suffix = optim_suffix + f'_lrx{int(args.lr / 1e-3)}'
args.label = args.label + filter_str + filter_suffix + optim_suffix
print(f'Experiment results saved under name: {args.label}')
main(args)