From f63e775d924edfce6e5086f8befe7b6c15b2516d Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Sun, 8 Oct 2023 14:10:48 +0800 Subject: [PATCH 01/15] [update] scatter labels --- .../losses/multilabel_supcon_loss.py | 13 ++++++-- .../utils/multilabel_loss_and_miner_utils.py | 21 ++++++------- tests/losses/test_multilabel_supcon_loss.py | 31 +++++++++++++------ 3 files changed, 40 insertions(+), 25 deletions(-) diff --git a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py index 512439f9..9e0f6a10 100644 --- a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py +++ b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py @@ -10,13 +10,15 @@ # adapted from https://github.com/HobbitLong/SupContrast class MultiSupConLoss(GenericPairLoss): - def __init__(self, num_classes, temperature=0.1, **kwargs): + def __init__(self, num_classes, temperature=0.1, threshold=0.3, **kwargs): super().__init__(mat_based_loss=True, **kwargs) self.temperature = temperature self.add_to_recordable_attributes(list_of_names=["temperature"], is_stat=False) self.num_classes = num_classes + self.threshold = threshold def _compute_loss(self, mat, pos_mask, neg_mask): + print(pos_mask) if pos_mask.bool().any() and neg_mask.bool().any(): # if dealing with actual distances, use negative distances if not self.distance.is_inverted: @@ -24,7 +26,6 @@ def _compute_loss(self, mat, pos_mask, neg_mask): mat = mat / self.temperature mat_max, _ = mat.max(dim=1, keepdim=True) mat = mat - mat_max.detach() # for numerical stability - denominator = lmu.logsumexp( mat, keep_mask=(pos_mask + neg_mask).bool(), add_one=False, dim=1 ) @@ -57,7 +58,13 @@ def mat_based_loss(self, mat, indices_tuple): def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): c_f.labels_or_indices_tuple_required(labels, indices_tuple) - indices_tuple = mlmu.convert_to_pairs(indices_tuple, labels, self.num_classes, ref_labels, device=embeddings.device) + indices_tuple = mlmu.convert_to_pairs( + indices_tuple, + labels, + self.num_classes, + ref_labels, + device=embeddings.device, + threshold=self.threshold) if all(len(x) <= 1 for x in indices_tuple): return self.zero_losses() mat = self.distance(embeddings, ref_emb) diff --git a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py index de08e86d..c9bfbba6 100644 --- a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py +++ b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py @@ -16,7 +16,7 @@ def set_ref_emb(embeddings, labels, ref_emb, ref_labels): check_shapes_multilabels(ref_emb, ref_labels) return ref_emb, ref_labels -def convert_to_pairs(indices_tuple, labels, num_classes, ref_labels=None, device=None): +def convert_to_pairs(indices_tuple, labels, num_classes, ref_labels=None, device=None, threshold=0.3): """ This returns anchor-positive and anchor-negative indices, regardless of what the input indices_tuple is @@ -26,28 +26,28 @@ def convert_to_pairs(indices_tuple, labels, num_classes, ref_labels=None, device labels: a tensor which has the label for each element in a batch """ if indices_tuple is None: - return get_all_pairs_indices(labels, num_classes, ref_labels, device=device) + return get_all_pairs_indices(labels, num_classes, ref_labels, device=device, threshold=threshold) elif len(indices_tuple) == 4: return indices_tuple else: a, p, n = indices_tuple return a, p, a, n -def get_matches_and_diffs(labels, num_classes, ref_labels=None, device=None): - matches = jaccard(num_classes, labels, ref_labels, device=device) +def get_matches_and_diffs(labels, num_classes, ref_labels=None, device=None, threshold=0.3): + matches = jaccard(num_classes, labels, ref_labels, device=device, threshold=threshold) diffs = matches ^ 1 if ref_labels is labels: matches.fill_diagonal_(0) return matches, diffs -def get_all_pairs_indices(labels, num_classes, ref_labels=None, device=None): +def get_all_pairs_indices(labels, num_classes, ref_labels=None, device=None, threshold=0.3): """ Given a tensor of labels, this will return 4 tensors. The first 2 tensors are the indices which form all positive pairs The second 2 tensors are the indices which form all negative pairs """ - matches, diffs = get_matches_and_diffs(labels, num_classes, ref_labels, device) + matches, diffs = get_matches_and_diffs(labels, num_classes, ref_labels, device, threshold=threshold) a1_idx, p_idx = torch.where(matches) a2_idx, n_idx = torch.where(diffs) return a1_idx, p_idx, a2_idx, n_idx @@ -55,12 +55,9 @@ def get_all_pairs_indices(labels, num_classes, ref_labels=None, device=None): def jaccard(n_classes, labels, ref_labels=None, threshold=0.3, device=torch.device("cpu")): if ref_labels is None: ref_labels = labels - # convert multilabels to scatter labels - labels1 = [torch.nn.functional.one_hot(torch.Tensor(label).long(), n_classes).sum(0) for label in labels] - labels2 = [torch.nn.functional.one_hot(torch.Tensor(label).long(), n_classes).sum(0) for label in ref_labels] - # stack and convert to float for calculation convenience - labels1 = torch.stack(labels1).float() - labels2 = torch.stack(labels2).float() + + labels1 = labels.float() + labels2 = ref_labels.float() # compute jaccard similarity # jaccard = intersection / union diff --git a/tests/losses/test_multilabel_supcon_loss.py b/tests/losses/test_multilabel_supcon_loss.py index cca3c61a..ebbd18c7 100644 --- a/tests/losses/test_multilabel_supcon_loss.py +++ b/tests/losses/test_multilabel_supcon_loss.py @@ -11,11 +11,11 @@ class TestMultiSupConLoss(unittest.TestCase): def test_multi_supcon_loss(self): - n_cls = 10 - n_samples = 16 - n_dim = 256 - loss_func = MultiSupConLoss(num_classes=10) - xbm_loss_func = CrossBatchMemory4MultiLabel(loss_func, n_dim, memory_size=128) + n_cls = 6 + n_samples = 6 + n_dim = 5 + loss_func = MultiSupConLoss(num_classes=n_cls) + xbm_loss_func = CrossBatchMemory4MultiLabel(loss_func, n_dim, memory_size=16) # # test float32 and float64 # for dtype in [torch.float32, torch.float64]: @@ -26,15 +26,26 @@ def test_multi_supcon_loss(self): # # test cuda and cpu # for device in [torch.device("cpu"),torch.device("cuda")]: - # embeddings = torch.randn(n_samples, n_dim, dtype=dtype, device=device) + # embeddings = torch.randn(n_samples, n_dim, device=device) # labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] # loss = loss_func(embeddings, labels) # self.assertTrue(loss >= 0) # test xbm - batchs = 10 - for b in range(batchs): - embeddings = torch.randn(n_samples, n_dim, dtype=torch.float32) + # batchs = 4 + # for b in range(batchs): + # embeddings = torch.randn(n_samples, n_dim, dtype=torch.float32) + # labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] + # loss = xbm_loss_func(embeddings, labels) + # self.assertTrue(loss == 0) + + # test scatter labels + for device in [torch.device("cpu"),torch.device("cuda")]: + embeddings = torch.randn(n_samples, n_dim, device=device) labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] - loss = xbm_loss_func(embeddings, labels) + labels = torch.stack([ + torch.nn.functional.one_hot(torch.tensor(label), n_cls).sum(dim=0).float() + for label in labels + ], dim=0) + loss = loss_func(embeddings, labels) self.assertTrue(loss >= 0) \ No newline at end of file From 8ddf96563b15168248913f54eae78b899c9b5514 Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Mon, 9 Oct 2023 10:32:06 +0800 Subject: [PATCH 02/15] [bug] fix jaccard function --- .../losses/multilabel_supcon_loss.py | 18 ++++++------ .../losses/xbm_multilabel.py | 15 ++++++---- .../utils/multilabel_loss_and_miner_utils.py | 14 +++++----- tests/losses/test_multilabel_supcon_loss.py | 28 +++++++++++++------ 4 files changed, 45 insertions(+), 30 deletions(-) diff --git a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py index 9e0f6a10..8a3694f1 100644 --- a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py +++ b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py @@ -10,15 +10,17 @@ # adapted from https://github.com/HobbitLong/SupContrast class MultiSupConLoss(GenericPairLoss): - def __init__(self, num_classes, temperature=0.1, threshold=0.3, **kwargs): + def __init__(self, num_classes, temperature=0.07, threshold=0.3, **kwargs): super().__init__(mat_based_loss=True, **kwargs) self.temperature = temperature self.add_to_recordable_attributes(list_of_names=["temperature"], is_stat=False) self.num_classes = num_classes self.threshold = threshold - def _compute_loss(self, mat, pos_mask, neg_mask): - print(pos_mask) + def dot_cosine_sim(self, a, b): + return a@b.T + + def _compute_loss(self, mat, pos_mask, neg_mask, multi_val): if pos_mask.bool().any() and neg_mask.bool().any(): # if dealing with actual distances, use negative distances if not self.distance.is_inverted: @@ -30,10 +32,10 @@ def _compute_loss(self, mat, pos_mask, neg_mask): mat, keep_mask=(pos_mask + neg_mask).bool(), add_one=False, dim=1 ) log_prob = mat - denominator - mean_log_prob_pos = (pos_mask * log_prob).sum(dim=1) / ( + print(multi_val * log_prob,'\n', multi_val) + mean_log_prob_pos = (multi_val * log_prob).sum(dim=1) / ( pos_mask.sum(dim=1) + c_f.small_val(mat.dtype) ) - return { "loss": { "losses": -mean_log_prob_pos, @@ -50,11 +52,11 @@ def get_default_distance(self): return CosineSimilarity() def mat_based_loss(self, mat, indices_tuple): - a1, p, a2, n = indices_tuple + a1, p, a2, n, jaccard_mat = indices_tuple pos_mask, neg_mask = torch.zeros_like(mat), torch.zeros_like(mat) pos_mask[a1, p] = 1 neg_mask[a2, n] = 1 - return self._compute_loss(mat, pos_mask, neg_mask) + return self._compute_loss(mat, pos_mask, neg_mask, jaccard_mat) def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): c_f.labels_or_indices_tuple_required(labels, indices_tuple) @@ -67,7 +69,7 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): threshold=self.threshold) if all(len(x) <= 1 for x in indices_tuple): return self.zero_losses() - mat = self.distance(embeddings, ref_emb) + mat = self.dot_cosine_sim(embeddings, ref_emb) return self.loss_method(mat, indices_tuple) def forward( diff --git a/src/pytorch_metric_learning/losses/xbm_multilabel.py b/src/pytorch_metric_learning/losses/xbm_multilabel.py index cf9b9f40..8099b107 100644 --- a/src/pytorch_metric_learning/losses/xbm_multilabel.py +++ b/src/pytorch_metric_learning/losses/xbm_multilabel.py @@ -41,10 +41,13 @@ def forward(self, embeddings, labels, indices_tuple=None, enqueue_mask=None): assert len(embeddings) <= len(self.embedding_memory) self.reset_stats() device = embeddings.device + labels = c_f.to_device(labels, device=device) self.embedding_memory = c_f.to_device( self.embedding_memory, device=device, dtype=embeddings.dtype ) - + self.label_memory = c_f.to_device( + self.label_memory, device=device, dtype=labels.dtype + ) if enqueue_mask is not None: emb_for_queue = embeddings[enqueue_mask] labels_for_queue = labels[enqueue_mask] @@ -79,14 +82,12 @@ def forward(self, embeddings, labels, indices_tuple=None, enqueue_mask=None): def add_to_memory(self, embeddings, labels, batch_size): self.curr_batch_idx = ( torch.arange( - self.queue_idx, self.queue_idx + batch_size + self.queue_idx, self.queue_idx + batch_size, device=labels.device ) % self.memory_size ) self.embedding_memory[self.curr_batch_idx] = embeddings.detach() - # self.label_memory[self.curr_batch_idx] = labels - for i in range(len(self.curr_batch_idx)): - self.label_memory[self.curr_batch_idx[i]] = labels[i] + self.label_memory[self.curr_batch_idx] = labels prev_queue_idx = self.queue_idx self.queue_idx = (self.queue_idx + batch_size) % self.memory_size if (not self.has_been_filled) and (self.queue_idx <= prev_queue_idx): @@ -127,6 +128,8 @@ def reset_queue(self): self.register_buffer( "embedding_memory", torch.zeros(self.memory_size, self.embedding_size) ) - self.label_memory = [[] for i in range(self.memory_size)] + self.register_buffer( + "label_memory", torch.zeros(self.memory_size, self.num_classes) + ) self.has_been_filled = False self.queue_idx = 0 diff --git a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py index c9bfbba6..543467ca 100644 --- a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py +++ b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py @@ -34,11 +34,12 @@ def convert_to_pairs(indices_tuple, labels, num_classes, ref_labels=None, device return a, p, a, n def get_matches_and_diffs(labels, num_classes, ref_labels=None, device=None, threshold=0.3): - matches = jaccard(num_classes, labels, ref_labels, device=device, threshold=threshold) + jaccard_matrix = jaccard(num_classes, labels, ref_labels, device=device, threshold=threshold) + matches = torch.where(jaccard_matrix > threshold, 1, 0).to(device) diffs = matches ^ 1 if ref_labels is labels: matches.fill_diagonal_(0) - return matches, diffs + return matches, diffs, jaccard_matrix def get_all_pairs_indices(labels, num_classes, ref_labels=None, device=None, threshold=0.3): @@ -47,10 +48,10 @@ def get_all_pairs_indices(labels, num_classes, ref_labels=None, device=None, thr The first 2 tensors are the indices which form all positive pairs The second 2 tensors are the indices which form all negative pairs """ - matches, diffs = get_matches_and_diffs(labels, num_classes, ref_labels, device, threshold=threshold) + matches, diffs, multi_val = get_matches_and_diffs(labels, num_classes, ref_labels, device, threshold=threshold) a1_idx, p_idx = torch.where(matches) a2_idx, n_idx = torch.where(diffs) - return a1_idx, p_idx, a2_idx, n_idx + return a1_idx, p_idx, a2_idx, n_idx, multi_val def jaccard(n_classes, labels, ref_labels=None, threshold=0.3, device=torch.device("cpu")): if ref_labels is None: @@ -65,11 +66,10 @@ def jaccard(n_classes, labels, ref_labels=None, threshold=0.3, device=torch.devi labels2_union = labels2.sum(-1) union = labels1_union.unsqueeze(1) + labels2_union.unsqueeze(0) intersection = torch.mm(labels1, labels2.T) - jaccard = intersection / (union - intersection) + jaccard_matrix = intersection / (union - intersection) # return indices of jaccard similarity above threshold - label_matrix = torch.where(jaccard > threshold, 1, 0).to(device) - return label_matrix + return jaccard_matrix def convert_to_triplets(indices_tuple, labels, ref_labels=None, t_per_anchor=100): """ diff --git a/tests/losses/test_multilabel_supcon_loss.py b/tests/losses/test_multilabel_supcon_loss.py index ebbd18c7..841fb279 100644 --- a/tests/losses/test_multilabel_supcon_loss.py +++ b/tests/losses/test_multilabel_supcon_loss.py @@ -40,12 +40,22 @@ def test_multi_supcon_loss(self): # self.assertTrue(loss == 0) # test scatter labels - for device in [torch.device("cpu"),torch.device("cuda")]: - embeddings = torch.randn(n_samples, n_dim, device=device) - labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] - labels = torch.stack([ - torch.nn.functional.one_hot(torch.tensor(label), n_cls).sum(dim=0).float() - for label in labels - ], dim=0) - loss = loss_func(embeddings, labels) - self.assertTrue(loss >= 0) \ No newline at end of file + # for device in [torch.device("cpu"),torch.device("cuda")]: + # embeddings = torch.randn(n_samples, n_dim, device=device) + # labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] + # labels = torch.stack([ + # torch.nn.functional.one_hot(torch.tensor(label), n_cls).sum(dim=0).float() + # for label in labels + # ], dim=0) + # loss = loss_func(embeddings, labels) + # self.assertTrue(loss >= 0) + + # test val + embeddings = torch.tensor([[0.1, 0.3, 0.1], + [0.2, 0.2, -0.1], + [0.1, -0.06, 0.1], + [0.03, -0.13, 0.4]]) + labels = torch.tensor([[1,0,1], [1,0,0], [0,0,1], [0,1,1]]) + loss = loss_func(embeddings, labels) + print(loss) + self.assertTrue(loss ==0) \ No newline at end of file From cabb941560013e9cf1975d6df82ee26d5d095998 Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Mon, 9 Oct 2023 10:35:51 +0800 Subject: [PATCH 03/15] [test] xbm --- tests/losses/test_multilabel_supcon_loss.py | 28 +++++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/tests/losses/test_multilabel_supcon_loss.py b/tests/losses/test_multilabel_supcon_loss.py index 841fb279..d8d077bf 100644 --- a/tests/losses/test_multilabel_supcon_loss.py +++ b/tests/losses/test_multilabel_supcon_loss.py @@ -51,11 +51,23 @@ def test_multi_supcon_loss(self): # self.assertTrue(loss >= 0) # test val - embeddings = torch.tensor([[0.1, 0.3, 0.1], - [0.2, 0.2, -0.1], - [0.1, -0.06, 0.1], - [0.03, -0.13, 0.4]]) - labels = torch.tensor([[1,0,1], [1,0,0], [0,0,1], [0,1,1]]) - loss = loss_func(embeddings, labels) - print(loss) - self.assertTrue(loss ==0) \ No newline at end of file + # embeddings = torch.tensor([[0.1, 0.3, 0.1], + # [0.2, 0.2, -0.1], + # [0.1, -0.06, 0.1], + # [0.03, -0.13, 0.4]]) + # labels = torch.tensor([[1,0,1], [1,0,0], [0,0,1], [0,1,1]]) + # loss = loss_func(embeddings, labels) + # print(loss) + # self.assertTrue(loss ==0) + + # test xbm with scatter labels + batchs = 4 + for b in range(batchs): + embeddings = torch.randn(n_samples, n_dim) + labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] + labels = torch.stack([ + torch.nn.functional.one_hot(torch.tensor(label), n_cls).sum(dim=0).float() + for label in labels + ], dim=0) + loss = loss_func(embeddings, labels) + self.assertTrue(loss > 0) \ No newline at end of file From da66f4c416bacbfb3365b1cf5523ee3bd190298c Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Mon, 9 Oct 2023 10:37:13 +0800 Subject: [PATCH 04/15] [update] remove print --- src/pytorch_metric_learning/losses/multilabel_supcon_loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py index 8a3694f1..92d3d414 100644 --- a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py +++ b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py @@ -32,7 +32,6 @@ def _compute_loss(self, mat, pos_mask, neg_mask, multi_val): mat, keep_mask=(pos_mask + neg_mask).bool(), add_one=False, dim=1 ) log_prob = mat - denominator - print(multi_val * log_prob,'\n', multi_val) mean_log_prob_pos = (multi_val * log_prob).sum(dim=1) / ( pos_mask.sum(dim=1) + c_f.small_val(mat.dtype) ) From 2f884df13c4c616ac4328a69511bee6bbe4bafce Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Mon, 9 Oct 2023 17:04:08 +0800 Subject: [PATCH 05/15] [bug] fix mea_log_prob_pos --- src/pytorch_metric_learning/losses/multilabel_supcon_loss.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py index 92d3d414..6d5076b5 100644 --- a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py +++ b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py @@ -32,7 +32,7 @@ def _compute_loss(self, mat, pos_mask, neg_mask, multi_val): mat, keep_mask=(pos_mask + neg_mask).bool(), add_one=False, dim=1 ) log_prob = mat - denominator - mean_log_prob_pos = (multi_val * log_prob).sum(dim=1) / ( + mean_log_prob_pos = (multi_val * log_prob * pos_mask).sum(dim=1) / ( pos_mask.sum(dim=1) + c_f.small_val(mat.dtype) ) return { @@ -68,7 +68,7 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): threshold=self.threshold) if all(len(x) <= 1 for x in indices_tuple): return self.zero_losses() - mat = self.dot_cosine_sim(embeddings, ref_emb) + mat = self.distance(embeddings, ref_emb) return self.loss_method(mat, indices_tuple) def forward( @@ -90,5 +90,6 @@ def forward( embeddings, labels, indices_tuple, ref_emb, ref_labels ) self.add_embedding_regularization_to_loss_dict(loss_dict, embeddings) + print(loss_dict) return self.reducer(loss_dict, embeddings, labels) From 7f21e85331c231d0e3f61f6599cbf97803f181d7 Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Mon, 9 Oct 2023 17:04:37 +0800 Subject: [PATCH 06/15] [update] remove print --- src/pytorch_metric_learning/losses/multilabel_supcon_loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py index 6d5076b5..62d13bca 100644 --- a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py +++ b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py @@ -90,6 +90,5 @@ def forward( embeddings, labels, indices_tuple, ref_emb, ref_labels ) self.add_embedding_regularization_to_loss_dict(loss_dict, embeddings) - print(loss_dict) return self.reducer(loss_dict, embeddings, labels) From 21b9e4b8b6300000be5b42b386ba0666d4fa685c Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Mon, 9 Oct 2023 17:06:18 +0800 Subject: [PATCH 07/15] [update] compatibility for xbm --- .../losses/xbm_multilabel.py | 2 +- .../utils/multilabel_loss_and_miner_utils.py | 28 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/pytorch_metric_learning/losses/xbm_multilabel.py b/src/pytorch_metric_learning/losses/xbm_multilabel.py index 8099b107..7f104674 100644 --- a/src/pytorch_metric_learning/losses/xbm_multilabel.py +++ b/src/pytorch_metric_learning/losses/xbm_multilabel.py @@ -107,7 +107,7 @@ def create_indices_tuple( else: indices_tuple = mlmu.get_all_pairs_indices(labels, self.num_classes, L_mem) if do_remove_self_comparisons: - indices_tuple = lmu.remove_self_comparisons( + indices_tuple = mlmu.remove_self_comparisons( indices_tuple, self.curr_batch_idx, self.memory_size ) diff --git a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py index 543467ca..206f1af7 100644 --- a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py +++ b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py @@ -96,3 +96,31 @@ def get_all_triplets_indices(labels, ref_labels=None): triplets = matches.unsqueeze(2) * diffs.unsqueeze(1) return torch.where(triplets) + +def remove_self_comparisons( + indices_tuple, curr_batch_idx, ref_size, ref_is_subset=False +): + # remove self-comparisons + assert len(indices_tuple) in [4, 5] + s, e = curr_batch_idx[0], curr_batch_idx[-1] + if len(indices_tuple) == 3: + a, p, n = indices_tuple + keep_mask = lmu.not_self_comparisons( + a, p, s, e, curr_batch_idx, ref_size, ref_is_subset + ) + a = a[keep_mask] + p = p[keep_mask] + n = n[keep_mask] + assert len(a) == len(p) == len(n) + return a, p, n + elif len(indices_tuple) == 4: + a1, p, a2, n = indices_tuple + keep_mask = lmu.not_self_comparisons( + a1, p, s, e, curr_batch_idx, ref_size, ref_is_subset + ) + a1 = a1[keep_mask] + p = p[keep_mask] + assert len(a1) == len(p) + assert len(a2) == len(n) + return a1, p, a2, n + From 794e119ed0447d8b3f745900480d6970b6e5ea76 Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Mon, 9 Oct 2023 17:07:47 +0800 Subject: [PATCH 08/15] [test] add test cases --- tests/losses/test_multilabel_supcon_loss.py | 72 +++++++++++---------- 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/tests/losses/test_multilabel_supcon_loss.py b/tests/losses/test_multilabel_supcon_loss.py index d8d077bf..44e715c2 100644 --- a/tests/losses/test_multilabel_supcon_loss.py +++ b/tests/losses/test_multilabel_supcon_loss.py @@ -17,48 +17,50 @@ def test_multi_supcon_loss(self): loss_func = MultiSupConLoss(num_classes=n_cls) xbm_loss_func = CrossBatchMemory4MultiLabel(loss_func, n_dim, memory_size=16) - # # test float32 and float64 - # for dtype in [torch.float32, torch.float64]: - # embeddings = torch.randn(n_samples, n_dim, dtype=dtype) - # labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] - # loss = loss_func(embeddings, labels) - # self.assertTrue(loss >= 0) + # test float32 and float64 + for dtype in [torch.float32, torch.float64]: + embeddings = torch.randn(n_samples, n_dim, dtype=dtype) + labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] + loss = loss_func(embeddings, labels) + self.assertTrue(loss >= 0) - # # test cuda and cpu - # for device in [torch.device("cpu"),torch.device("cuda")]: - # embeddings = torch.randn(n_samples, n_dim, device=device) - # labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] - # loss = loss_func(embeddings, labels) - # self.assertTrue(loss >= 0) + # test cuda and cpu + for device in [torch.device("cpu"),torch.device("cuda")]: + embeddings = torch.randn(n_samples, n_dim, device=device) + labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] + loss = loss_func(embeddings, labels) + self.assertTrue(loss >= 0) # test xbm - # batchs = 4 - # for b in range(batchs): - # embeddings = torch.randn(n_samples, n_dim, dtype=torch.float32) - # labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] - # loss = xbm_loss_func(embeddings, labels) - # self.assertTrue(loss == 0) + batchs = 4 + for b in range(batchs): + embeddings = torch.randn(n_samples, n_dim, dtype=torch.float32) + labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] + loss = xbm_loss_func(embeddings, labels) + self.assertTrue(loss >= 0) # test scatter labels - # for device in [torch.device("cpu"),torch.device("cuda")]: - # embeddings = torch.randn(n_samples, n_dim, device=device) - # labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] - # labels = torch.stack([ - # torch.nn.functional.one_hot(torch.tensor(label), n_cls).sum(dim=0).float() - # for label in labels - # ], dim=0) - # loss = loss_func(embeddings, labels) - # self.assertTrue(loss >= 0) + for device in [torch.device("cpu"),torch.device("cuda")]: + embeddings = torch.randn(n_samples, n_dim, device=device) + labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] + labels = torch.stack([ + torch.nn.functional.one_hot(torch.tensor(label), n_cls).sum(dim=0).float() + for label in labels + ], dim=0) + loss = loss_func(embeddings, labels) + self.assertTrue(loss >= 0) # test val - # embeddings = torch.tensor([[0.1, 0.3, 0.1], - # [0.2, 0.2, -0.1], - # [0.1, -0.06, 0.1], - # [0.03, -0.13, 0.4]]) - # labels = torch.tensor([[1,0,1], [1,0,0], [0,0,1], [0,1,1]]) - # loss = loss_func(embeddings, labels) - # print(loss) - # self.assertTrue(loss ==0) + embeddings = torch.tensor([[0.1, 0.3, 0.1], + [0.23, -0.2, -0.1], + [0.1, -0.16, 0.1], + [0.13, -0.13, 0.2]]) + labels = torch.tensor([[1,0,1], [1,0,0], [0,1,1], [0,1,0]]) + affine_net = torch.nn.Linear(3, 3, bias=False) + affine_net.weight.data = torch.ones_like(affine_net.weight.data) + embeddings = affine_net(embeddings) + loss = loss_func(embeddings, labels) + self.assertTrue(loss>0) # test xbm with scatter labels batchs = 4 From 32d3e9459e14449ddde610ff9e89a4f757bd3962 Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Mon, 9 Oct 2023 20:24:54 +0800 Subject: [PATCH 09/15] [bug] fix xbm self-comparisons --- .../utils/multilabel_loss_and_miner_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py index 206f1af7..3b92a203 100644 --- a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py +++ b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py @@ -103,7 +103,7 @@ def remove_self_comparisons( # remove self-comparisons assert len(indices_tuple) in [4, 5] s, e = curr_batch_idx[0], curr_batch_idx[-1] - if len(indices_tuple) == 3: + if len(indices_tuple) == 4: a, p, n = indices_tuple keep_mask = lmu.not_self_comparisons( a, p, s, e, curr_batch_idx, ref_size, ref_is_subset @@ -113,7 +113,7 @@ def remove_self_comparisons( n = n[keep_mask] assert len(a) == len(p) == len(n) return a, p, n - elif len(indices_tuple) == 4: + elif len(indices_tuple) == 5: a1, p, a2, n = indices_tuple keep_mask = lmu.not_self_comparisons( a1, p, s, e, curr_batch_idx, ref_size, ref_is_subset From b9c1ecd2fb19d27414fc0fa7c841a380e91d4359 Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Tue, 10 Oct 2023 12:01:22 +0800 Subject: [PATCH 10/15] [fix] remove num_labels --- .../losses/multilabel_supcon_loss.py | 1 - .../losses/xbm_multilabel.py | 4 +-- .../utils/multilabel_loss_and_miner_utils.py | 28 +++++++++---------- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py index 62d13bca..3b307b07 100644 --- a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py +++ b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py @@ -62,7 +62,6 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): indices_tuple = mlmu.convert_to_pairs( indices_tuple, labels, - self.num_classes, ref_labels, device=embeddings.device, threshold=self.threshold) diff --git a/src/pytorch_metric_learning/losses/xbm_multilabel.py b/src/pytorch_metric_learning/losses/xbm_multilabel.py index 7f104674..a8fea4ad 100644 --- a/src/pytorch_metric_learning/losses/xbm_multilabel.py +++ b/src/pytorch_metric_learning/losses/xbm_multilabel.py @@ -105,7 +105,7 @@ def create_indices_tuple( if self.miner: indices_tuple = self.miner(embeddings, labels, E_mem, L_mem) else: - indices_tuple = mlmu.get_all_pairs_indices(labels, self.num_classes, L_mem) + indices_tuple = mlmu.get_all_pairs_indices(labels, L_mem) if do_remove_self_comparisons: indices_tuple = mlmu.remove_self_comparisons( indices_tuple, self.curr_batch_idx, self.memory_size @@ -113,7 +113,7 @@ def create_indices_tuple( if input_indices_tuple is not None: if len(input_indices_tuple) == 3 and len(indices_tuple) == 4: - input_indices_tuple = mlmu.convert_to_pairs(input_indices_tuple, labels, self.num_classes) + input_indices_tuple = mlmu.convert_to_pairs(input_indices_tuple, labels) elif len(input_indices_tuple) == 4 and len(indices_tuple) == 3: input_indices_tuple = mlmu.convert_to_triplets( input_indices_tuple, labels diff --git a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py index 3b92a203..32329cb7 100644 --- a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py +++ b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py @@ -16,7 +16,7 @@ def set_ref_emb(embeddings, labels, ref_emb, ref_labels): check_shapes_multilabels(ref_emb, ref_labels) return ref_emb, ref_labels -def convert_to_pairs(indices_tuple, labels, num_classes, ref_labels=None, device=None, threshold=0.3): +def convert_to_pairs(indices_tuple, labels, ref_labels=None, device=None, threshold=0.3): """ This returns anchor-positive and anchor-negative indices, regardless of what the input indices_tuple is @@ -26,15 +26,15 @@ def convert_to_pairs(indices_tuple, labels, num_classes, ref_labels=None, device labels: a tensor which has the label for each element in a batch """ if indices_tuple is None: - return get_all_pairs_indices(labels, num_classes, ref_labels, device=device, threshold=threshold) - elif len(indices_tuple) == 4: + return get_all_pairs_indices(labels, ref_labels, device=device, threshold=threshold) + elif len(indices_tuple) == 5: return indices_tuple else: - a, p, n = indices_tuple - return a, p, a, n + a, p, n, jaccard_mat = indices_tuple + return a, p, a, n,jaccard_mat -def get_matches_and_diffs(labels, num_classes, ref_labels=None, device=None, threshold=0.3): - jaccard_matrix = jaccard(num_classes, labels, ref_labels, device=device, threshold=threshold) +def get_matches_and_diffs(labels, ref_labels=None, device=None, threshold=0.3): + jaccard_matrix = jaccard(labels, ref_labels) matches = torch.where(jaccard_matrix > threshold, 1, 0).to(device) diffs = matches ^ 1 if ref_labels is labels: @@ -42,18 +42,18 @@ def get_matches_and_diffs(labels, num_classes, ref_labels=None, device=None, thr return matches, diffs, jaccard_matrix -def get_all_pairs_indices(labels, num_classes, ref_labels=None, device=None, threshold=0.3): +def get_all_pairs_indices(labels, ref_labels=None, device=None, threshold=0.3): """ Given a tensor of labels, this will return 4 tensors. The first 2 tensors are the indices which form all positive pairs The second 2 tensors are the indices which form all negative pairs """ - matches, diffs, multi_val = get_matches_and_diffs(labels, num_classes, ref_labels, device, threshold=threshold) + matches, diffs, multi_val = get_matches_and_diffs(labels, ref_labels, device, threshold=threshold) a1_idx, p_idx = torch.where(matches) a2_idx, n_idx = torch.where(diffs) return a1_idx, p_idx, a2_idx, n_idx, multi_val -def jaccard(n_classes, labels, ref_labels=None, threshold=0.3, device=torch.device("cpu")): +def jaccard(labels, ref_labels=None): if ref_labels is None: ref_labels = labels @@ -104,7 +104,7 @@ def remove_self_comparisons( assert len(indices_tuple) in [4, 5] s, e = curr_batch_idx[0], curr_batch_idx[-1] if len(indices_tuple) == 4: - a, p, n = indices_tuple + a, p, n, _ = indices_tuple keep_mask = lmu.not_self_comparisons( a, p, s, e, curr_batch_idx, ref_size, ref_is_subset ) @@ -112,9 +112,9 @@ def remove_self_comparisons( p = p[keep_mask] n = n[keep_mask] assert len(a) == len(p) == len(n) - return a, p, n + return a, p, n, _ elif len(indices_tuple) == 5: - a1, p, a2, n = indices_tuple + a1, p, a2, n, _ = indices_tuple keep_mask = lmu.not_self_comparisons( a1, p, s, e, curr_batch_idx, ref_size, ref_is_subset ) @@ -122,5 +122,5 @@ def remove_self_comparisons( p = p[keep_mask] assert len(a1) == len(p) assert len(a2) == len(n) - return a1, p, a2, n + return a1, p, a2, n, _ From 41113b9d8344a29595355a76b43751ded9051116 Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Tue, 10 Oct 2023 19:47:12 +0800 Subject: [PATCH 11/15] [update] reformat testcase --- tests/losses/test_multilabel_supcon_loss.py | 88 +++++++-------------- 1 file changed, 30 insertions(+), 58 deletions(-) diff --git a/tests/losses/test_multilabel_supcon_loss.py b/tests/losses/test_multilabel_supcon_loss.py index 44e715c2..5de4bbbc 100644 --- a/tests/losses/test_multilabel_supcon_loss.py +++ b/tests/losses/test_multilabel_supcon_loss.py @@ -10,66 +10,38 @@ ) class TestMultiSupConLoss(unittest.TestCase): - def test_multi_supcon_loss(self): - n_cls = 6 - n_samples = 6 - n_dim = 5 - loss_func = MultiSupConLoss(num_classes=n_cls) - xbm_loss_func = CrossBatchMemory4MultiLabel(loss_func, n_dim, memory_size=16) - - # test float32 and float64 - for dtype in [torch.float32, torch.float64]: - embeddings = torch.randn(n_samples, n_dim, dtype=dtype) - labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] - loss = loss_func(embeddings, labels) - self.assertTrue(loss >= 0) - - # test cuda and cpu - for device in [torch.device("cpu"),torch.device("cuda")]: - embeddings = torch.randn(n_samples, n_dim, device=device) - labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] - loss = loss_func(embeddings, labels) - self.assertTrue(loss >= 0) - - # test xbm - batchs = 4 - for b in range(batchs): - embeddings = torch.randn(n_samples, n_dim, dtype=torch.float32) - labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] - loss = xbm_loss_func(embeddings, labels) - self.assertTrue(loss >= 0) - - # test scatter labels - for device in [torch.device("cpu"),torch.device("cuda")]: - embeddings = torch.randn(n_samples, n_dim, device=device) - labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] - labels = torch.stack([ - torch.nn.functional.one_hot(torch.tensor(label), n_cls).sum(dim=0).float() - for label in labels - ], dim=0) - loss = loss_func(embeddings, labels) - self.assertTrue(loss >= 0) - - # test val - embeddings = torch.tensor([[0.1, 0.3, 0.1], + def __init__(self, methodName: str = "runTest") -> None: + super().__init__(methodName) + self.n_cls = 3 + self.n_samples = 4 + self.n_dim = 3 + self.n_batchs = 10 + self.xbm_max_size = 1024 + self.loss_func = MultiSupConLoss( + num_classes=self.n_cls, + threshold=0.3) + self.xbm_loss_func = CrossBatchMemory4MultiLabel( + self.loss_func, + self.n_dim, + memory_size=self.xbm_max_size) + # test cases + self.embeddings = torch.tensor([[0.1, 0.3, 0.1], [0.23, -0.2, -0.1], [0.1, -0.16, 0.1], [0.13, -0.13, 0.2]]) - labels = torch.tensor([[1,0,1], [1,0,0], [0,1,1], [0,1,0]]) - affine_net = torch.nn.Linear(3, 3, bias=False) - affine_net.weight.data = torch.ones_like(affine_net.weight.data) - embeddings = affine_net(embeddings) - loss = loss_func(embeddings, labels) - self.assertTrue(loss>0) + self.labels = torch.tensor([[1,0,1], [1,0,0], [0,1,1], [0,1,0]]) + self.test_multisupcon_val_gt = 0.6247 + # xbm test cases + self.test_xbm_multisupcon_val_gt = 2.3841 + + + def test_multisupcon_val(self): + loss = self.loss_func(self.embeddings, self.labels) + print(loss) + self.assertTrue(np.isclose(loss.item(), self.test_multisupcon_val_gt, atol=1e-4)) + def test_xbm_multisupcon_val(self): # test xbm with scatter labels - batchs = 4 - for b in range(batchs): - embeddings = torch.randn(n_samples, n_dim) - labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] - labels = torch.stack([ - torch.nn.functional.one_hot(torch.tensor(label), n_cls).sum(dim=0).float() - for label in labels - ], dim=0) - loss = loss_func(embeddings, labels) - self.assertTrue(loss > 0) \ No newline at end of file + for b in range(self.n_batchs): + loss = self.xbm_loss_func(self.embeddings, self.labels) + self.assertTrue(np.isclose(loss.item(), self.test_xbm_multisupcon_val_gt, atol=1e-4)) \ No newline at end of file From cc66e68f4a3a77187ee2fcbac7d6f1fdd8bcc00f Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Tue, 10 Oct 2023 19:59:07 +0800 Subject: [PATCH 12/15] [update] remove dot_cosine_sim --- src/pytorch_metric_learning/losses/multilabel_supcon_loss.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py index 3b307b07..11e4df51 100644 --- a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py +++ b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py @@ -17,9 +17,6 @@ def __init__(self, num_classes, temperature=0.07, threshold=0.3, **kwargs): self.num_classes = num_classes self.threshold = threshold - def dot_cosine_sim(self, a, b): - return a@b.T - def _compute_loss(self, mat, pos_mask, neg_mask, multi_val): if pos_mask.bool().any() and neg_mask.bool().any(): # if dealing with actual distances, use negative distances From 0e9e94846b65dacc53556a798493be5b2117f0ea Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Tue, 10 Oct 2023 19:59:26 +0800 Subject: [PATCH 13/15] [update] remove print --- tests/losses/test_multilabel_supcon_loss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/losses/test_multilabel_supcon_loss.py b/tests/losses/test_multilabel_supcon_loss.py index 5de4bbbc..94def73c 100644 --- a/tests/losses/test_multilabel_supcon_loss.py +++ b/tests/losses/test_multilabel_supcon_loss.py @@ -37,8 +37,7 @@ def __init__(self, methodName: str = "runTest") -> None: def test_multisupcon_val(self): loss = self.loss_func(self.embeddings, self.labels) - print(loss) - self.assertTrue(np.isclose(loss.item(), self.test_multisupcon_val_gt, atol=1e-4)) + self.assertTrue(np.isclose(loss.item(), self.test_multisupcon_val_gt, atol=1e-6)) def test_xbm_multisupcon_val(self): # test xbm with scatter labels From eb6647a06940d6264c561484b0ada6c283ffbb20 Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Wed, 11 Oct 2023 12:02:53 +0800 Subject: [PATCH 14/15] [update] reconstruct the files --- .../losses/__init__.py | 3 +- .../losses/multilabel_supcon_loss.py | 283 +++++++++++++++++- .../losses/xbm_multilabel.py | 7 +- .../utils/multilabel_loss_and_miner_utils.py | 17 +- tests/losses/test_multilabel_supcon_loss.py | 112 ++++++- 5 files changed, 395 insertions(+), 27 deletions(-) diff --git a/src/pytorch_metric_learning/losses/__init__.py b/src/pytorch_metric_learning/losses/__init__.py index d3b98c94..cfff0813 100644 --- a/src/pytorch_metric_learning/losses/__init__.py +++ b/src/pytorch_metric_learning/losses/__init__.py @@ -35,5 +35,4 @@ from .triplet_margin_loss import TripletMarginLoss from .tuplet_margin_loss import TupletMarginLoss from .vicreg_loss import VICRegLoss -from .multilabel_supcon_loss import MultiSupConLoss -from .xbm_multilabel import CrossBatchMemory4MultiLabel +from .multilabel_supcon_loss import MultiSupConLoss, CrossBatchMemory4MultiLabel \ No newline at end of file diff --git a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py index 11e4df51..a8e226ab 100644 --- a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py +++ b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py @@ -3,14 +3,15 @@ from ..distances import CosineSimilarity from ..reducers import AvgNonZeroReducer from ..utils import common_functions as c_f -from ..utils import multilabel_loss_and_miner_utils as mlmu from ..utils import loss_and_miner_utils as lmu +from ..utils.module_with_records import ModuleWithRecords from .generic_pair_loss import GenericPairLoss - +from .base_loss_wrapper import BaseLossWrapper # adapted from https://github.com/HobbitLong/SupContrast +# modified for multi-supcon class MultiSupConLoss(GenericPairLoss): - def __init__(self, num_classes, temperature=0.07, threshold=0.3, **kwargs): + def __init__(self, num_classes, temperature=0.1, threshold=0.3, **kwargs): super().__init__(mat_based_loss=True, **kwargs) self.temperature = temperature self.add_to_recordable_attributes(list_of_names=["temperature"], is_stat=False) @@ -25,6 +26,7 @@ def _compute_loss(self, mat, pos_mask, neg_mask, multi_val): mat = mat / self.temperature mat_max, _ = mat.max(dim=1, keepdim=True) mat = mat - mat_max.detach() # for numerical stability + denominator = lmu.logsumexp( mat, keep_mask=(pos_mask + neg_mask).bool(), add_one=False, dim=1 ) @@ -32,6 +34,7 @@ def _compute_loss(self, mat, pos_mask, neg_mask, multi_val): mean_log_prob_pos = (multi_val * log_prob * pos_mask).sum(dim=1) / ( pos_mask.sum(dim=1) + c_f.small_val(mat.dtype) ) + return { "loss": { "losses": -mean_log_prob_pos, @@ -47,6 +50,8 @@ def get_default_reducer(self): def get_default_distance(self): return CosineSimilarity() + # ==== class methods below are overriden for adaptability to multi-supcon ==== + def mat_based_loss(self, mat, indices_tuple): a1, p, a2, n, jaccard_mat = indices_tuple pos_mask, neg_mask = torch.zeros_like(mat), torch.zeros_like(mat) @@ -56,11 +61,10 @@ def mat_based_loss(self, mat, indices_tuple): def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): c_f.labels_or_indices_tuple_required(labels, indices_tuple) - indices_tuple = mlmu.convert_to_pairs( + indices_tuple = convert_to_pairs( indices_tuple, labels, ref_labels, - device=embeddings.device, threshold=self.threshold) if all(len(x) <= 1 for x in indices_tuple): return self.zero_losses() @@ -80,11 +84,276 @@ def forward( Returns: the loss """ self.reset_stats() - mlmu.check_shapes_multilabels(embeddings, labels) - ref_emb, ref_labels = mlmu.set_ref_emb(embeddings, labels, ref_emb, ref_labels) + check_shapes_multilabels(embeddings, labels) + ref_emb, ref_labels = set_ref_emb(embeddings, labels, ref_emb, ref_labels) loss_dict = self.compute_loss( embeddings, labels, indices_tuple, ref_emb, ref_labels ) self.add_embedding_regularization_to_loss_dict(loss_dict, embeddings) return self.reducer(loss_dict, embeddings, labels) + # ========================================================================= + + +# ================== cross batch memory for multi-supcon ================== +class CrossBatchMemory4MultiLabel(BaseLossWrapper, ModuleWithRecords): + def __init__(self, loss, embedding_size, memory_size=1024, miner=None, **kwargs): + super().__init__(loss=loss, **kwargs) + self.loss = loss + self.miner = miner + self.embedding_size = embedding_size + self.memory_size = memory_size + self.num_classes = loss.num_classes + self.reset_queue() + self.add_to_recordable_attributes( + list_of_names=["embedding_size", "memory_size", "queue_idx"], is_stat=False + ) + + @staticmethod + def supported_losses(): + return [ + "MultiSupConLoss" + ] + + @classmethod + def check_loss_support(cls, loss_name): + if loss_name not in cls.supported_losses(): + raise Exception(f"CrossBatchMemory not supported for {loss_name}") + + def forward(self, embeddings, labels, indices_tuple=None, enqueue_mask=None): + if indices_tuple is not None and enqueue_mask is not None: + raise ValueError("indices_tuple and enqueue_mask are mutually exclusive") + if enqueue_mask is not None: + assert len(enqueue_mask) == len(embeddings) + else: + assert len(embeddings) <= len(self.embedding_memory) + self.reset_stats() + device = embeddings.device + labels = c_f.to_device(labels, device=device) + self.embedding_memory = c_f.to_device( + self.embedding_memory, device=device, dtype=embeddings.dtype + ) + self.label_memory = c_f.to_device( + self.label_memory, device=device, dtype=labels.dtype + ) + + if enqueue_mask is not None: + emb_for_queue = embeddings[enqueue_mask] + labels_for_queue = labels[enqueue_mask] + embeddings = embeddings[~enqueue_mask] + labels = labels[~enqueue_mask] + do_remove_self_comparisons = False + else: + emb_for_queue = embeddings + labels_for_queue = labels + do_remove_self_comparisons = True + + queue_batch_size = len(emb_for_queue) + self.add_to_memory(emb_for_queue, labels_for_queue, queue_batch_size) + + if not self.has_been_filled: + E_mem = self.embedding_memory[: self.queue_idx] + L_mem = self.label_memory[: self.queue_idx] + else: + E_mem = self.embedding_memory + L_mem = self.label_memory + + indices_tuple = self.create_indices_tuple( + embeddings, + labels, + E_mem, + L_mem, + indices_tuple, + do_remove_self_comparisons, + ) + loss = self.loss(embeddings, labels, indices_tuple, E_mem, L_mem) + return loss + + def add_to_memory(self, embeddings, labels, batch_size): + self.curr_batch_idx = ( + torch.arange( + self.queue_idx, self.queue_idx + batch_size, device=labels.device + ) + % self.memory_size + ) + self.embedding_memory[self.curr_batch_idx] = embeddings.detach() + self.label_memory[self.curr_batch_idx] = labels.detach() + prev_queue_idx = self.queue_idx + self.queue_idx = (self.queue_idx + batch_size) % self.memory_size + if (not self.has_been_filled) and (self.queue_idx <= prev_queue_idx): + self.has_been_filled = True + + def create_indices_tuple( + self, + embeddings, + labels, + E_mem, + L_mem, + input_indices_tuple, + do_remove_self_comparisons, + ): + if self.miner: + indices_tuple = self.miner(embeddings, labels, E_mem, L_mem) + else: + indices_tuple = get_all_pairs_indices(labels, L_mem) + + if do_remove_self_comparisons: + indices_tuple = remove_self_comparisons( + indices_tuple, self.curr_batch_idx, self.memory_size + ) + + if input_indices_tuple is not None: + if len(input_indices_tuple) == 3 and len(indices_tuple) == 4: + input_indices_tuple = convert_to_pairs(input_indices_tuple, labels) + elif len(input_indices_tuple) == 4 and len(indices_tuple) == 3: + input_indices_tuple = convert_to_triplets( + input_indices_tuple, labels + ) + indices_tuple = c_f.concatenate_indices_tuples( + indices_tuple, input_indices_tuple + ) + + return indices_tuple + + def reset_queue(self): + self.register_buffer( + "embedding_memory", torch.zeros(self.memory_size, self.embedding_size) + ) + self.register_buffer( + "label_memory", torch.zeros(self.memory_size, self.num_classes) + ) + self.has_been_filled = False + self.queue_idx = 0 + +# ========================================================================= + +# compute jaccard similarity +def jaccard(labels, ref_labels=None): + if ref_labels is None: + ref_labels = labels + + labels1 = labels.float() + labels2 = ref_labels.float() + + # compute jaccard similarity + # jaccard = intersection / union + labels1_union = labels1.sum(-1) + labels2_union = labels2.sum(-1) + union = labels1_union.unsqueeze(1) + labels2_union.unsqueeze(0) + intersection = torch.mm(labels1, labels2.T) + jaccard_matrix = intersection / (union - intersection) + + # return indices of jaccard similarity above threshold + return jaccard_matrix + +# ====== methods below are overriden for adaptability to multi-supcon ====== + +# use jaccard similarity to get matches +def get_matches_and_diffs(labels, ref_labels=None, threshold=0.3): + if ref_labels is None: + ref_labels = labels + jaccard_matrix = jaccard(labels, ref_labels) + matches = torch.where(jaccard_matrix > threshold, 1, 0) + diffs = matches ^ 1 + if ref_labels is labels: + matches.fill_diagonal_(0) + return matches, diffs, jaccard_matrix + +def check_shapes_multilabels(embeddings, labels): + if labels is not None and embeddings.shape[0] != labels.shape[0]: + raise ValueError("Number of embeddings must equal number of labels") + if labels is not None and labels.ndim != 2: + raise ValueError("labels must be a 1D tensor of shape (batch_size,)") + + +def set_ref_emb(embeddings, labels, ref_emb, ref_labels): + if ref_emb is None: + ref_emb, ref_labels = embeddings, labels + check_shapes_multilabels(ref_emb, ref_labels) + return ref_emb, ref_labels + + +def convert_to_pairs(indices_tuple, labels, ref_labels=None, threshold=0.3): + """ + This returns anchor-positive and anchor-negative indices, + regardless of what the input indices_tuple is + Args: + indices_tuple: tuple of tensors. Each tensor is 1d and specifies indices + within a batch + labels: a tensor which has the label for each element in a batch + """ + if indices_tuple is None: + return get_all_pairs_indices(labels, ref_labels, threshold=threshold) + elif len(indices_tuple) == 5: + return indices_tuple + else: + a, p, n, jaccard_mat = indices_tuple + return a, p, a, n,jaccard_mat + + +def get_all_pairs_indices(labels, ref_labels=None, threshold=0.3): + """ + Given a tensor of labels, this will return 4 tensors. + The first 2 tensors are the indices which form all positive pairs + The second 2 tensors are the indices which form all negative pairs + """ + matches, diffs, multi_val = get_matches_and_diffs(labels, ref_labels, threshold=threshold) + a1_idx, p_idx = torch.where(matches) + a2_idx, n_idx = torch.where(diffs) + return a1_idx, p_idx, a2_idx, n_idx, multi_val + + +def convert_to_triplets(indices_tuple, labels, ref_labels=None, t_per_anchor=100): + """ + This returns anchor-positive-negative triplets + regardless of what the input indices_tuple is + """ + if indices_tuple is None: + if t_per_anchor == "all": + return get_all_triplets_indices(labels, ref_labels) + else: + return lmu.get_random_triplet_indices( + labels, ref_labels, t_per_anchor=t_per_anchor + ) + elif len(indices_tuple) == 3: + return indices_tuple + else: + a1, p, a2, n = indices_tuple + p_idx, n_idx = torch.where(a1.unsqueeze(1) == a2) + return a1[p_idx], p[p_idx], n[n_idx] + + +def get_all_triplets_indices(labels, ref_labels=None): + matches, diffs = get_matches_and_diffs(labels, ref_labels) + triplets = matches.unsqueeze(2) * diffs.unsqueeze(1) + return torch.where(triplets) + + +def remove_self_comparisons( + indices_tuple, curr_batch_idx, ref_size, ref_is_subset=False +): + # remove self-comparisons + assert len(indices_tuple) in [4, 5] + s, e = curr_batch_idx[0], curr_batch_idx[-1] + if len(indices_tuple) == 4: + a, p, n, jaccard_mat = indices_tuple + keep_mask = lmu.not_self_comparisons( + a, p, s, e, curr_batch_idx, ref_size, ref_is_subset + ) + a = a[keep_mask] + p = p[keep_mask] + n = n[keep_mask] + assert len(a) == len(p) == len(n) + return a, p, n, jaccard_mat + elif len(indices_tuple) == 5: + a1, p, a2, n, jaccard_mat = indices_tuple + keep_mask = lmu.not_self_comparisons( + a1, p, s, e, curr_batch_idx, ref_size, ref_is_subset + ) + a1 = a1[keep_mask] + p = p[keep_mask] + assert len(a1) == len(p) + assert len(a2) == len(n) + return a1, p, a2, n, jaccard_mat + +# ========================================================================= \ No newline at end of file diff --git a/src/pytorch_metric_learning/losses/xbm_multilabel.py b/src/pytorch_metric_learning/losses/xbm_multilabel.py index a8fea4ad..5de41e38 100644 --- a/src/pytorch_metric_learning/losses/xbm_multilabel.py +++ b/src/pytorch_metric_learning/losses/xbm_multilabel.py @@ -1,7 +1,7 @@ import torch from ..utils import common_functions as c_f -# replace the functions of loss_and_miner_utils by multisupcon's +# replace the functions of loss_and_miner_utils for multilabels from ..utils import multilabel_loss_and_miner_utils as mlmu from ..utils import loss_and_miner_utils as lmu from ..utils.module_with_records import ModuleWithRecords @@ -48,6 +48,7 @@ def forward(self, embeddings, labels, indices_tuple=None, enqueue_mask=None): self.label_memory = c_f.to_device( self.label_memory, device=device, dtype=labels.dtype ) + if enqueue_mask is not None: emb_for_queue = embeddings[enqueue_mask] labels_for_queue = labels[enqueue_mask] @@ -68,6 +69,7 @@ def forward(self, embeddings, labels, indices_tuple=None, enqueue_mask=None): else: E_mem = self.embedding_memory L_mem = self.label_memory + indices_tuple = self.create_indices_tuple( embeddings, labels, @@ -87,7 +89,7 @@ def add_to_memory(self, embeddings, labels, batch_size): % self.memory_size ) self.embedding_memory[self.curr_batch_idx] = embeddings.detach() - self.label_memory[self.curr_batch_idx] = labels + self.label_memory[self.curr_batch_idx] = labels.detach() prev_queue_idx = self.queue_idx self.queue_idx = (self.queue_idx + batch_size) % self.memory_size if (not self.has_been_filled) and (self.queue_idx <= prev_queue_idx): @@ -106,6 +108,7 @@ def create_indices_tuple( indices_tuple = self.miner(embeddings, labels, E_mem, L_mem) else: indices_tuple = mlmu.get_all_pairs_indices(labels, L_mem) + if do_remove_self_comparisons: indices_tuple = mlmu.remove_self_comparisons( indices_tuple, self.curr_batch_idx, self.memory_size diff --git a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py index 32329cb7..ee211fc7 100644 --- a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py +++ b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py @@ -1,6 +1,15 @@ import torch from . import loss_and_miner_utils as lmu +def get_matches_and_diffs(labels, ref_labels=None, threshold=0.3): + if ref_labels is None: + ref_labels = labels + jaccard_matrix = jaccard(labels, ref_labels) + matches = torch.where(jaccard_matrix > threshold, 1, 0) + diffs = matches ^ 1 + if ref_labels is labels: + matches.fill_diagonal_(0) + return matches, diffs, jaccard_matrix def check_shapes_multilabels(embeddings, labels): if labels is not None and embeddings.shape[0] != len(labels): raise ValueError("Number of embeddings must equal number of labels") @@ -33,13 +42,7 @@ def convert_to_pairs(indices_tuple, labels, ref_labels=None, device=None, thresh a, p, n, jaccard_mat = indices_tuple return a, p, a, n,jaccard_mat -def get_matches_and_diffs(labels, ref_labels=None, device=None, threshold=0.3): - jaccard_matrix = jaccard(labels, ref_labels) - matches = torch.where(jaccard_matrix > threshold, 1, 0).to(device) - diffs = matches ^ 1 - if ref_labels is labels: - matches.fill_diagonal_(0) - return matches, diffs, jaccard_matrix + def get_all_pairs_indices(labels, ref_labels=None, device=None, threshold=0.3): diff --git a/tests/losses/test_multilabel_supcon_loss.py b/tests/losses/test_multilabel_supcon_loss.py index 94def73c..fc82fc66 100644 --- a/tests/losses/test_multilabel_supcon_loss.py +++ b/tests/losses/test_multilabel_supcon_loss.py @@ -2,13 +2,15 @@ import torch import numpy as np -import random from pytorch_metric_learning.losses import ( MultiSupConLoss, CrossBatchMemory4MultiLabel ) +from ..zzz_testing_utils.testing_utils import angle_to_coord + +from .. import TEST_DEVICE, TEST_DTYPES class TestMultiSupConLoss(unittest.TestCase): def __init__(self, methodName: str = "runTest") -> None: super().__init__(methodName) @@ -17,9 +19,14 @@ def __init__(self, methodName: str = "runTest") -> None: self.n_dim = 3 self.n_batchs = 10 self.xbm_max_size = 1024 + + # multi_supcon self.loss_func = MultiSupConLoss( - num_classes=self.n_cls, + num_classes=self.n_cls, + temperature=0.07, threshold=0.3) + + # xbm self.xbm_loss_func = CrossBatchMemory4MultiLabel( self.loss_func, self.n_dim, @@ -30,17 +37,104 @@ def __init__(self, methodName: str = "runTest") -> None: [0.1, -0.16, 0.1], [0.13, -0.13, 0.2]]) self.labels = torch.tensor([[1,0,1], [1,0,0], [0,1,1], [0,1,0]]) - self.test_multisupcon_val_gt = 0.6247 + + # the gt values are obtained by running the code + # (https://github.com/WolodjaZ/MultiSupContrast/blob/main/losses.py) + + # multi_supcon test cases + self.test_multisupcon_val_gt = { + torch.float16: 3.2836, + torch.float32: 3.2874, + torch.float64: 3.2874, + } # xbm test cases - self.test_xbm_multisupcon_val_gt = 2.3841 + self.test_xbm_multisupcon_val_gt = { + torch.float16: [3.2836, 4.3792, 4.4588, 4.5741, 4.6831, 4.7809, 4.8682, 4.9465, 5.0174, 5.0819], + torch.float32: [3.2874, 4.3779, 4.4577, 4.5730, 4.6820, 4.7798, 4.8671, 4.9455, 5.0163, 5.0808], + torch.float64: [3.2874, 4.3779, 4.4577, 4.5730, 4.6820, 4.7798, 4.8671, 4.9455, 5.0163, 5.0808,] + } def test_multisupcon_val(self): - loss = self.loss_func(self.embeddings, self.labels) - self.assertTrue(np.isclose(loss.item(), self.test_multisupcon_val_gt, atol=1e-6)) + for dtype in TEST_DTYPES: + for device in ["cpu", "cuda"]: + # skip float16 on cpu + if device == "cpu" and dtype == torch.float16: + continue + embedding = self.embeddings.to(device).to(dtype) + label = self.labels.to(device).to(dtype) + loss = self.loss_func(embedding, label) + loss = loss.to("cpu") + self.assertTrue(np.isclose( + loss.item(), + self.test_multisupcon_val_gt[dtype], + atol=1e-2 if dtype == torch.float16 else 1e-4)) + def test_xbm_multisupcon_val(self): # test xbm with scatter labels - for b in range(self.n_batchs): - loss = self.xbm_loss_func(self.embeddings, self.labels) - self.assertTrue(np.isclose(loss.item(), self.test_xbm_multisupcon_val_gt, atol=1e-4)) \ No newline at end of file + for dtype in TEST_DTYPES: + for device in ["cpu", "cuda"]: + # skip float16 on cpu + if device == "cpu" and dtype == torch.float16: + continue + self.xbm_loss_func.reset_queue() + for b in range(self.n_batchs): + embedding = self.embeddings.to(device).to(dtype) + label = self.labels.to(device).to(dtype) + loss = self.xbm_loss_func(embedding, label) + loss = loss.to("cpu") + print(loss, self.test_xbm_multisupcon_val_gt[dtype][b], dtype) + self.assertTrue(np.isclose( + loss.item(), + self.test_xbm_multisupcon_val_gt[dtype][b], + atol=1e-2 if dtype == torch.float16 else 1e-4)) + + def test_with_no_valid_pairs(self): + for dtype in TEST_DTYPES: + embedding_angles = [0] + embeddings = torch.tensor( + [angle_to_coord(a) for a in embedding_angles], + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + labels = torch.LongTensor([[0]]) + loss = self.loss_func(embeddings, labels) + loss.backward() + self.assertEqual(loss, 0) + + def test_(self): + for dtype in TEST_DTYPES: + embedding_angles = [0] + embeddings = torch.tensor( + [angle_to_coord(a) for a in embedding_angles], + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + labels = torch.LongTensor([[0]]) + loss = self.loss_func(embeddings, labels) + loss.backward() + self.assertEqual(loss, 0) + + + def test_backward(self): + for dtype in TEST_DTYPES: + embedding_angles = list(range(0, 180, 20))[:4] + embeddings = torch.tensor( + [angle_to_coord(a) for a in embedding_angles], + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + labels = torch.LongTensor([[0, 0, 1, 0, 1, 0, 0], + [1, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 1]]).to(TEST_DEVICE) + + loss = self.loss_func(embeddings, labels) + loss.backward() \ No newline at end of file From d3fe43c4bf6845ca79d52ee03746d055ca26eeaf Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Wed, 11 Oct 2023 12:03:40 +0800 Subject: [PATCH 15/15] [update] remove files --- .../losses/xbm_multilabel.py | 138 ------------------ .../utils/multilabel_loss_and_miner_utils.py | 129 ---------------- 2 files changed, 267 deletions(-) delete mode 100644 src/pytorch_metric_learning/losses/xbm_multilabel.py delete mode 100644 src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py diff --git a/src/pytorch_metric_learning/losses/xbm_multilabel.py b/src/pytorch_metric_learning/losses/xbm_multilabel.py deleted file mode 100644 index 5de41e38..00000000 --- a/src/pytorch_metric_learning/losses/xbm_multilabel.py +++ /dev/null @@ -1,138 +0,0 @@ -import torch - -from ..utils import common_functions as c_f -# replace the functions of loss_and_miner_utils for multilabels -from ..utils import multilabel_loss_and_miner_utils as mlmu -from ..utils import loss_and_miner_utils as lmu -from ..utils.module_with_records import ModuleWithRecords -from .base_loss_wrapper import BaseLossWrapper - - -class CrossBatchMemory4MultiLabel(BaseLossWrapper, ModuleWithRecords): - def __init__(self, loss, embedding_size, memory_size=1024, miner=None, **kwargs): - super().__init__(loss=loss, **kwargs) - self.loss = loss - self.miner = miner - self.embedding_size = embedding_size - self.memory_size = memory_size - self.num_classes = loss.num_classes - self.reset_queue() - self.add_to_recordable_attributes( - list_of_names=["embedding_size", "memory_size", "queue_idx"], is_stat=False - ) - - @staticmethod - def supported_losses(): - return [ - "MultiSupConLoss" - ] - - @classmethod - def check_loss_support(cls, loss_name): - if loss_name not in cls.supported_losses(): - raise Exception(f"CrossBatchMemory not supported for {loss_name}") - - def forward(self, embeddings, labels, indices_tuple=None, enqueue_mask=None): - if indices_tuple is not None and enqueue_mask is not None: - raise ValueError("indices_tuple and enqueue_mask are mutually exclusive") - if enqueue_mask is not None: - assert len(enqueue_mask) == len(embeddings) - else: - assert len(embeddings) <= len(self.embedding_memory) - self.reset_stats() - device = embeddings.device - labels = c_f.to_device(labels, device=device) - self.embedding_memory = c_f.to_device( - self.embedding_memory, device=device, dtype=embeddings.dtype - ) - self.label_memory = c_f.to_device( - self.label_memory, device=device, dtype=labels.dtype - ) - - if enqueue_mask is not None: - emb_for_queue = embeddings[enqueue_mask] - labels_for_queue = labels[enqueue_mask] - embeddings = embeddings[~enqueue_mask] - labels = labels[~enqueue_mask] - do_remove_self_comparisons = False - else: - emb_for_queue = embeddings - labels_for_queue = labels - do_remove_self_comparisons = True - - queue_batch_size = len(emb_for_queue) - self.add_to_memory(emb_for_queue, labels_for_queue, queue_batch_size) - - if not self.has_been_filled: - E_mem = self.embedding_memory[: self.queue_idx] - L_mem = self.label_memory[: self.queue_idx] - else: - E_mem = self.embedding_memory - L_mem = self.label_memory - - indices_tuple = self.create_indices_tuple( - embeddings, - labels, - E_mem, - L_mem, - indices_tuple, - do_remove_self_comparisons, - ) - loss = self.loss(embeddings, labels, indices_tuple, E_mem, L_mem) - return loss - - def add_to_memory(self, embeddings, labels, batch_size): - self.curr_batch_idx = ( - torch.arange( - self.queue_idx, self.queue_idx + batch_size, device=labels.device - ) - % self.memory_size - ) - self.embedding_memory[self.curr_batch_idx] = embeddings.detach() - self.label_memory[self.curr_batch_idx] = labels.detach() - prev_queue_idx = self.queue_idx - self.queue_idx = (self.queue_idx + batch_size) % self.memory_size - if (not self.has_been_filled) and (self.queue_idx <= prev_queue_idx): - self.has_been_filled = True - - def create_indices_tuple( - self, - embeddings, - labels, - E_mem, - L_mem, - input_indices_tuple, - do_remove_self_comparisons, - ): - if self.miner: - indices_tuple = self.miner(embeddings, labels, E_mem, L_mem) - else: - indices_tuple = mlmu.get_all_pairs_indices(labels, L_mem) - - if do_remove_self_comparisons: - indices_tuple = mlmu.remove_self_comparisons( - indices_tuple, self.curr_batch_idx, self.memory_size - ) - - if input_indices_tuple is not None: - if len(input_indices_tuple) == 3 and len(indices_tuple) == 4: - input_indices_tuple = mlmu.convert_to_pairs(input_indices_tuple, labels) - elif len(input_indices_tuple) == 4 and len(indices_tuple) == 3: - input_indices_tuple = mlmu.convert_to_triplets( - input_indices_tuple, labels - ) - indices_tuple = c_f.concatenate_indices_tuples( - indices_tuple, input_indices_tuple - ) - - return indices_tuple - - def reset_queue(self): - self.register_buffer( - "embedding_memory", torch.zeros(self.memory_size, self.embedding_size) - ) - self.register_buffer( - "label_memory", torch.zeros(self.memory_size, self.num_classes) - ) - self.has_been_filled = False - self.queue_idx = 0 diff --git a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py deleted file mode 100644 index ee211fc7..00000000 --- a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py +++ /dev/null @@ -1,129 +0,0 @@ -import torch -from . import loss_and_miner_utils as lmu - -def get_matches_and_diffs(labels, ref_labels=None, threshold=0.3): - if ref_labels is None: - ref_labels = labels - jaccard_matrix = jaccard(labels, ref_labels) - matches = torch.where(jaccard_matrix > threshold, 1, 0) - diffs = matches ^ 1 - if ref_labels is labels: - matches.fill_diagonal_(0) - return matches, diffs, jaccard_matrix -def check_shapes_multilabels(embeddings, labels): - if labels is not None and embeddings.shape[0] != len(labels): - raise ValueError("Number of embeddings must equal number of labels") - if labels is not None: - if isinstance(labels[0], list) or isinstance(labels[0], torch.Tensor): - pass - else: - raise ValueError("labels must be a list of 1d tensors or a list of lists") - -def set_ref_emb(embeddings, labels, ref_emb, ref_labels): - if ref_emb is None: - ref_emb, ref_labels = embeddings, labels - check_shapes_multilabels(ref_emb, ref_labels) - return ref_emb, ref_labels - -def convert_to_pairs(indices_tuple, labels, ref_labels=None, device=None, threshold=0.3): - """ - This returns anchor-positive and anchor-negative indices, - regardless of what the input indices_tuple is - Args: - indices_tuple: tuple of tensors. Each tensor is 1d and specifies indices - within a batch - labels: a tensor which has the label for each element in a batch - """ - if indices_tuple is None: - return get_all_pairs_indices(labels, ref_labels, device=device, threshold=threshold) - elif len(indices_tuple) == 5: - return indices_tuple - else: - a, p, n, jaccard_mat = indices_tuple - return a, p, a, n,jaccard_mat - - - - -def get_all_pairs_indices(labels, ref_labels=None, device=None, threshold=0.3): - """ - Given a tensor of labels, this will return 4 tensors. - The first 2 tensors are the indices which form all positive pairs - The second 2 tensors are the indices which form all negative pairs - """ - matches, diffs, multi_val = get_matches_and_diffs(labels, ref_labels, device, threshold=threshold) - a1_idx, p_idx = torch.where(matches) - a2_idx, n_idx = torch.where(diffs) - return a1_idx, p_idx, a2_idx, n_idx, multi_val - -def jaccard(labels, ref_labels=None): - if ref_labels is None: - ref_labels = labels - - labels1 = labels.float() - labels2 = ref_labels.float() - - # compute jaccard similarity - # jaccard = intersection / union - labels1_union = labels1.sum(-1) - labels2_union = labels2.sum(-1) - union = labels1_union.unsqueeze(1) + labels2_union.unsqueeze(0) - intersection = torch.mm(labels1, labels2.T) - jaccard_matrix = intersection / (union - intersection) - - # return indices of jaccard similarity above threshold - return jaccard_matrix - -def convert_to_triplets(indices_tuple, labels, ref_labels=None, t_per_anchor=100): - """ - This returns anchor-positive-negative triplets - regardless of what the input indices_tuple is - """ - if indices_tuple is None: - if t_per_anchor == "all": - return get_all_triplets_indices(labels, ref_labels) - else: - return lmu.get_random_triplet_indices( - labels, ref_labels, t_per_anchor=t_per_anchor - ) - elif len(indices_tuple) == 3: - return indices_tuple - else: - a1, p, a2, n = indices_tuple - p_idx, n_idx = torch.where(a1.unsqueeze(1) == a2) - return a1[p_idx], p[p_idx], n[n_idx] - - -def get_all_triplets_indices(labels, ref_labels=None): - matches, diffs = get_matches_and_diffs(labels, ref_labels) - triplets = matches.unsqueeze(2) * diffs.unsqueeze(1) - return torch.where(triplets) - - -def remove_self_comparisons( - indices_tuple, curr_batch_idx, ref_size, ref_is_subset=False -): - # remove self-comparisons - assert len(indices_tuple) in [4, 5] - s, e = curr_batch_idx[0], curr_batch_idx[-1] - if len(indices_tuple) == 4: - a, p, n, _ = indices_tuple - keep_mask = lmu.not_self_comparisons( - a, p, s, e, curr_batch_idx, ref_size, ref_is_subset - ) - a = a[keep_mask] - p = p[keep_mask] - n = n[keep_mask] - assert len(a) == len(p) == len(n) - return a, p, n, _ - elif len(indices_tuple) == 5: - a1, p, a2, n, _ = indices_tuple - keep_mask = lmu.not_self_comparisons( - a1, p, s, e, curr_batch_idx, ref_size, ref_is_subset - ) - a1 = a1[keep_mask] - p = p[keep_mask] - assert len(a1) == len(p) - assert len(a2) == len(n) - return a1, p, a2, n, _ -