-
Notifications
You must be signed in to change notification settings - Fork 0
/
patch.py
90 lines (77 loc) · 2.82 KB
/
patch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Modified from https://github.com/EleutherAI/knowledge-neurons/blob/main/knowledge_neurons/patch.py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from tqdm import tqdm
import pickle
import pdb
from utils import get_attributes, set_attributes
class Patch(torch.nn.Module):
def __init__(
self,
ff_layer: nn.Module,
intermediate_size: int = None,
replacement_activations: torch.Tensor = None,
onehot_coef: torch.Tensor = None,
):
super().__init__()
self.module = ff_layer
if intermediate_size is not None: # slimming
self.slim_coef = nn.Parameter(torch.ones(intermediate_size))
self.acts = replacement_activations
self.onehot_coef = onehot_coef
def forward(self, x: torch.Tensor):
hidden_states = self.module(x)
if self.acts is not None: # knowledge neurons
hidden_states[:, -1, :] = self.acts # patch the last token
elif self.onehot_coef is not None: # zero-out
hidden_states = hidden_states * self.onehot_coef.unsqueeze(1)
else: # slimming
hidden_states *= torch.clip(self.slim_coef, 0, 1)
return hidden_states
def patch_ff_layer(
model: nn.Module,
ff_attrs: str,
intermediate_size: int = None,
replacement_activations: torch.Tensor = None,
onehot_coef: torch.Tensor = None,
):
"""
replaces the ff layer at `layer_idx` with a `Patch` class - that will replace the intermediate activations at sequence position
`mask_index` with `replacement_activations`
"""
ff_layer = get_attributes(model, ff_attrs)
patch = Patch(
ff_layer,
intermediate_size,
replacement_activations,
onehot_coef,
)
set_attributes(model, ff_attrs, patch)
#print(f"Patch {ff_attrs}")
def unpatch_ff_layer(
model: nn.Module,
ff_attrs: str,
):
"""
Removes the `Patch` applied by `patch_ff_layer`, replacing it with its original value.
"""
ff_layer = get_attributes(model, ff_attrs)
assert isinstance(ff_layer, Patch), "Can't unpatch a layer that hasn't been patched"
set_attributes(model, ff_attrs, ff_layer.module)
#print(f"Reset {ff_attrs}")
def patch_slim(model):
for ly in range(model.config.n_layer):
ff_attrs = f"{model.attr_dict['transformer_layer']}.{ly}.{model.attr_dict['ffn_act']}"
patch_ff_layer(
model,
ff_attrs,
intermediate_size=model.inner_dim,
)
def reinit_slim(model):
for ly in range(model.config.n_layer):
attrs_str = f"{model.attr_dict['transformer_layer']}.{ly}.{model.attr_dict['ffn_act']}.slim_coef"
coef = get_attributes(model, attrs_str)
init.ones_(coef)