forked from urvashik/knnlm
-
Notifications
You must be signed in to change notification settings - Fork 1
/
tune_locality_weights_wiki.py
81 lines (64 loc) · 2.97 KB
/
tune_locality_weights_wiki.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
import torch
from torch import nn
from fairseq import utils
class WeightedDist(torch.nn.Module):
def __init__(self):
super().__init__()
self.w0 = nn.Parameter(torch.tensor(1.))
# self.b0 = torch.nn.Parameter(torch.tensor(0.))
# self.linear = nn.Linear(3, 1, bias=False)
self.w1 = nn.Parameter(torch.tensor(1.))
self.b1 = nn.Parameter(torch.tensor(0.))
self.w2 = nn.Parameter(torch.tensor(1.))
self.b2 = nn.Parameter(torch.tensor(0.))
self.w3 = nn.Parameter(torch.tensor(1.))
self.b3 = nn.Parameter(torch.tensor(0.))
def forward(self, dist, pkg_l, proj_l, idx_mask):
"""
"""
dist = dist.cuda()
pkg_l = pkg_l.cuda()
proj_l = proj_l.cuda()
idx_mask = idx_mask.cuda()
locality_indicator = proj_l + 2 * pkg_l
locality_feat = torch.nn.functional.one_hot(locality_indicator.long(), num_classes=4).permute(2, 0, 1)
probs = utils.log_softmax(locality_feat[0] * (self.w0 * dist) +
locality_feat[1] * (self.w1 * dist + self.b1) +
locality_feat[2] * (self.w2 * dist + self.b2) +
locality_feat[3] * (self.w3 * dist + self.b3), dim=-1)
# probs = utils.log_softmax(self.w0 * dist + self.w2 * pkg_l, dim=-1)
# inp = torch.stack([dist, proj_l, pkg_l], dim=2)
# new_dist = self.linear(inp).squeeze()
# probs = utils.log_softmax(probs, dim=-1)
return torch.logsumexp(probs + idx_mask, dim=-1)
num_retrieved = 1024
bsz = 810
dists = np.load('saved_tensors/wikitext-103/valid_proj_dist_cache.npy').reshape(-1, num_retrieved)
pkg_locality = np.load('saved_tensors/wikitext-103/valid_pkg_locality_cache.npy').reshape(-1, num_retrieved)
proj_locality = np.load('saved_tensors/wikitext-103/valid_proj_locality_cache.npy').reshape(-1, num_retrieved)
index_masks = np.load('saved_tensors/wikitext-103/valid_proj_index_mask_cache.npy').reshape(-1, num_retrieved)
dists = torch.from_numpy(dists).float()
pkg_locality = torch.from_numpy(pkg_locality)
proj_locality = torch.from_numpy(proj_locality)
index_masks = torch.from_numpy(index_masks)
model = WeightedDist().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for i in range(2000):
epoch_loss = 0.0
num_batches = 0
for start_idx in range(0, dists.shape[0], bsz):
num_batches += 1
outputs = model(dists[start_idx:start_idx+bsz, :],
pkg_locality[start_idx:start_idx+bsz, :],
proj_locality[start_idx:start_idx+bsz, :],
index_masks[start_idx:start_idx+bsz, :])
loss = torch.mean(-outputs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print('Epoch:', i, 'Loss:', epoch_loss)
for name, param in model.named_parameters():
if param.requires_grad:
print(name, param.data)