Skip to content

Commit

Permalink
Add maj@k metric (#158)
Browse files Browse the repository at this point in the history

Co-authored-by: Nathan Habib <[email protected]>

* added review change

---------

Co-authored-by: Nathan Habib <[email protected]>
  • Loading branch information
clefourrier and NathanHB authored Apr 30, 2024
1 parent 9806093 commit 0a455c4
Show file tree
Hide file tree
Showing 16 changed files with 270 additions and 213 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ These metrics need the model to generate an output. They are therefore slower.
- `f1_score`: Average F1 score in terms of word overlap between the model output and gold without normalisation
- `f1_score_macro`: Corpus level macro F1 score
- `f1_score_macro`: Corpus level micro F1 score
- `maj_at_5` and `maj_at_8`: Model majority vote. Takes n (5 or 8) generations from the model and assumes the most frequent is the actual prediction.
- Summarization:
- `rouge` (Harness): Average ROUGE score [(Lin, 2004)](https://aclanthology.org/W04-1013/)
- `rouge1` (HELM): Average ROUGE score [(Lin, 2004)](https://aclanthology.org/W04-1013/) based on 1-gram overlap.
Expand All @@ -376,7 +377,9 @@ These metrics need the model to generate an output. They are therefore slower.
- `edit_similarity`: average Levenshtein edit similarity (normalized by length of longer sequence) between model generation and reference.
- Math:
- `quasi_exact_match_math` (HELM): Fraction of instances where the normalized prediction matches the normalized gold (normalization done for math, where latex symbols, units, etc are removed)
- `maj_at_4_math` (Lighteval): Majority choice evaluation, using the math normalisation for the predictions and gold
- `quasi_exact_match_gsm8k` (Harness): Fraction of instances where the normalized prediction matches the normalized gold (normalization done for gsm8k, where latex symbols, units, etc are removed)
- `maj_at_8_gsm8k` (Lighteval): Majority choice evaluation, using the gsm8k normalisation for the predictions and gold

### Metrics for specific tasks
To keep compatibility with the Harness for some specific tasks, we ported their evaluations more or less as such. They include `drop` (for the DROP dataset) and `truthfulqa_mc_metrics` (for TruthfulQA). In general, except for tasks where the dataset has a very different formatting than usual (an other language, programming language, math, ...), we want to use standard implementations of the above metrics. It makes little sense to have 10 different versions of an exact match depending on the task. However, most of the above metrics are parametrizable so that you can change the normalization applied easily for experimental purposes.
Expand Down
3 changes: 1 addition & 2 deletions src/lighteval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from lighteval.logging.hierarchical_logger import hlog_warn
from lighteval.tasks.requests import (
GreedyUntilRequest,
GreedyUntilWithLogitsRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
Expand Down Expand Up @@ -205,7 +204,7 @@ def _sorting_criteria(self, request: LoglikelihoodSingleTokenRequest) -> int:


class GenerativeTaskDataset(DynamicBatchDataset):
def _sorting_criteria(self, request: GreedyUntilRequest | GreedyUntilWithLogitsRequest) -> int:
def _sorting_criteria(self, request: GreedyUntilRequest) -> int:
"""
Collate function for generating batches.
Expand Down
6 changes: 2 additions & 4 deletions src/lighteval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,10 @@ def evaluate( # noqa: C901
full_resps = lm.loglikelihood(requests, override_bs=override_bs)
elif request_type == RequestType.LOGLIKELIHOOD_SINGLE_TOKEN:
full_resps = lm.loglikelihood_single_token(requests, override_bs=override_bs)
elif request_type == RequestType.GREEDY_UNTIL:
full_resps = lm.greedy_until(requests, override_bs=override_bs)
elif request_type == RequestType.GREEDY_UNTIL_WITH_LOGITS:
full_resps = lm.greedy_until_with_logits(requests, override_bs=override_bs)
elif request_type == RequestType.LOGLIKELIHOOD_ROLLING:
full_resps = lm.loglikelihood_rolling(requests, override_bs=override_bs)
elif request_type == RequestType.GREEDY_UNTIL:
full_resps = lm.greedy_until(requests, override_bs=override_bs)
elif request_type == RequestType.GREEDY_UNTIL_MULTI_TURN:
full_resps = lm.greedy_until_multi_turn(requests, override_bs=override_bs)
else:
Expand Down
5 changes: 4 additions & 1 deletion src/lighteval/logging/info_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,10 @@ def log(
):
pred_saved = True
pass # should we log something?
if task.has_metric_category[MetricCategory.GENERATIVE]:
if (
task.has_metric_category[MetricCategory.GENERATIVE]
or task.has_metric_category[MetricCategory.GENERATIVE_SAMPLING]
):
detail.gold = doc.get_golds()
pred_saved = True
if task.has_metric_category[MetricCategory.GENERATIVE_LOGPROB]:
Expand Down
53 changes: 30 additions & 23 deletions src/lighteval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,21 @@ def apply_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metr
return results, outputs


def apply_generative_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[str], output_regex=None):
def apply_generative_metric(
results: list[ModelReturn], formatted_doc: Doc, metrics: list[str], output_regex=None, max_num_samples=1
):
outputs = {}

# Post processing prediction
pred_raw = results.pop(0).result
if output_regex is not None:
pred = next(iter(re.findall(output_regex, pred_raw)), "")
else:
pred = pred_raw
pred = as_list(pred)
preds_raw = as_list(results.pop(0).result)
preds = []

for pred_raw in preds_raw:
if output_regex is not None:
pred = next(iter(re.findall(output_regex, pred_raw)), "")
else:
pred = pred_raw
preds.append(pred)

# Extracting gold
try:
Expand All @@ -87,23 +92,28 @@ def apply_generative_metric(results: list[ModelReturn], formatted_doc: Doc, metr
# if "label_to_choices" in formatted_doc:
if formatted_doc.specific is not None and "label_to_choices" in formatted_doc.specific:
# Helm predicts on labels keys (A/B/C/D), but computes metrics on choices
pred = [formatted_doc.specific["label_to_choices"].get(p) for p in pred]
preds = [formatted_doc.specific["label_to_choices"].get(p) for p in preds]
golds = [formatted_doc.specific["label_to_choices"][g] for g in golds]

for metric in metrics:
if Metrics[metric].value.category == MetricCategory.GENERATIVE:
outputs.update(Metrics[metric].value.compute(golds=golds, predictions=pred, formatted_doc=formatted_doc))

return results, outputs


def apply_generative_logprob_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[str]):
# Applied to no metric atm, but we have the model side logic
outputs = {}

for metric in metrics:
outputs.update(
Metrics[metric].value.compute(
golds=golds,
predictions=as_list(preds[0]) if max_num_samples > 0 else preds,
formatted_doc=formatted_doc,
)
)
if Metrics[metric].value.category == MetricCategory.GENERATIVE_LOGPROB:
outputs.update(Metrics[metric].value.compute(results=results, formatted_doc=formatted_doc))
outputs.update(
Metrics[metric].value.compute(
golds=golds,
predictions=as_list(preds[0]) if max_num_samples > 0 else preds,
formatted_doc=formatted_doc,
)
)
if Metrics[metric].value.category == MetricCategory.GENERATIVE_SAMPLING:
outputs.update(Metrics[metric].value.compute(golds=golds, predictions=preds, formatted_doc=formatted_doc))

return results, outputs

Expand Down Expand Up @@ -153,10 +163,7 @@ def apply_llm_as_judge_metric(results: list[ModelReturn], formatted_doc: Doc, me
predictions = results.pop(0).result

for metric in metrics:
if (
Metrics[metric].value.category == MetricCategory.LLM_AS_JUDGE_MULTI_TURN
or Metrics[metric].value.category == MetricCategory.LLM_AS_JUDGE
):
if Metrics[metric].value.category in [MetricCategory.LLM_AS_JUDGE_MULTI_TURN, MetricCategory.LLM_AS_JUDGE]:
outputs.update(Metrics[metric].value.compute(predictions=predictions, formatted_doc=formatted_doc))

return results, outputs
37 changes: 37 additions & 0 deletions src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
F1_score,
JudgeLLM,
LoglikelihoodAcc,
MajAtK,
Recall,
StringDistance,
acc_golds_likelihood,
Expand Down Expand Up @@ -326,6 +327,42 @@ class Metrics(Enum):
corpus_level_fn=matthews_corrcoef,
higher_is_better=True,
)
maj_at_4_math = SampleLevelMetric(
metric="maj@4",
sample_level_fn=MajAtK(
k=4, strip_strings=True, normalize_pred=math_normalizer, normalize_gold=math_normalizer_gold
).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.MATH,
corpus_level_fn=np.mean,
higher_is_better=True,
)
maj_at_5 = SampleLevelMetric(
metric="maj@5",
sample_level_fn=MajAtK(k=5).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.ACCURACY,
corpus_level_fn=np.mean,
higher_is_better=True,
)
maj_at_8 = SampleLevelMetric(
metric="maj@8",
sample_level_fn=MajAtK(k=8).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.ACCURACY,
corpus_level_fn=np.mean,
higher_is_better=True,
)
maj_at_8_gsm8k = SampleLevelMetric(
metric="maj@8",
sample_level_fn=MajAtK(
k=8, strip_strings=True, normalize_pred=gsm8k_normalizer, normalize_gold=gsm8k_normalizer
).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.MATH,
corpus_level_fn=np.mean,
higher_is_better=True,
)
mrr = SampleLevelMetric(
metric="mrr",
sample_level_fn=MRR().compute,
Expand Down
86 changes: 86 additions & 0 deletions src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,3 +675,89 @@ def compute(self, predictions: list[str], formatted_doc: Doc, **kwargs) -> dict[
"user_prompt": messages[0],
"judgement": judgements[0],
}


class MajAtK:
def __init__(
self,
k: int,
normalize_gold: callable = None,
normalize_pred: callable = None,
strip_strings: bool = False,
type_exact_match: str = "full",
):
"""An exact match class.
Args:
normalize_gold (callable, optional): Function to use to normalize the reference strings.
Defaults to None if no normalization is applied.
normalize_pred (callable, optional): Function to use to normalize the predicted strings.
Defaults to None if no normalization is applied.
strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False.
type_exact_match (str, optional): Defines what type of match to apply (post normalization if present).
Can be any of `prefix`, `suffix` or `full`. Defaults to "full".
`prefix` checks if the prediction starts with the gold,
`suffix` if the prediction ends with the gold,
`full` if the prediction and gold are equal
"""
self.k = k
self.normalize_gold = normalize_gold
self.normalize_pred = normalize_pred
self.strip_strings = strip_strings

if type_exact_match not in ["prefix", "suffix", "full"]:
# todo: we could add a set exact match
raise ValueError(
f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {type_exact_match} instead."
)
self.type_exact_match = type_exact_match

def compute(self, golds: list[str], predictions: list[str], **kwargs) -> dict[str, float]:
"""Computes the metric over a list of golds and predictions for one single sample.
It applies normalisation (if needed) to model prediction and gold, and takes the most frequent answer of all the available ones,
then compares it to the gold.
Args:
golds (list[str]): Reference targets
predictions (list[str]): k predicted strings
Returns:
float: Aggregated score over the current sample's items.
"""
if len(golds) > 1:
raise Exception("Cannot compute maj@k with several golds")

gold = self.get_processed_gold(golds[0])
all_answers = []
for pred in predictions[: self.k]:
all_answers.append(self.get_processed_pred(pred=pred))
majority_prediction = max(all_answers, key=all_answers.count)
return self.compute_score(majority_prediction, gold)

def get_processed_gold(self, gold: str) -> float:
if self.strip_strings:
gold = gold.strip()

if self.normalize_gold:
gold = self.normalize_gold(gold)

return gold

def get_processed_pred(self, pred: str) -> float:
if not pred:
return ""

if self.strip_strings:
pred = pred.strip()

if self.normalize_pred:
pred = self.normalize_pred(pred)

return pred

def compute_score(self, pred: str, gold: str) -> int:
if self.type_exact_match == "prefix":
return 1 if pred.startswith(gold) else 0
if self.type_exact_match == "suffix":
return 1 if pred.endswith(gold) else 0
return 1 if gold == pred else 0
3 changes: 2 additions & 1 deletion src/lighteval/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ class MetricCategory(Enum):
TARGET_PERPLEXITY = auto()
PERPLEXITY = auto()
GENERATIVE = auto()
GENERATIVE_LOGPROB = auto()
GENERATIVE_SAMPLING = auto()
LLM_AS_JUDGE_MULTI_TURN = auto()
LLM_AS_JUDGE = auto()
GENERATIVE_LOGPROB = auto()
MULTICHOICE = auto()
MULTICHOICE_ONE_TOKEN = auto()
IGNORED = auto()
Expand Down
28 changes: 0 additions & 28 deletions src/lighteval/models/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from lighteval.tasks.requests import (
GreedyUntilMultiTurnRequest,
GreedyUntilRequest,
GreedyUntilWithLogitsRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
Expand Down Expand Up @@ -83,31 +82,6 @@ def max_length(self) -> int:
def disable_tqdm(self) -> bool:
raise NotImplementedError

def greedy_until_with_logits(
self,
requests: list[GreedyUntilWithLogitsRequest],
override_bs: Optional[int] = None,
) -> list[GenerateReturn]:
"""
Generates sequences greedily until a stopping condition is met,
returning both the generated sequences and the logits.
Args:
requests (list[tuple[str, dict]]): A list of input requests,
where each request is a tuple containing a prompt string and a dictionary of additional parameters.
disable_tqdm (bool, optional): Whether to disable the tqdm progress bar. Defaults to False.
override_bs (Optional[int], optional): Overrides the batch size for generation. Defaults to None.
Returns:
list[GenerateReturn]: A list of GenerateReturn objects,
where each object contains the generated sequence and the corresponding logits.
"""
return self.greedy_until(
requests=requests,
override_bs=override_bs,
returns_logits=True,
)

def greedy_until_multi_turn( # noqa: C901
self, requests: list[GreedyUntilMultiTurnRequest], override_bs: Optional[int] = None
) -> GenerateMultiTurnReturn:
Expand All @@ -118,15 +92,13 @@ def greedy_until_multi_turn( # noqa: C901
def greedy_until(
self,
requests: list[GreedyUntilRequest],
returns_logits: bool = False,
override_bs: Optional[int] = None,
) -> list[GenerateReturn]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met.
Args:
requests (list[Request]): list of requests containing the context and ending conditions.
returns_logits (bool, optional): Whether to return the logits of the generated responses. Defaults to False.
disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False.
override_bs (int, optional): Override the batch size for generation. Defaults to None.
Expand Down
Loading

0 comments on commit 0a455c4

Please sign in to comment.