Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a tiny bug in DROP metric #229

Merged
merged 2 commits into from
Jul 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions src/lighteval/metrics/harness_compatibility/drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,30 @@

import re
import string
from typing import List, Set, Tuple

import numpy as np
from scipy.optimize import linear_sum_assignment


def drop_metrics(predictions: list[str], formatted_doc, **kwargs): # noqa: C901
"""F1 score from bag of words: comes from Harness Drop
"""F1 score from bag of words: comes from Harness Drop. DROP offers two metrics,
a quasi exact match and a numeracy-focused F1 score. Quasi in the sense that it
does some normalizations before matching and numeracy-focused in the sense that
if there's number mismatch between the target and prediction F1 score is set to 0.
F1 score is computed using the intersection of target and prediction's BoW
representations with the additional spice that if the answer and/or prediction is
comprised of multiple spans, a greedy matching is done between the two sets of spans
(based on the very BoW overlap) and the average over F1 of pairs is returned.
DROP also accepts multiple answers in which case, the maximum of F1/ Exact Match
between prediction and the different answers is taken.

For more information, please refer to the section 5 of the DROP paper (https://aclanthology.org/N19-1246/).

Todo: this code is really hard to follow, simplify when possible
"""

def _answer_to_bags(answer):
def _answer_to_bags(answer: List[str]) -> Tuple[List[str], List[Set[str]]]:
if isinstance(answer, (list, tuple)):
raw_spans = answer
else:
Expand All @@ -45,23 +58,25 @@ def _answer_to_bags(answer):
token_bags.append(set(normalized_span.split()))
return normalized_spans, token_bags

def _get_metrics(predicted, gold):
def _get_metrics(predicted: List[str], gold: List[str]):
"""
Takes a predicted answer and a gold answer (that are both either a string or a list of
strings), and returns exact match and the DROP F1 metric for the prediction. If you are
writing a script for evaluating objects in memory (say, the output of predictions during
validation, or while training), this is the function you want to call, after using
:func:`answer_json_to_strings` when reading the gold answer from the released data file.
"""
predicted_bags = _answer_to_bags(predicted)
gold_bags = _answer_to_bags(gold)
pred_normalized_spans, pred_bags = _answer_to_bags(predicted)
gold_normalized_spans, gold_bags = _answer_to_bags(gold)

if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
if set(pred_normalized_spans) == set(gold_normalized_spans) and len(gold_normalized_spans) == len(
gold_normalized_spans
):
exact_match = 1.0
else:
exact_match = 0.0

f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
f1_per_bag = _align_bags(pred_bags, gold_bags)
f1 = np.mean(f1_per_bag)
f1 = round(f1, 2)
return exact_match, f1
Expand All @@ -73,7 +88,7 @@ def _is_number(text):
except ValueError:
return False

def _match_numbers_if_present(gold_bag, predicted_bag):
def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]):
gold_numbers = set()
predicted_numbers = set()
for word in gold_bag:
Expand All @@ -86,7 +101,7 @@ def _match_numbers_if_present(gold_bag, predicted_bag):
return True
return False

def _align_bags(predicted, gold):
def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> np.array:
"""
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
between them and gets maximum metric values over all the answers.
Expand Down Expand Up @@ -136,7 +151,7 @@ def _fix_number(text):
def _tokenize(text):
return re.split(" |-", text)

def _normalize(answer):
def _normalize(answer: str):
tokens = [
_white_space_fix(_remove_articles(_fix_number(_remove_punc(token.lower())))) for token in _tokenize(answer)
]
Expand All @@ -147,9 +162,9 @@ def _normalize(answer):
max_em = 0
max_f1 = 0
for gold_answer in formatted_doc.specific["golds_no_preprocessing"]:
exact_match, f1_score = _get_metrics(predictions, gold_answer)
if isinstance(gold_answer, list):
gold_answer = gold_answer[0]
exact_match, f1_score = _get_metrics(predictions, gold_answer)
if gold_answer.strip():
max_em = max(max_em, exact_match)
max_f1 = max(max_f1, f1_score)
Expand Down
Loading