-
Notifications
You must be signed in to change notification settings - Fork 6
/
pytorch_utils.py
49 lines (39 loc) · 1.42 KB
/
pytorch_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def save_checkpoint(state, is_best, filename='/output/checkpoint.pth.tar'):
"""Save checkpoint if a new best is achieved"""
if is_best:
print ("=> Saving a new best")
torch.save(state['model'], filename) # save checkpoint
else:
print ("=> Validation Accuracy did not improve")
class ContrastiveLoss(torch.nn.Module):
"""
Contrastive loss function.
Based on:
"""
def __init__(self, margin=1.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def check_type_forward(self, in_types):
assert len(in_types) == 3
x0_type, x1_type, y_type = in_types
assert x0_type.size() == x1_type.shape
assert x1_type.size()[0] == y_type.shape[0]
assert x1_type.size()[0] > 0
assert x0_type.dim() == 2
assert x1_type.dim() == 2
assert y_type.dim() == 1
def forward(self, x0, x1, y):
self.check_type_forward((x0, x1, y))
# euclidian distance
diff = x0 - x1
dist_sq = torch.sum(torch.pow(diff, 2), 1)
dist = torch.sqrt(dist_sq)
mdist = self.margin - dist
dist = torch.clamp(mdist, min=0.0)
loss = y * dist_sq + (1 - y) * torch.pow(dist, 2)
loss = torch.sum(loss) / 2.0 / x0.size()[0]
return loss