Skip to content

Commit

Permalink
Added eps to contrastive loss sqrt
Browse files Browse the repository at this point in the history
  • Loading branch information
adambielski committed Jan 10, 2019
1 parent 244f18f commit 69c1073
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ class ContrastiveLoss(nn.Module):
def __init__(self, margin):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.eps = 1e-9

def forward(self, output1, output2, target, size_average=True):
distances = (output2 - output1).pow(2).sum(1) # squared distances
losses = 0.5 * (target.float() * distances +
(1 + -1 * target).float() * F.relu(self.margin - distances.sqrt()).pow(2))
(1 + -1 * target).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2))
return losses.mean() if size_average else losses.sum()


Expand Down

0 comments on commit 69c1073

Please sign in to comment.