Question about ref_labels, ref_emb and computing training pairs for different sets of embeddings #495
-
Hi, I am working on generating textual embeddings for queries. When I generate pairs of positives and negatives I don't want to generate them for the pairs queries/queries or doc_embeddings/doc_embeddings. Only for queries --> doc_embeddings. As my doc_embeddings are already precomputed. query_terms, document_embeddings, labels have the same length and for every query the document_embedding at the same position is the positive match. There could be other positives, as there could be different queries with the same document_embedding (they will share same label). I was looking at Anchors are query_embeddings, positive and negatives are contained in document_embeddings. Labels is the same array for both. I am expecting a loss size of (batch_size * batch_size). This would be an example of the simplest thing I am trying to achieve:
I would also want to know easy would it be an integration with miners if this is the only way of generating pairs. Thanks in advance. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
What about it looks incorrect?
These approaches should work: from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
from pytorch_metric_learning.miners import TripletMarginMiner
ref_labels = torch.clone(labels)
# All pairs (which get converted to triplets if you're using TripletMarginLoss)
pairs = lmu.get_all_pairs_indices(labels, ref_labels=ref_labels)
loss = loss_fn(query_embeddings, indices_tuple=pairs, ref_emb=doc_embeddings)
# All triplets
triplets = lmu.get_all_triplets_indices(labels, ref_labels=ref_labels)
loss = loss_fn(query_embeddings, indices_tuple=triplets, ref_emb=doc_embeddings)
# All triplets (default behavior of TripletMarginLoss)
loss = loss_fn(query_embeddings, labels=labels, ref_emb=doc_embeddings, ref_labels=ref_labels)
# Using a miner
miner = TripletMarginMiner()
triplets = miner(query_embeddings, labels, ref_emb=doc_embeddings, ref_labels=ref_labels)
loss = loss_fn(query_embeddings, indices_tuple=triplets, ref_emb=doc_embeddings) |
Beta Was this translation helpful? Give feedback.
What about it looks incorrect?
These approaches should work: