Skip to content

Commit

Permalink
HardNegativePairSelector bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
adambielski authored Mar 15, 2018
1 parent 296164d commit 9512643
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def get_pairs(self, embeddings, labels):

class HardNegativePairSelector(PairSelector):
"""
Creates all possible positive pairs. For negative pairs, pairs with greatest distance are taken into consideration,
Creates all possible positive pairs. For negative pairs, pairs with smallest distance are taken into consideration,
matching the number of positive pairs.
"""

def __init__(self, cpu=False):
def __init__(self, cpu=True):
super(HardNegativePairSelector, self).__init__()
self.cpu = cpu

Expand All @@ -68,7 +68,7 @@ def get_pairs(self, embeddings, labels):

negative_distances = distance_matrix[negative_pairs[:, 0], negative_pairs[:, 1]]
negative_distances = negative_distances.cpu().data.numpy()
top_negatives = np.argpartition(negative_distances, -len(positive_pairs))[-len(positive_pairs):]
top_negatives = np.argpartition(negative_distances, len(positive_pairs))[:len(positive_pairs)]
top_negative_pairs = negative_pairs[torch.LongTensor(top_negatives)]

return positive_pairs, top_negative_pairs
Expand Down

0 comments on commit 9512643

Please sign in to comment.