forked from saprmarks/dictionary_learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
standard.py
188 lines (157 loc) · 6.6 KB
/
standard.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
Implements the standard SAE training scheme.
"""
import torch as t
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from trainers.trainer import SAETrainer
from config import DEBUG
from dictionary import AutoEncoder
from collections import namedtuple
class ConstrainedAdam(t.optim.Adam):
"""
A variant of Adam where some of the parameters are constrained to have unit norm.
"""
def __init__(self, params, constrained_params, lr):
super().__init__(params, lr=lr)
self.constrained_params = list(constrained_params)
def step(self, closure=None):
with t.no_grad():
for p in self.constrained_params:
normed_p = p / p.norm(dim=0, keepdim=True)
# project away the parallel component of the gradient
p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p
super().step(closure=closure)
with t.no_grad():
for p in self.constrained_params:
# renormalize the constrained parameters
p /= p.norm(dim=0, keepdim=True)
class StandardTrainer(SAETrainer):
"""
Standard SAE training scheme.
"""
def __init__(self,
dict_class=AutoEncoder,
activation_dim=512,
dict_size=64*512,
lr=1e-3,
l1_penalty=1e-1,
warmup_steps=1000, # lr warmup period at start of training and after each resample
resample_steps=None, # how often to resample neurons
seed=None,
device=None,
layer=None,
lm_name=None,
wandb_name='StandardTrainer',
submodule_name=None,
):
super().__init__(seed)
assert layer is not None and lm_name is not None
self.layer = layer
self.lm_name = lm_name
self.submodule_name = submodule_name
if seed is not None:
t.manual_seed(seed)
t.cuda.manual_seed_all(seed)
# initialize dictionary
self.ae = dict_class(activation_dim, dict_size)
self.lr = lr
self.l1_penalty=l1_penalty
self.warmup_steps = warmup_steps
self.wandb_name = wandb_name
if device is None:
self.device = 'cuda' if t.cuda.is_available() else 'cpu'
else:
self.device = device
self.ae.to(self.device)
self.resample_steps = resample_steps
if self.resample_steps is not None:
# how many steps since each neuron was last activated?
self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device)
else:
self.steps_since_active = None
self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr)
if resample_steps is None:
def warmup_fn(step):
return min(step / warmup_steps, 1.)
else:
def warmup_fn(step):
return min((step % resample_steps) / warmup_steps, 1.)
self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn)
def resample_neurons(self, deads, activations):
with t.no_grad():
if deads.sum() == 0: return
print(f"resampling {deads.sum().item()} neurons")
# compute loss for each activation
losses = (activations - self.ae(activations)).norm(dim=-1)
# sample input to create encoder/decoder weights from
n_resample = min([deads.sum(), losses.shape[0]])
indices = t.multinomial(losses, num_samples=n_resample, replacement=False)
sampled_vecs = activations[indices]
# get norm of the living neurons
alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean()
# resample first n_resample dead neurons
deads[deads.nonzero()[n_resample:]] = False
self.ae.encoder.weight[deads] = sampled_vecs * alive_norm * 0.2
self.ae.decoder.weight[:,deads] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T
self.ae.encoder.bias[deads] = 0.
# reset Adam parameters for dead neurons
state_dict = self.optimizer.state_dict()['state']
## encoder weight
state_dict[1]['exp_avg'][deads] = 0.
state_dict[1]['exp_avg_sq'][deads] = 0.
## encoder bias
state_dict[2]['exp_avg'][deads] = 0.
state_dict[2]['exp_avg_sq'][deads] = 0.
## decoder weight
state_dict[3]['exp_avg'][:,deads] = 0.
state_dict[3]['exp_avg_sq'][:,deads] = 0.
def loss(self, x, logging=False, **kwargs):
x_hat, f = self.ae(x, output_features=True)
l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean()
l1_loss = f.norm(p=1, dim=-1).mean()
if self.steps_since_active is not None:
# update steps_since_active
deads = (f == 0).all(dim=0)
self.steps_since_active[deads] += 1
self.steps_since_active[~deads] = 0
loss = l2_loss + self.l1_penalty * l1_loss
if not logging:
return loss
else:
return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])(
x, x_hat, f,
{
'l2_loss' : l2_loss.item(),
'mse_loss' : (x - x_hat).pow(2).sum(dim=-1).mean().item(),
'sparsity_loss' : l1_loss.item(),
'loss' : loss.item()
}
)
def update(self, step, activations):
activations = activations.to(self.device)
self.optimizer.zero_grad()
loss = self.loss(activations)
loss.backward()
self.optimizer.step()
self.scheduler.step()
if self.resample_steps is not None and step % self.resample_steps == 0:
self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations)
@property
def config(self):
return {
'dict_class': 'AutoEncoder',
'trainer_class' : 'StandardTrainer',
'activation_dim': self.ae.activation_dim,
'dict_size': self.ae.dict_size,
'lr' : self.lr,
'l1_penalty' : self.l1_penalty,
'warmup_steps' : self.warmup_steps,
'resample_steps' : self.resample_steps,
'device' : self.device,
'layer' : self.layer,
'lm_name' : self.lm_name,
'wandb_name': self.wandb_name,
'submodule_name': self.submodule_name,
}