From d70f5d9f4f2d58833b50d0fa3caafeb7cfcd2f0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robert=20Heum=C3=BCller?= Date: Fri, 21 Apr 2023 12:47:03 +0200 Subject: [PATCH 1/3] Update samplers.md --- docs/samplers.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/samplers.md b/docs/samplers.md index b8166166..4ad9ecde 100644 --- a/docs/samplers.md +++ b/docs/samplers.md @@ -1,5 +1,6 @@ # Samplers -Samplers are just extensions of the torch.utils.data.Sampler class, i.e. they are passed to a PyTorch Dataloader. The purpose of samplers is to determine how batches should be formed. This is also where any offline pair or triplet miners should exist. +Samplers are just extensions of the torch.utils.data.Sampler class, i.e. they are passed to a PyTorch Dataloader (specifically as _sampler_ argument, unless otherwise mentioned). +The purpose of samplers is to determine how batches should be formed. This is also where any offline pair or triplet miners should exist. ## MPerClassSampler @@ -87,4 +88,4 @@ samplers.FixedSetOfTriplets(labels, num_triplets) **Parameters**: * **labels**: The list of labels for your dataset, i.e. the labels[x] should be the label of the xth element in your dataset. -* **num_triplets**: The number of triplets to create. \ No newline at end of file +* **num_triplets**: The number of triplets to create. From e6b3527519b888d33e772372af252bbc43ef51f6 Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Fri, 26 May 2023 12:20:08 -0400 Subject: [PATCH 2/3] Fix bug where set_stats wasn't being called in TripletMarginMiner --- src/pytorch_metric_learning/__init__.py | 2 +- .../miners/triplet_margin_miner.py | 2 ++ tests/miners/test_triplet_margin_miner.py | 9 +++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index 58039f50..4eabd0b3 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "2.1.1" +__version__ = "2.1.2" diff --git a/src/pytorch_metric_learning/miners/triplet_margin_miner.py b/src/pytorch_metric_learning/miners/triplet_margin_miner.py index 4ddfca66..5a9b2fbd 100644 --- a/src/pytorch_metric_learning/miners/triplet_margin_miner.py +++ b/src/pytorch_metric_learning/miners/triplet_margin_miner.py @@ -37,6 +37,8 @@ def mine(self, embeddings, labels, ref_emb, ref_labels): ap_dist - an_dist if self.distance.is_inverted else an_dist - ap_dist ) + self.set_stats(ap_dist, an_dist, triplet_margin) + if self.type_of_triplets == "easy": threshold_condition = triplet_margin > self.margin else: diff --git a/tests/miners/test_triplet_margin_miner.py b/tests/miners/test_triplet_margin_miner.py index 95f946ba..489e4e99 100644 --- a/tests/miners/test_triplet_margin_miner.py +++ b/tests/miners/test_triplet_margin_miner.py @@ -103,3 +103,12 @@ def test_empty_output(self): self.assertTrue(len(a) == 0) self.assertTrue(len(p) == 0) self.assertTrue(len(n) == 0) + + @unittest.skipUnless(WITH_COLLECT_STATS, "WITH_COLLECT_STATS is false") + def test_recordable_attributes(self): + miner = TripletMarginMiner() + emb, labels = torch.randn(32, 32), torch.randint(0, 3, size=(32,)) + miner(emb, labels) + self.assertNotEqual(miner.avg_triplet_margin, 0) + self.assertNotEqual(miner.pos_pair_dist, 0) + self.assertNotEqual(miner.neg_pair_dist, 0) From 45835fae72614bb5698562f30fd52318a20b58c7 Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Fri, 26 May 2023 13:54:07 -0400 Subject: [PATCH 3/3] Made HierarchicalSampler extend Sampler instead of BatchSampler --- src/pytorch_metric_learning/samplers/hierarchical_sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_metric_learning/samplers/hierarchical_sampler.py b/src/pytorch_metric_learning/samplers/hierarchical_sampler.py index ba0f99e7..74e70f85 100644 --- a/src/pytorch_metric_learning/samplers/hierarchical_sampler.py +++ b/src/pytorch_metric_learning/samplers/hierarchical_sampler.py @@ -2,14 +2,14 @@ from collections import defaultdict import torch -from torch.utils.data.sampler import BatchSampler +from torch.utils.data.sampler import Sampler from ..utils import common_functions as c_f # Inspired by # https://github.com/kunhe/Deep-Metric-Learning-Baselines/blob/master/datasets.py -class HierarchicalSampler(BatchSampler): +class HierarchicalSampler(Sampler): def __init__( self, labels,