Skip to content

Commit

Permalink
LightevalTask now only takes functions
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Jul 5, 2024
1 parent e10a84c commit c927d14
Showing 1 changed file with 7 additions and 40 deletions.
47 changes: 7 additions & 40 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

from datasets import load_dataset

import lighteval.tasks.tasks_prompt_formatting as tasks_prompt_formatting
from lighteval.few_shot_manager import FewShotSampler
from lighteval.logging.hierarchical_logger import hlog, hlog_warn
from lighteval.metrics import (
Expand Down Expand Up @@ -71,7 +70,7 @@ class LightevalTaskConfig:
Arguments:
name (str): Short name of the evaluation task.
suite (list[str]): Evaluation suites to which the task belongs.
prompt_function (FormatterType|str): Name of the function used to create the [`Doc`] samples from each line of the evaluation dataset.
prompt_function (FormatterType): Function used to create the [`Doc`] samples from each line of the evaluation dataset.
hf_repo (str): Path of the hub dataset repository containing the evaluation information.
hf_subset (str): Subset used for the current task, will be default if none is selected.
hf_avail_splits (list[str]): All the available splits in the evaluation dataset
Expand All @@ -91,7 +90,7 @@ class LightevalTaskConfig:
"""

name: str
prompt_function: FormatterType | str
prompt_function: FormatterType
hf_repo: str
hf_subset: str
metric: Tuple[Union[str, Metrics]]
Expand Down Expand Up @@ -132,38 +131,6 @@ def __post_init__(self):
self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence is not None else None


def load_prompt_function(prompt_function: str, custom_tasks_module: list | None) -> FormatterType:
"""
Tries to load the prompt function defined as string.
Arguments:
prompt_function (str): Name of the prompt function to load.
custom_tasks_module (list): List of custom modules to search for the prompt function.
Returns:
FormatterType: The prompt function.
"""

if custom_tasks_module is None:
return getattr(tasks_prompt_formatting, prompt_function)

formatter = []
for module in custom_tasks_module:
if hasattr(module, prompt_function):
formatter.append(getattr(module, prompt_function))

if len(formatter) == 0: # Default version
return getattr(tasks_prompt_formatting, prompt_function)
elif len(formatter) == 1:
# If we have a prompt in both the module and our tasks_prompt_formatting
# We take the prompt from the module
if hasattr(tasks_prompt_formatting, prompt_function):
hlog_warn(f"Be careful you are using custom prompt function {prompt_function} and not the default one.")
return formatter[0]
else:
raise Exception(
f"You defined the prompt function {prompt_function} several times in the different custom modules you are loading."
)


class LightevalTask:
def __init__( # noqa: C901
self, name: str, cfg: LightevalTaskConfig, cache_dir: Optional[str] = None, custom_tasks_module: list = None
Expand Down Expand Up @@ -237,11 +204,11 @@ def __init__( # noqa: C901
self.num_samples = [1] + [
int(metric.replace("maj_at_", "").split("_")[0]) for metric in self.metrics if "maj_at_" in metric
]
self.formatter: FormatterType
if isinstance(cfg.prompt_function, str):
self.formatter = load_prompt_function(cfg.prompt_function, custom_tasks_module)
else:
self.formatter = cfg.prompt_function
if not isinstance(cfg.prompt_function, FormatterType):
raise TypeError(
f"Prompt formatting function ({str(cfg.prompt_function)}) should have been passed as a callable, was {type(cfg.prompt_function)} instead."
)
self.formatter = cfg.prompt_function

self.generation_size = cfg.generation_size
self.stop_sequence = cfg.stop_sequence
Expand Down

0 comments on commit c927d14

Please sign in to comment.