-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DebiasedMultipleNegativesRankingLoss
to the losses
#3148
base: master
Are you sure you want to change the base?
Conversation
TODO: Prepare PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello!
I'm excited to try this out some more. I created a simple training script with it, but I'm getting losses of 0.
The script:
import argparse
import random
from datasets import load_dataset
import numpy
import torch
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import NanoBEIREvaluator
from sentence_transformers.losses.DebiasedMultipleNegativesRankingLoss import DebiasedMultipleNegativesRankingLoss
from sentence_transformers.losses.MultipleNegativesRankingLoss import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
def main():
# parse the lr & model name
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=8e-5)
parser.add_argument("--model_name", type=str, default="bert-base-uncased")
parser.add_argument("--loss", type=str, default="debiased-mnrl")
args = parser.parse_args()
lr = args.lr
model_name = args.model_name
loss_name = args.loss
model_shortname = model_name.split("/")[-1]
seed = 12
random.seed(seed)
torch.manual_seed(seed)
numpy.random.seed(seed)
# 1. Load a model to finetune
model = SentenceTransformer(model_name)
# 2. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/natural-questions", split="train")
dataset_dict = dataset.train_test_split(test_size=1_000, seed=seed)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
# 3. Define a loss function
if loss_name == "mnrl":
loss = MultipleNegativesRankingLoss(model)
elif loss_name == "debiased-mnrl":
loss = DebiasedMultipleNegativesRankingLoss(model)
else:
raise ValueError(f"Loss {loss_name} not supported")
run_name = f"{model_shortname}-nq-{loss_name}"
# 4. (Optional) Specify training arguments
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=f"output/{model_shortname}/{run_name}",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=128,
per_device_eval_batch_size=128,
learning_rate=lr,
warmup_ratio=0.05,
fp16=False, # Set to False if GPU can't handle FP16
bf16=True, # Set to True if GPU supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # (Cached)MultipleNegativesRankingLoss benefits from no duplicates
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=50,
save_strategy="steps",
save_steps=50,
save_total_limit=2,
logging_steps=10,
seed=seed,
run_name=run_name, # Used in `wandb`, `tensorboard`, `neptune`, etc. if installed
)
# 5. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = NanoBEIREvaluator(dataset_names=["MSMARCO", "HotpotQA"])
dev_evaluator(model)
# 6. Create a trainer & train
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# 7. (Optional) Evaluate the trained model on the evaluator after training
dev_evaluator(model)
# 8. Save the model
model.save_pretrained(f"output/{model_shortname}/{run_name}/final")
# 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name, private=False)
if __name__ == "__main__":
main()
Which results in:
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 2.0512820512820512e-05, 'epoch': 0.01}
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 4.1025641025641023e-05, 'epoch': 0.03}
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 6.153846153846155e-05, 'epoch': 0.04}
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 7.989145183175034e-05, 'epoch': 0.05}
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 7.880597014925374e-05, 'epoch': 0.06}
Compared to MNRL, which has these logs:
{'loss': 4.0986, 'grad_norm': 19.445810317993164, 'learning_rate': 2.0512820512820512e-05, 'epoch': 0.01}
{'loss': 2.2274, 'grad_norm': 14.608287811279297, 'learning_rate': 4.1025641025641023e-05, 'epoch': 0.03}
{'loss': 1.1188, 'grad_norm': 7.870720863342285, 'learning_rate': 6.153846153846155e-05, 'epoch': 0.04}
{'loss': 0.512, 'grad_norm': 4.837278366088867, 'learning_rate': 7.989145183175034e-05, 'epoch': 0.05}
{'loss': 0.2806, 'grad_norm': 4.354431629180908, 'learning_rate': 7.880597014925374e-05, 'epoch': 0.06}
I believe it's because N_neg * g
is substantially small (127 and a tensor of 2.0612e-09) such that pos_exp / (pos_exp + N_neg * g)
(where pos_exp
is mostly values around 5e+07) is essentially 1.
Could you perhaps look into this?
- Tom Aarsen
# Compute the g estimator with the exponential of the similarities. | ||
N_neg = scores.size(1) - 1 # Number of negatives | ||
g = torch.clamp((1 / (1 - self.tau_plus)) * ((neg_exp / N_neg) - (self.tau_plus * pos_exp)), | ||
min=torch.exp(-torch.tensor(self.scale))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
min=torch.exp(-torch.tensor(self.scale))) | |
min=self.scale) |
The torch.tensor
call results in device mismatches when training on GPUs. Perhaps this can be simplified to just the above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much better, I totally forgot about keeping the exp
- that's quite important, my bad.
This PR introduces the Debiased Contrastive Loss, from the paper "Debiased Contrastive Learning" (Chuang et al., NeurIPS 2020). The purpose of this loss is to reduce false negative bias, which occurs when negative samples in the dataset are semantically similar to the anchor. Such bias can harm the quality of embeddings and reduce performance in downstream tasks, as shown in the paper's results.
The integration follows the same structure as other losses in the
losses
package, with full documentation and acitation
method to reference the original work. This loss is an improved version ofMultipleNegativesRankingLoss
with an additional hyper-parametertau_plus
that controls the bias correction. Thus, it's compatible with methods likeGenQ
(see Query Generation Example).In this implementation, I focus on the case where$M = 1$ , meaning each anchor has one positive sample. This approach can be extended to handle multiple positive samples $M \geq 1$ , which could be a direction for future development. (Here, $M$ refers to the number of positive examples associated with each anchor)