-
Notifications
You must be signed in to change notification settings - Fork 33
/
losses.py
84 lines (67 loc) · 2.76 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn.functional as F
def kl_loc_loss(pre, post, mask=None):
pre = pre.to(torch.float32)
post = post.to(torch.float32)
sequence = pre.dim() == 3
pre_ = pre.view(-1, pre.shape[-1])
post_ = post.view(pre_.shape)
assert pre_.shape[0] == post_.shape[0]
if not sequence:
if pre_.shape[-1] == 1: # No masking needed for binary classification
return (pre.sigmoid() * (F.logsigmoid(pre) - F.logsigmoid(post))).mean() + (
(-pre).sigmoid() * (F.logsigmoid(-pre) - F.logsigmoid(-post))
).mean()
else: # We have sequences of predictions; masking needed
if pre_.shape[-1] > 1:
assert mask is not None
mask_ = mask.view(pre_.shape[0])
kl = (pre_.softmax(-1) * (pre_.log_softmax(-1) - post_.log_softmax(-1))).sum(-1)
return (kl * mask_).sum() / mask_.sum()
raise NotImplementedError
def binary_log_probs(pred, targ):
neg_mask = torch.ones_like(pred)
neg_mask[targ == 0] *= -1
pred = pred * neg_mask
log_probs = F.logsigmoid(pred)
acc = (log_probs.exp() > 0.5).float().mean()
return {
"acc": acc,
"log_prob": log_probs.mean(),
"prob": log_probs.exp().mean(),
"nll": -log_probs.mean(),
"n_tokens": log_probs.shape[0]
}
def multiclass_log_probs(pred, targ, shift=True):
NULL_TOKEN = 0 # a placeholder used for masked target locations
pred = pred.clone()
targ = targ.clone()
if shift and pred.dim() == 3: # Dealing with sequences
pred = pred[:, :-1] # Remove last prediction in sequence
targ = targ[:, 1:] # Shift to align predictions and targets
mask = targ != -100
targ[~mask] = NULL_TOKEN # Can be any valid token, since we'll throw them out
unmasked_log_probs = pred.log_softmax(-1).gather(-1, targ.unsqueeze(-1)).squeeze(-1)
pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN)
correct = pred_ids == targ
if pred.dim() == 3:
correct = (pred_ids == targ).all(-1) # We want to get the whole sequence right
acc = correct.float().mean()
n_tokens = mask.float().sum()
log_prob = (unmasked_log_probs * mask.float()).sum() / n_tokens
prob = (unmasked_log_probs.exp() * mask.float()).sum() / n_tokens
return {
"acc": acc,
"log_prob": log_prob,
"prob": prob,
"n_tokens": n_tokens,
"nll": -log_prob
}
def masked_log_probs(pred, targ, shift=True):
pred = pred.to(torch.float32)
if not (pred.dim() == 2 or pred.dim() == 3):
raise RuntimeError(f"Expected pred to have 2 or 3 dimensions, got {pred.shape}")
if pred.shape[-1] == 1:
return binary_log_probs(pred, targ)
else:
return multiclass_log_probs(pred, targ, shift=shift)