Skip to content

Commit

Permalink
Merge pull request #631 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v2.1.2
  • Loading branch information
Kevin Musgrave authored May 27, 2023
2 parents c57ebdd + eb43a83 commit 6d91b49
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 5 deletions.
5 changes: 3 additions & 2 deletions docs/samplers.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Samplers
Samplers are just extensions of the torch.utils.data.Sampler class, i.e. they are passed to a PyTorch Dataloader. The purpose of samplers is to determine how batches should be formed. This is also where any offline pair or triplet miners should exist.
Samplers are just extensions of the torch.utils.data.Sampler class, i.e. they are passed to a PyTorch Dataloader (specifically as _sampler_ argument, unless otherwise mentioned).
The purpose of samplers is to determine how batches should be formed. This is also where any offline pair or triplet miners should exist.


## MPerClassSampler
Expand Down Expand Up @@ -87,4 +88,4 @@ samplers.FixedSetOfTriplets(labels, num_triplets)
**Parameters**:

* **labels**: The list of labels for your dataset, i.e. the labels[x] should be the label of the xth element in your dataset.
* **num_triplets**: The number of triplets to create.
* **num_triplets**: The number of triplets to create.
3 changes: 2 additions & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
__version__ = "2.1.1"
__version__ = "2.1.2"

2 changes: 2 additions & 0 deletions src/pytorch_metric_learning/miners/triplet_margin_miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def mine(self, embeddings, labels, ref_emb, ref_labels):
ap_dist - an_dist if self.distance.is_inverted else an_dist - ap_dist
)

self.set_stats(ap_dist, an_dist, triplet_margin)

if self.type_of_triplets == "easy":
threshold_condition = triplet_margin > self.margin
else:
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_metric_learning/samplers/hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from collections import defaultdict

import torch
from torch.utils.data.sampler import BatchSampler
from torch.utils.data.sampler import Sampler

from ..utils import common_functions as c_f


# Inspired by
# https://github.com/kunhe/Deep-Metric-Learning-Baselines/blob/master/datasets.py
class HierarchicalSampler(BatchSampler):
class HierarchicalSampler(Sampler):
def __init__(
self,
labels,
Expand Down
9 changes: 9 additions & 0 deletions tests/miners/test_triplet_margin_miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,12 @@ def test_empty_output(self):
self.assertTrue(len(a) == 0)
self.assertTrue(len(p) == 0)
self.assertTrue(len(n) == 0)

@unittest.skipUnless(WITH_COLLECT_STATS, "WITH_COLLECT_STATS is false")
def test_recordable_attributes(self):
miner = TripletMarginMiner()
emb, labels = torch.randn(32, 32), torch.randint(0, 3, size=(32,))
miner(emb, labels)
self.assertNotEqual(miner.avg_triplet_margin, 0)
self.assertNotEqual(miner.pos_pair_dist, 0)
self.assertNotEqual(miner.neg_pair_dist, 0)

0 comments on commit 6d91b49

Please sign in to comment.