Skip to content

Commit

Permalink
Fix a tiny bug in DROP metric (#229)
Browse files Browse the repository at this point in the history
* Fix the metric

* Apply comment and Ruff
  • Loading branch information
sadra-barikbin authored Jul 18, 2024
1 parent 44f9a46 commit 66ed7a2
Showing 1 changed file with 26 additions and 11 deletions.
37 changes: 26 additions & 11 deletions src/lighteval/metrics/harness_compatibility/drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,30 @@

import re
import string
from typing import List, Set, Tuple

import numpy as np
from scipy.optimize import linear_sum_assignment


def drop_metrics(predictions: list[str], formatted_doc, **kwargs): # noqa: C901
"""F1 score from bag of words: comes from Harness Drop
"""F1 score from bag of words: comes from Harness Drop. DROP offers two metrics,
a quasi exact match and a numeracy-focused F1 score. Quasi in the sense that it
does some normalizations before matching and numeracy-focused in the sense that
if there's number mismatch between the target and prediction F1 score is set to 0.
F1 score is computed using the intersection of target and prediction's BoW
representations with the additional spice that if the answer and/or prediction is
comprised of multiple spans, a greedy matching is done between the two sets of spans
(based on the very BoW overlap) and the average over F1 of pairs is returned.
DROP also accepts multiple answers in which case, the maximum of F1/ Exact Match
between prediction and the different answers is taken.
For more information, please refer to the section 5 of the DROP paper (https://aclanthology.org/N19-1246/).
Todo: this code is really hard to follow, simplify when possible
"""

def _answer_to_bags(answer):
def _answer_to_bags(answer: List[str]) -> Tuple[List[str], List[Set[str]]]:
if isinstance(answer, (list, tuple)):
raw_spans = answer
else:
Expand All @@ -45,23 +58,25 @@ def _answer_to_bags(answer):
token_bags.append(set(normalized_span.split()))
return normalized_spans, token_bags

def _get_metrics(predicted, gold):
def _get_metrics(predicted: List[str], gold: List[str]):
"""
Takes a predicted answer and a gold answer (that are both either a string or a list of
strings), and returns exact match and the DROP F1 metric for the prediction. If you are
writing a script for evaluating objects in memory (say, the output of predictions during
validation, or while training), this is the function you want to call, after using
:func:`answer_json_to_strings` when reading the gold answer from the released data file.
"""
predicted_bags = _answer_to_bags(predicted)
gold_bags = _answer_to_bags(gold)
pred_normalized_spans, pred_bags = _answer_to_bags(predicted)
gold_normalized_spans, gold_bags = _answer_to_bags(gold)

if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
if set(pred_normalized_spans) == set(gold_normalized_spans) and len(gold_normalized_spans) == len(
gold_normalized_spans
):
exact_match = 1.0
else:
exact_match = 0.0

f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
f1_per_bag = _align_bags(pred_bags, gold_bags)
f1 = np.mean(f1_per_bag)
f1 = round(f1, 2)
return exact_match, f1
Expand All @@ -73,7 +88,7 @@ def _is_number(text):
except ValueError:
return False

def _match_numbers_if_present(gold_bag, predicted_bag):
def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]):
gold_numbers = set()
predicted_numbers = set()
for word in gold_bag:
Expand All @@ -86,7 +101,7 @@ def _match_numbers_if_present(gold_bag, predicted_bag):
return True
return False

def _align_bags(predicted, gold):
def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> np.array:
"""
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
between them and gets maximum metric values over all the answers.
Expand Down Expand Up @@ -136,7 +151,7 @@ def _fix_number(text):
def _tokenize(text):
return re.split(" |-", text)

def _normalize(answer):
def _normalize(answer: str):
tokens = [
_white_space_fix(_remove_articles(_fix_number(_remove_punc(token.lower())))) for token in _tokenize(answer)
]
Expand All @@ -147,9 +162,9 @@ def _normalize(answer):
max_em = 0
max_f1 = 0
for gold_answer in formatted_doc.specific["golds_no_preprocessing"]:
exact_match, f1_score = _get_metrics(predictions, gold_answer)
if isinstance(gold_answer, list):
gold_answer = gold_answer[0]
exact_match, f1_score = _get_metrics(predictions, gold_answer)
if gold_answer.strip():
max_em = max(max_em, exact_match)
max_f1 = max(max_f1, f1_score)
Expand Down

0 comments on commit 66ed7a2

Please sign in to comment.