diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index e605a064929760..e2bb8d5e688f2d 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -627,6 +627,14 @@ def post_processing_function(examples, features, predictions, stage="eval"): references = [{"id": str(ex["id"]), "answers": ex[answer_column_name]} for ex in examples] return EvalPrediction(predictions=formatted_predictions, label_ids=references) + if data_args.version_2_with_negative: + accepted_best_metrics = ("exact", "f1", "HasAns_exact", "HasAns_f1") + else: + accepted_best_metrics = ("exact_match", "f1") + + if training_args.load_best_model_at_end and training_args.metric_for_best_model not in accepted_best_metrics: + warnings.warn(f"--metric_for_best_model should be set to one of {accepted_best_metrics}") + metric = evaluate.load( "squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir )