diff --git a/README.md b/README.md index a6dfbc48..f15f6b84 100644 --- a/README.md +++ b/README.md @@ -415,6 +415,11 @@ These metrics need the model to generate an output. They are therefore slower. - `maj_at_4_math` (Lighteval): Majority choice evaluation, using the math normalisation for the predictions and gold - `quasi_exact_match_gsm8k` (Harness): Fraction of instances where the normalized prediction matches the normalized gold (normalization done for gsm8k, where latex symbols, units, etc are removed) - `maj_at_8_gsm8k` (Lighteval): Majority choice evaluation, using the gsm8k normalisation for the predictions and gold +- LLM-as-Judge: + - `llm_judge_gpt3p5`: Can be used for any generative task, the model will be scored by a GPT3.5 model using the openai API + - `llm_judge_llama_3_405b`: Can be used for any generative task, the model will be scored by a Llama 3.405B model using the openai API + - `llm_judge_multi_turn_gpt3p5`: Can be used for any generative task, the model will be scored by a GPT3.5 model using the openai API. It is used for multiturn tasks like mt-bench. + - `llm_judge_multi_turn_llama_3_405b`: Can be used for any generative task, the model will be scored by a Llama 3.405B model using the openai API. It is used for multiturn tasks like mt-bench. ### Metrics for specific tasks To keep compatibility with the Harness for some specific tasks, we ported their evaluations more or less as such. They include `drop` (for the DROP dataset) and `truthfulqa_mc_metrics` (for TruthfulQA). In general, except for tasks where the dataset has very different formatting than usual (another language, programming language, math, ...), we want to use standard implementations of the above metrics. It makes little sense to have 10 different versions of an exact match depending on the task. However, most of the above metrics are parametrizable so that you can change the normalization applied easily for experimental purposes. diff --git a/pyproject.toml b/pyproject.toml index 95f74147..e301d7af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,7 @@ tests = ["pytest==7.4.0"] dev = ["lighteval[accelerate,quality,tests]"] extended_tasks = [ "langdetect", # ifeval - "openai", # mt-bench + "openai", # llm as a judge using openai models ] [project.urls] diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index 370758d0..e6d1846f 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -57,7 +57,10 @@ class EnhancedJSONEncoder(json.JSONEncoder): def default(self, o): if is_dataclass(o): - return asdict(o) + try: + return asdict(o) + except Exception: + return str(o) if callable(o): return o.__name__ if isinstance(o, Enum): diff --git a/src/lighteval/metrics/llm_as_judge.py b/src/lighteval/metrics/llm_as_judge.py index 5b70e9d5..ff3dff3b 100644 --- a/src/lighteval/metrics/llm_as_judge.py +++ b/src/lighteval/metrics/llm_as_judge.py @@ -25,56 +25,57 @@ import json import re import time -from typing import Optional +from typing import Any, Optional + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from lighteval.logging.hierarchical_logger import hlog_warn -from lighteval.utils import NO_OPENAI_ERROR_MSG, is_openai_available -class JudgeOpenAI: +class JudgeLM: """ - A class representing a judge for evaluating answers using the OpenAI API. + A class representing a judge for evaluating answers using either the OpeanAI or Transformers library. Args: - model (str): The name of the OpenAI model to use. - seed (int): The seed value for generating random responses. - temperature (float): The temperature value for controlling the randomness of the responses. + model (str): The name of the model to use. templates_path (str): The path to the JSON file containing the templates for prompts. + multi_turn (bool): Whether to use multi-turn prompts + url (Optional[str]): The URL for the OpenAI API. + api_key (Optional[str]): The API key for the OpenAI API (either OpenAI or HF key). Attributes: - client: An instance of the OpenAI client. - model (str): The name of the OpenAI model. - seed (int): The seed value, passed to the API when generating responses. - temperature (float): The temperature value, passed to the API when generating responses. + model (str): The name of the model. templates (dict): A dictionary containing the templates for prompts. one_score_pattern (re.Pattern): A regular expression pattern for extracting scores from the response. one_score_pattern_backup (re.Pattern): A backup regular expression pattern for extracting scores. - API_MAX_RETRY (int): The maximum number of API retries. - API_RETRY_SLEEP (int): The sleep time between API retries. - max_tokens (int): The maximum number of tokens allowed in the response. + API_MAX_RETRY (int): The maximum number of retries for the API. + API_RETRY_SLEEP (int): The sleep time between retries. + client (Optional[OpenAI]): The OpenAI client. + pipe (Optional[pipeline]): The Transformers pipeline. + use_transformers (bool): Whether to use the Transformers library. + url (Optional[str]): The URL for the OpenAI API. + api_key (Optional[str]): The API key for the OpenAI API (either OpenAI or HF key). Methods: - evaluate_answer: Evaluates an answer using the OpenAI API. + evaluate_answer: Evaluates an answer using the OpenAI API or Transformers library. __get_prompts_multi_turn: Generates prompts for multi-turn conversations. __get_prompts_single_turn: Generates prompts for single-turn conversations. __process_judge_response: Processes the judge's response and extracts the score. + __call_openai_api: Calls the OpenAI API to get the judge's response. + __lazy_load_client: Lazy loads the OpenAI client or Transformers pipeline. """ def __init__( self, model: str, - seed: int, - temperature: float, templates_path: str, - openai_api_key: str, multi_turn: bool = False, + url: Optional[str] = None, + api_key: Optional[str] = None, ): - self.client = None # loaded lazily - self.openai_api_key = openai_api_key - self.model = model - self.seed = seed - self.temperature = temperature self.multi_turn = multi_turn + self.model = model data = [] with open(templates_path, "r") as f: @@ -89,40 +90,59 @@ def __init__( # the second is for the backup case: [score] self.one_score_pattern = re.compile(r"\[\[(\d+\.?\d*)\]\]") self.one_score_pattern_backup = re.compile(r"\[(\d+\.?\d*)\]") + self.API_MAX_RETRY = 3 + self.API_RETRY_SLEEP = 1 + + self.client = None + self.pipe = None + + self.use_transformers = url is None and api_key is None + + self.url = url + self.api_key = api_key + + def __lazy_load_client(self): + if self.use_transformers: + if self.pipe is None: + transformers_model = AutoModelForCausalLM.from_pretrained( + self.model, torch_dtype=torch.bfloat16, trust_remote_code=False, device_map="cuda" + ) + tokenizer = AutoTokenizer.from_pretrained(self.model) + self.pipe = pipeline( + "text-generation", + model=transformers_model, + tokenizer=tokenizer, + max_new_tokens=50, + ) + else: + if self.client is None: + from openai import OpenAI - self.API_MAX_RETRY = 16 - self.API_RETRY_SLEEP = 10 - self.max_tokens = 2048 + if self.url is None: + self.client = OpenAI(api_key=self.api_key) + else: + self.client = OpenAI(base_url=self.url, api_key=self.api_key) def evaluate_answer( self, questions: list[str], answers: list[str], references: list[str] - ) -> tuple[int, list[dict[str, str]], str]: + ) -> tuple[list[int], list[list[dict[str, str]]], list[str | None | Any]]: """ - Evaluates an answer using the OpenAI API. + Evaluates an answer using either Transformers or OpenAI API. Args: questions (list[str]): A list of questions (can be a list because of multi-turn conversations) answers (list[str]): A list of answers, one for each question. references (list[str]): A list of reference answers, one for each question (sometimes not available) - single_turn (bool): Indicates whether the conversation is single-turn or multi-turn. Returns: A tuple containing the score, prompts, and judgment. - - Raises: - Exception: If an error occurs during the API call. """ - if self.client is None: - if not is_openai_available(): - raise ImportError(NO_OPENAI_ERROR_MSG) - - from openai import OpenAI - - self.client = OpenAI(api_key=self.openai_api_key) + # lazy loading of the pipeline + self.__lazy_load_client() prompts = [ self.__get_prompts_single_turn( - questions[0], answers[0], references[0] if references is not None and len(references) > 0 else None + questions[0], answers[0], references[0] if references and len(references) > 0 else None ) ] @@ -132,28 +152,15 @@ def evaluate_answer( ) prompts.append(prompts_multi_turn) - responses = [] + judgments = [] for prompt in prompts: - for _ in range(self.API_MAX_RETRY): - try: - response = self.client.chat.completions.create( - model=self.model, - seed=self.seed, - temperature=self.temperature, - messages=prompt, - max_tokens=self.max_tokens, - n=1, - ) - responses.append(response) - break - except Exception as e: - hlog_warn(f"{type(e), e}") - time.sleep(self.API_RETRY_SLEEP) - - if len(responses) == 0: - raise Exception("Failed to get response from the API") - - judgments = [response.choices[0].message.content for response in responses] + if self.client is not None: + response = self.__call_api(prompt) + else: + response = self.pipe(prompt)[0]["generated_text"] + response = response[-1]["content"] + judgments.append(response) + scores = [self.__process_judge_response(judgment) for judgment in judgments] return scores, prompts, judgments @@ -235,3 +242,18 @@ def __process_judge_response(self, judgment: str) -> int: rating = -1 return rating + + def __call_api(self, prompt): + for _ in range(self.API_MAX_RETRY): + try: + response = self.client.chat.completions.create( + model=self.model, + messages=prompt, + max_tokens=512, + n=1, + ) + return response.choices[0].message.content + except Exception as e: + hlog_warn(f"{type(e), e}") + time.sleep(self.API_RETRY_SLEEP) + raise Exception("Failed to get response from the API") diff --git a/src/lighteval/metrics/metrics.py b/src/lighteval/metrics/metrics.py index 8b06e45c..068e086e 100644 --- a/src/lighteval/metrics/metrics.py +++ b/src/lighteval/metrics/metrics.py @@ -228,9 +228,9 @@ class Metrics(Enum): corpus_level_fn=np.mean, higher_is_better=True, ) - llm_judge_multi_turn_openai = SampleLevelMetricGrouping( + llm_judge_multi_turn_gpt3p5 = SampleLevelMetricGrouping( metric_name=["single_turn", "multi_turn"], - higher_is_better=True, + higher_is_better={"single_turn": True, "multi_turn": True}, category=MetricCategory.LLM_AS_JUDGE_MULTI_TURN, use_case=MetricUseCase.SUMMARIZATION, sample_level_fn=JudgeLLM( @@ -243,9 +243,24 @@ class Metrics(Enum): "multi_turn": np.mean, }, ) - llm_judge_openai = SampleLevelMetricGrouping( + llm_judge_multi_turn_llama_3_405b = SampleLevelMetricGrouping( + metric_name=["single_turn", "multi_turn"], + higher_is_better={"single_turn": True, "multi_turn": True}, + category=MetricCategory.LLM_AS_JUDGE_MULTI_TURN, + use_case=MetricUseCase.SUMMARIZATION, + sample_level_fn=JudgeLLM( + judge_model_name="meta-llama/Meta-Llama-3.1-405B-Instruct-FP8", + template_path=os.path.join(os.path.dirname(__file__), "judge_prompts.jsonl"), + multi_turn=True, + ).compute, + corpus_level_fn={ + "single_turn": np.mean, + "multi_turn": np.mean, + }, + ) + llm_judge_gpt3p5 = SampleLevelMetricGrouping( metric_name=["judge_score"], - higher_is_better=True, + higher_is_better={"judge_score": True}, category=MetricCategory.LLM_AS_JUDGE, use_case=MetricUseCase.SUMMARIZATION, sample_level_fn=JudgeLLM( @@ -257,6 +272,20 @@ class Metrics(Enum): "judge_score": np.mean, }, ) + llm_judge_llama_3_405b = SampleLevelMetricGrouping( + metric_name=["judge_score"], + higher_is_better={"judge_score": True}, + category=MetricCategory.LLM_AS_JUDGE, + use_case=MetricUseCase.SUMMARIZATION, + sample_level_fn=JudgeLLM( + judge_model_name="meta-llama/Meta-Llama-3.1-405B-Instruct-FP8", + template_path=os.path.join(os.path.dirname(__file__), "judge_prompts.jsonl"), + multi_turn=False, + ).compute, + corpus_level_fn={ + "judge_score": np.mean, + }, + ) loglikelihood_acc = SampleLevelMetric( metric_name="acc", sample_level_fn=LoglikelihoodAcc().compute, diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index a240166b..c729655c 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -29,6 +29,7 @@ import nltk import numpy as np +from huggingface_hub import HfApi from nltk.metrics.distance import edit_distance from nltk.tokenize import word_tokenize from nltk.tokenize.treebank import TreebankWordTokenizer @@ -40,7 +41,7 @@ from lighteval.metrics.imports.bert_scorer import BERTScorer from lighteval.metrics.imports.data_stats_metric import DataStatsMetric from lighteval.metrics.imports.summac import SummaCZS -from lighteval.metrics.llm_as_judge import JudgeOpenAI +from lighteval.metrics.llm_as_judge import JudgeLM from lighteval.metrics.normalizations import remove_braces, remove_braces_and_strip from lighteval.tasks.requests import Doc from lighteval.utils import as_list @@ -626,22 +627,32 @@ def edit_similarity(self, s1, s2): class JudgeLLM: - available_models = ["gpt-3.5-turbo", "gpt-4o", "gpt-4-turbo", "gpt-4"] + available_models_openai = ["gpt-3.5-turbo", "gpt-4o", "gpt-4-turbo", "gpt-4"] - def __init__(self, judge_model_name: str, template_path: str, multi_turn: bool = False): - if judge_model_name not in self.available_models: - raise ValueError(f"{judge_model_name} not in available models for llm as a judge metric") + def __init__( + self, judge_model_name: str, template_path: str, multi_turn: bool = False, use_transformers: bool = False + ) -> None: + if judge_model_name in self.available_models_openai: + api_key = os.getenv("OPENAI_API_KEY") + url = None + elif not use_transformers: + api_key = os.getenv("HF_TOKEN") + url = "https://api-inference.huggingface.co/v1/" + else: + api = HfApi() + models = api.list_models(model_name=judge_model_name) + url = None + api_key = None + if not models: + raise ValueError(f"{judge_model_name} not in available models for llm as a judge metric") - OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") self.multi_turn = multi_turn - - self.judge = JudgeOpenAI( + self.judge = JudgeLM( model=judge_model_name, - seed=42, - temperature=0.0, templates_path=template_path, - openai_api_key=OPENAI_API_KEY, multi_turn=multi_turn, + api_key=api_key, + url=url, ) def compute(self, predictions: list[str], formatted_doc: Doc, **kwargs) -> dict[str, float]: diff --git a/src/lighteval/tasks/extended/mt_bench/main.py b/src/lighteval/tasks/extended/mt_bench/main.py index 77b8f3ee..03bff898 100644 --- a/src/lighteval/tasks/extended/mt_bench/main.py +++ b/src/lighteval/tasks/extended/mt_bench/main.py @@ -23,6 +23,7 @@ # ruff: noqa: F405, F403, F401, I001 from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc +from lighteval.metrics.metrics import Metrics def mt_bench_prompt(line, task_name: str = None): @@ -55,7 +56,7 @@ def mt_bench_prompt(line, task_name: str = None): evaluation_splits=["train"], few_shots_split="", few_shots_select="random", - metric=["llm_judge_multi_turn_openai"], + metric=[Metrics.llm_judge_multi_turn_gpt3p5], generation_size=1024, stop_sequence=[], ) diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 70135742..07120b71 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -22,7 +22,6 @@ import collections import inspect -import os import random from dataclasses import asdict, dataclass from multiprocessing import Pool @@ -56,7 +55,7 @@ RequestType, TaskExampleId, ) -from lighteval.utils import NO_OPENAI_ERROR_MSG, as_list, is_openai_available +from lighteval.utils import as_list if TYPE_CHECKING: @@ -191,16 +190,6 @@ def __init__( # noqa: C901 current_categories = [metric.category for metric in self.metrics] self.has_metric_category = {category: (category in current_categories) for category in MetricCategory} - if ( - self.has_metric_category[MetricCategory.LLM_AS_JUDGE] - or self.has_metric_category[MetricCategory.LLM_AS_JUDGE_MULTI_TURN] - ): - if not is_openai_available(): - raise ImportError(NO_OPENAI_ERROR_MSG) - if os.getenv("OPENAI_API_KEY") is None: - raise ValueError( - "Using llm as judge metric but no OPEN_API_KEY were found, please set it with: export OPEN_API_KEY={yourkey}" - ) # We assume num_samples always contains 1 (for base generative evals) self.num_samples = [1] diff --git a/src/lighteval/tasks/requests.py b/src/lighteval/tasks/requests.py index 6dd30786..2bd69023 100644 --- a/src/lighteval/tasks/requests.py +++ b/src/lighteval/tasks/requests.py @@ -136,6 +136,7 @@ class GreedyUntilMultiTurnRequest(Request): stop_sequence: str generation_size: int request_type = RequestType.GREEDY_UNTIL_MULTI_TURN + use_logits: bool = False class TaskExampleId(NamedTuple):