From b8923aa7ae7be82cf65cdf4641407c358c864a39 Mon Sep 17 00:00:00 2001 From: Nathan Habib <30601243+NathanHB@users.noreply.github.com> Date: Tue, 8 Oct 2024 13:35:19 +0200 Subject: [PATCH 1/2] Nathan llm judge quickfix (#350) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --------- Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> --- src/lighteval/models/vllm_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lighteval/models/vllm_model.py b/src/lighteval/models/vllm_model.py index e905f33b3..d07f05a5a 100644 --- a/src/lighteval/models/vllm_model.py +++ b/src/lighteval/models/vllm_model.py @@ -101,6 +101,7 @@ def tokenizer(self): def cleanup(self): destroy_model_parallel() del self.model.llm_engine.model_executor.driver_worker + self.model = None gc.collect() ray.shutdown() destroy_distributed_environment() From 78cda93f89be3974347493a676d59f260306b34b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Thu, 10 Oct 2024 08:58:22 +0200 Subject: [PATCH 2/2] Selecting tasks using their superset (#308) --------- Co-authored-by: Hynek Kydlicek --- src/lighteval/pipeline.py | 25 ++- src/lighteval/tasks/default_tasks.py | 71 +------- src/lighteval/tasks/registry.py | 253 +++++++++++++++++---------- tests/tasks/test_registry.py | 136 ++++++++++++++ tests/utils.py | 17 +- 5 files changed, 321 insertions(+), 181 deletions(-) create mode 100644 tests/tasks/test_registry.py diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index a051261cf..b273fdf47 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -37,7 +37,7 @@ from lighteval.models.model_loader import load_model from lighteval.models.model_output import ModelResponse from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks -from lighteval.tasks.registry import Registry, get_custom_tasks, taskinfo_selector +from lighteval.tasks.registry import Registry, taskinfo_selector from lighteval.tasks.requests import SampleUid from lighteval.utils.imports import ( NO_ACCELERATE_ERROR_MSG, @@ -166,23 +166,18 @@ def _init_model(self, model_config, model): return load_model(config=model_config, env_config=self.pipeline_parameters.env_config) return model - def _init_tasks_and_requests(self, tasks): + def _init_tasks_and_requests(self, tasks: str): with htrack_block("Tasks loading"): with local_ranks_zero_first() if self.launcher_type == ParallelismManager.NANOTRON else nullcontext(): - # If some tasks are provided as task groups, we load them separately - custom_tasks = self.pipeline_parameters.custom_tasks_directory - tasks_groups_dict = None - if custom_tasks: - _, tasks_groups_dict = get_custom_tasks(custom_tasks) - if tasks_groups_dict and tasks in tasks_groups_dict: - tasks = tasks_groups_dict[tasks] - - # Loading all tasks - task_names_list, fewshots_dict = taskinfo_selector(tasks) - task_dict = Registry(cache_dir=self.pipeline_parameters.env_config.cache_dir).get_task_dict( - task_names_list, custom_tasks=custom_tasks + registry = Registry( + cache_dir=self.pipeline_parameters.env_config.cache_dir, + custom_tasks=self.pipeline_parameters.custom_tasks_directory, + ) + task_names_list, fewshots_dict = taskinfo_selector(tasks, registry) + task_dict = registry.get_task_dict(task_names_list) + LightevalTask.load_datasets( + list(task_dict.values()), self.pipeline_parameters.dataset_loading_processes ) - LightevalTask.load_datasets(task_dict.values(), self.pipeline_parameters.dataset_loading_processes) self.evaluation_tracker.task_config_logger.log(task_dict) diff --git a/src/lighteval/tasks/default_tasks.py b/src/lighteval/tasks/default_tasks.py index 96799e7d0..198ea9e98 100644 --- a/src/lighteval/tasks/default_tasks.py +++ b/src/lighteval/tasks/default_tasks.py @@ -402,34 +402,6 @@ trust_dataset=True, version=0, ) -anli_lighteval = LightevalTaskConfig( - name="anli", - suite=["lighteval", "anli"], - prompt_function=prompt.anli, - hf_repo="anli", - hf_subset="plain_text", - hf_avail_splits=[ - "train_r1", - "dev_r1", - "train_r2", - "dev_r2", - "train_r3", - "dev_r3", - "test_r1", - "test_r2", - "test_r3", - ], - evaluation_splits=["test_r1", "test_r2", "test_r3"], - few_shots_split=None, - few_shots_select=None, - generation_size=1, - metric=[Metrics.loglikelihood_acc_single_token], - stop_sequence=["\n"], - output_regex=None, - frozen=False, - trust_dataset=True, - version=0, -) anli_r1_lighteval = LightevalTaskConfig( name="anli:r1", suite=["lighteval", "anli"], @@ -2295,7 +2267,7 @@ version=0, ) bbq_Nationality_helm = LightevalTaskConfig( - name="bbq=Nationality", + name="bbq:Nationality", suite=["helm"], prompt_function=prompt.bbq, hf_repo="lighteval/bbq_helm", @@ -11368,47 +11340,6 @@ trust_dataset=True, version=0, ) -mmlu_helm = LightevalTaskConfig( - name="mmlu", - suite=["helm", "helm_general"], - prompt_function=prompt.mmlu_helm, - hf_repo="lighteval/mmlu", - hf_subset="all", - hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], - evaluation_splits=["test"], - few_shots_split="dev", - few_shots_select=None, - generation_size=5, - metric=[ - Metrics.exact_match, - Metrics.quasi_exact_match, - Metrics.prefix_exact_match, - Metrics.prefix_quasi_exact_match, - ], - stop_sequence=["\n"], - output_regex=None, - frozen=False, - trust_dataset=True, - version=0, -) -mmlu_original = LightevalTaskConfig( - name="mmlu", - suite=["original"], - prompt_function=prompt.mmlu_helm, - hf_repo="lighteval/mmlu", - hf_subset="all", - hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], - evaluation_splits=["test"], - few_shots_split="dev", - few_shots_select="sequential", - generation_size=5, - metric=[Metrics.loglikelihood_acc_single_token], - stop_sequence=["\n"], - output_regex=None, - frozen=False, - trust_dataset=True, - version=0, -) mmlu_abstract_algebra_original = LightevalTaskConfig( name="mmlu:abstract_algebra", suite=["original", "mmlu"], diff --git a/src/lighteval/tasks/registry.py b/src/lighteval/tasks/registry.py index 69e3b2182..b7d763438 100644 --- a/src/lighteval/tasks/registry.py +++ b/src/lighteval/tasks/registry.py @@ -23,11 +23,12 @@ import collections import importlib import os +from functools import lru_cache, partial from itertools import groupby from pathlib import Path from pprint import pformat from types import ModuleType -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Union from datasets.load import dataset_module_factory @@ -59,79 +60,75 @@ TRUNCATE_FEW_SHOTS_DEFAULTS = True +LazyLightevalTask = Callable[[], LightevalTask] + + class Registry: """ The Registry class is used to manage the task registry and get task classes. """ - def __init__(self, cache_dir: str): + def __init__(self, cache_dir: Optional[str] = None, custom_tasks: Optional[Union[str, Path, ModuleType]] = None): """ Initialize the Registry class. Args: - cache_dir (str): Directory path for caching. - - Attributes: - cache_dir (str): Directory path for caching. - TASK_REGISTRY (dict[str, LightevalTask]): A dictionary containing the registered tasks. + cache_dir (Optional[str]): Directory path for caching. Defaults to None. + custom_tasks (Optional[Union[str, Path, ModuleType]]): Custom tasks to be included in registry. Can be a string path, Path object, or a module. + Each custom task should be a module with a TASKS_TABLE exposing a list of LightevalTaskConfig. + E.g: + TASKS_TABLE = [ + LightevalTaskConfig( + name="custom_task", + suite="custom", + ... + ) + ] """ - self.cache_dir: str = cache_dir - self.TASK_REGISTRY: dict[str, LightevalTask] = {**create_config_tasks(cache_dir=cache_dir)} - def get_task_class( - self, task_name: str, custom_tasks_registry: Optional[dict[str, LightevalTask]] = None - ) -> LightevalTask: + # Private attributes, not expected to be mutated after initialization + self._cache_dir = cache_dir + self._custom_tasks = custom_tasks + + def get_task_instance(self, task_name: str): """ - Get the task class based on the task name. + Get the task class based on the task name (suite|task). Args: - task_name (str): Name of the task. - custom_tasks_registry (Optional[dict[str, LightevalTask]]): A dictionary containing custom tasks. - + task_name (str): Name of the task (suite|task). Returns: LightevalTask: Task class. Raises: ValueError: If the task is not found in the task registry or custom task registry. """ - if task_name in self.TASK_REGISTRY: - if custom_tasks_registry is not None and task_name in custom_tasks_registry: - hlog_warn( - f"One of the tasks you requested ({task_name}) exists both in the default and custom tasks. Selecting the default task." - ) - return self.TASK_REGISTRY[task_name] - if custom_tasks_registry is not None and task_name in custom_tasks_registry: - return custom_tasks_registry[task_name] - hlog_warn(f"{task_name} not found in provided tasks") - hlog_warn(pformat(self.TASK_REGISTRY)) - raise ValueError( - f"Cannot find tasks {task_name} in task list or in custom task registry ({custom_tasks_registry})" - ) - - def get_task_dict( - self, task_name_list: List[str], custom_tasks: Optional[Union[str, Path, ModuleType]] = None - ) -> Dict[str, LightevalTask]: - """ - Get a dictionary of tasks based on the task name list. + task_class = self.task_registry.get(task_name) + if task_class is None: + hlog_warn(f"{task_name} not found in provided tasks") + hlog_warn(pformat(self.task_registry)) + raise ValueError(f"Cannot find tasks {task_name} in task list or in custom task registry)") - Args: - task_name_list (List[str]): A list of task names. - custom_tasks (Optional[Union[str, ModuleType]]): Path to the custom tasks file or name of a module to import containing custom tasks or the module itself - extended_tasks (Optional[str]): The path to the extended tasks group of submodules + return task_class() + @property + @lru_cache + def task_registry(self): + """ Returns: - Dict[str, LightevalTask]: A dictionary containing the tasks. + dict[str, LazyLightevalTask]: A dictionary mapping task names (suite|task) to their corresponding LightevalTask classes. - Notes: - - If custom_tasks is provided, it will import the custom tasks module and create a custom tasks registry. - - Each task in the task_name_list will be instantiated with the corresponding task class. + Example: + { + "lighteval|arc_easy": lambda: LightevalTask(name="lighteval|arc_easy", ...) + } """ + # Import custom tasks provided by the user - custom_tasks_registry = None + custom_tasks_registry = {} custom_tasks_module = [] TASKS_TABLE = [] - if custom_tasks is not None: - custom_tasks_module.append(create_custom_tasks_module(custom_tasks=custom_tasks)) + if self._custom_tasks is not None: + custom_tasks_module.append(create_custom_tasks_module(custom_tasks=self._custom_tasks)) if can_load_extended_tasks(): for extended_task_module in AVAILABLE_EXTENDED_TASKS_MODULES: custom_tasks_module.append(extended_task_module) @@ -140,24 +137,103 @@ def get_task_dict( for module in custom_tasks_module: TASKS_TABLE.extend(module.TASKS_TABLE) + # We don't log the tasks themselves as it makes the logs unreadable + hlog(f"Found {len(module.TASKS_TABLE)} custom tasks in {module.__file__}") if len(TASKS_TABLE) > 0: - custom_tasks_registry = create_config_tasks(meta_table=TASKS_TABLE, cache_dir=self.cache_dir) - hlog(custom_tasks_registry) + custom_tasks_registry = create_lazy_tasks(meta_table=TASKS_TABLE, cache_dir=self._cache_dir) + + default_tasks_registry = create_lazy_tasks(cache_dir=self._cache_dir) + # Check the overlap between default_tasks_registry and custom_tasks_registry + intersection = set(default_tasks_registry.keys()).intersection(set(custom_tasks_registry.keys())) + if len(intersection) > 0: + hlog_warn( + f"Following tasks ({intersection}) exists both in the default and custom tasks. Will use the default ones on conflict." + ) + + # Defaults tasks should overwrite custom tasks + return {**default_tasks_registry, **custom_tasks_registry} + + @property + @lru_cache + def _task_superset_dict(self): + """ + Returns: + dict[str, list[str]]: A dictionary where keys are task super set names (suite|task) and values are lists of task subset names (suite|task). + + Example: + { + "lighteval|mmlu" -> ["lighteval|mmlu:abstract_algebra", "lighteval|mmlu:college_biology", ...] + } + """ + # Note: sorted before groupby is imporant as the python implementation of groupby does not + # behave like sql groupby. For more info see the docs of itertools.groupby + superset_dict = {k: list(v) for k, v in groupby(sorted(self.task_registry.keys()), lambda x: x.split(":")[0])} + # Only consider supersets with more than one task + return {k: v for k, v in superset_dict.items() if len(v) > 1} + + @property + @lru_cache + def task_groups_dict(self) -> dict[str, list[str]]: + """ + Returns: + dict[str, list[str]]: A dictionary where keys are task group names and values are lists of task names (suite|task). + + Example: + { + "all_custom": ["custom|task1", "custom|task2", "custom|task3"], + "group1": ["custom|task1", "custom|task2"], + } + """ + if self._custom_tasks is None: + return {} + custom_tasks_module = create_custom_tasks_module(custom_tasks=self._custom_tasks) + tasks_group_dict = {} + if hasattr(custom_tasks_module, "TASKS_GROUPS"): + tasks_group_dict = custom_tasks_module.TASKS_GROUPS + + # We should allow defining task groups as comma-separated strings or lists of tasks + return {k: v if isinstance(v, list) else v.split(",") for k, v in tasks_group_dict.items()} + def get_task_dict(self, task_names: list[str]) -> dict[str, LightevalTask]: + """ + Get a dictionary of tasks based on the task name list (suite|task). + + Args: + task_name_list (List[str]): A list of task names (suite|task). + + Returns: + Dict[str, LightevalTask]: A dictionary containing the tasks. + + Notes: + - Each task in the task_name_list will be instantiated with the corresponding task class. + """ # Select relevant tasks given the subset asked for by the user - tasks_dict = {} - for task_name in task_name_list: - task_class = self.get_task_class(task_name, custom_tasks_registry=custom_tasks_registry) - tasks_dict[task_name] = task_class() + return {task_name: self.get_task_instance(task_name) for task_name in task_names} + + def expand_task_definition(self, task_definition: str): + """ + Args: + task_definition (str): Task definition to expand. In format: + - suite|task + - suite|task_superset (e.g lighteval|mmlu, which runs all the mmlu subtasks) + Returns: + list[str]: List of task names (suite|task) + """ - return tasks_dict + # Try if it's a task superset + tasks = self._task_superset_dict.get(task_definition, None) + if tasks is not None: + return tasks + + # Then it must be a single task + return [task_definition] def print_all_tasks(self): """ Print all the tasks in the task registry. """ - tasks_names = list(self.TASK_REGISTRY.keys()) + tasks_names = list(self.task_registry.keys()) tasks_names.sort() for suite, g in groupby(tasks_names, lambda x: x.split("|")[0]): tasks_names = list(g) @@ -182,33 +258,24 @@ def create_custom_tasks_module(custom_tasks: Union[str, Path, ModuleType]) -> Mo dataset_module = dataset_module_factory(str(custom_tasks)) return importlib.import_module(dataset_module.module_path) if isinstance(custom_tasks, (str, Path)): - return importlib.import_module(custom_tasks) + return importlib.import_module(str(custom_tasks)) raise ValueError(f"Cannot import custom tasks from {custom_tasks}") -def get_custom_tasks(custom_tasks: Union[str, ModuleType]) -> Tuple[ModuleType, str]: - """Get all the custom tasks available from the given custom tasks file or module. - - Args: - custom_tasks (Optional[Union[str, ModuleType]]): Path to the custom tasks file or name of a module to import containing custom tasks or the module itself - """ - custom_tasks_module = create_custom_tasks_module(custom_tasks=custom_tasks) - tasks_string = "" - if hasattr(custom_tasks_module, "TASKS_GROUPS"): - tasks_string = custom_tasks_module.TASKS_GROUPS - return custom_tasks_module, tasks_string - - -def taskinfo_selector( - tasks: str, -) -> tuple[list[str], dict[str, list[tuple[int, bool]]]]: +def taskinfo_selector(tasks: str, task_registry: Registry) -> tuple[list[str], dict[str, list[tuple[int, bool]]]]: """ Converts a input string of tasks name to task information usable by lighteval. Args: - tasks (str): A string containing a comma-separated list of tasks in the - format "suite|task|few_shot|truncate_few_shots" or a path to a file + tasks (str): A string containing a comma-separated list of tasks definitions in the + format "task_definition|few_shot|truncate_few_shots" or a path to a file containing a list of tasks. + where task_definition can be: + - path to a file containing a list of tasks (one per line) + - task group defined in TASKS_GROUPS dict in custom tasks file + - task name with few shot in format "suite|task|few_shot|truncate_few_shots" + - task superset in format "suite|task_superset|few_shot|truncate_few_shots" (superset will run all tasks with format "suite|task_superset:{subset}|few_shot|truncate_few_shots") + Returns: tuple[list[str], dict[str, list[tuple[int, bool]]]]: A tuple containing: @@ -217,11 +284,22 @@ def taskinfo_selector( """ few_shot_dict = collections.defaultdict(list) - # We can provide a path to a file with a list of tasks + # We can provide a path to a file with a list of tasks or a string of comma-separated tasks if "." in tasks and os.path.exists(tasks): - tasks = ",".join([line for line in open(tasks, "r").read().splitlines() if not line.startswith("#")]) - - for task in tasks.split(","): + with open(tasks, "r") as f: + tasks_list = [line.strip() for line in f if line.strip() and not line.startswith("#")] + else: + tasks_list = tasks.split(",") + + # At this point the strings are either task name/superset name or group names + # Here we deal with group names and map them to corresponding tasks + expanded_tasks_list: list[str] = [] + for maybe_task_group in tasks_list: + # We either expand the group (in case it's a group name), or we keep it as is (in case it's a task name or superset name) + expanded_tasks = task_registry.task_groups_dict.get(maybe_task_group, [maybe_task_group]) + expanded_tasks_list.extend(expanded_tasks) + + for task in expanded_tasks_list: try: suite_name, task_name, few_shot, truncate_few_shots = tuple(task.split("|")) truncate_few_shots = int(truncate_few_shots) @@ -239,15 +317,17 @@ def taskinfo_selector( if suite_name not in DEFAULT_SUITES: hlog(f"Suite {suite_name} unknown. This is not normal, unless you are testing adding new evaluations.") - # Store few_shot info for each task name (suite|task) - few_shot_dict[f"{suite_name}|{task_name}"].append((few_shot, truncate_few_shots)) + # This adds support for task supersets (eg: mmlu -> all the mmlu tasks) + for expanded_task in task_registry.expand_task_definition(f"{suite_name}|{task_name}"): + # Store few_shot info for each task name (suite|task) + few_shot_dict[expanded_task].append((few_shot, truncate_few_shots)) return sorted(few_shot_dict.keys()), {k: list(set(v)) for k, v in few_shot_dict.items()} -def create_config_tasks( +def create_lazy_tasks( meta_table: Optional[List[LightevalTaskConfig]] = None, cache_dir: Optional[str] = None -) -> Dict[str, LightevalTask]: +) -> Dict[str, LazyLightevalTask]: """ Create configuration tasks based on the provided meta_table. @@ -261,17 +341,10 @@ def create_config_tasks( Dict[str, LightevalTask]: A dictionary of task names mapped to their corresponding LightevalTask classes. """ - def create_task(name, cfg: LightevalTaskConfig, cache_dir: str): - class LightevalTaskFromConfig(LightevalTask): - def __init__(self): - super().__init__(name, cfg, cache_dir=cache_dir) - - return LightevalTaskFromConfig - if meta_table is None: meta_table = [config for config in vars(default_tasks).values() if isinstance(config, LightevalTaskConfig)] - tasks_with_config = {} + tasks_with_config: dict[str, LightevalTaskConfig] = {} # Every task is renamed suite|task, if the suite is in DEFAULT_SUITE for config in meta_table: if not any(suite in config.suite for suite in DEFAULT_SUITES): @@ -283,4 +356,4 @@ def __init__(self): if suite in DEFAULT_SUITES: tasks_with_config[f"{suite}|{config.name}"] = config - return {task: create_task(task, cfg, cache_dir=cache_dir) for task, cfg in tasks_with_config.items()} + return {task: partial(LightevalTask, task, cfg, cache_dir=cache_dir) for task, cfg in tasks_with_config.items()} diff --git a/tests/tasks/test_registry.py b/tests/tasks/test_registry.py new file mode 100644 index 000000000..7fc1ec088 --- /dev/null +++ b/tests/tasks/test_registry.py @@ -0,0 +1,136 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import pytest + +from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig +from lighteval.tasks.registry import Registry, taskinfo_selector + + +TASKS_TABLE = [ + LightevalTaskConfig( + name="test_task_revision", + # Won't be called, so it can be anything + prompt_function=lambda x: x, # type: ignore + hf_repo="test", + hf_subset="default", + evaluation_splits=["train"], + metric=[], + ) +] + +TASKS_GROUPS = { + "zero_and_one": "custom|test_task_revision|0|0,custom|test_task_revision|1|0", + "all_mmlu": "original|mmlu|3|0", +} + + +def test_custom_task_groups(): + """ + Tests that task info selector correctly handles custom task groups. + """ + registry = Registry(custom_tasks="tests.tasks.test_registry") + tasks, task_info = taskinfo_selector("zero_and_one", registry) + + assert set(tasks) == {"custom|test_task_revision"} + assert all(task in task_info for task in tasks) + assert all(task_info[task] == [(1, False), (0, False)] for task in tasks) + + +def test_custom_tasks(): + """ + Tests that task info selector correctly handles custom tasks. + """ + registry = Registry(custom_tasks="tests.tasks.test_registry") + tasks, task_info = taskinfo_selector("custom|test_task_revision|0|0", registry) + + assert tasks == ["custom|test_task_revision"] + assert task_info["custom|test_task_revision"] == [(0, False)] + + +def test_superset_expansion(): + """ + Tests that task info selector correctly handles supersets. + """ + registry = Registry() + + tasks, task_info = taskinfo_selector("lighteval|storycloze|0|0", registry) + + assert set(tasks) == {"lighteval|storycloze:2016", "lighteval|storycloze:2018"} + assert all(task_info[task] == [(0, False)] for task in tasks) + + +def test_superset_with_subset_task(): + """ + Tests that task info selector correctly handles if both superset and one of subset tasks are provided. + """ + registry = Registry() + + tasks, task_info = taskinfo_selector("original|mmlu|3|0,original|mmlu:abstract_algebra|5|0", registry) + + # We have all mmlu tasks + assert len(tasks) == 57 + # Since it's defined twice + assert task_info["original|mmlu:abstract_algebra"] == [(5, False), (3, False)] + + +def test_task_group_expansion_with_subset_expansion(): + """ + Tests that task info selector correctly handles a group with task superset is provided. + """ + registry = Registry(custom_tasks="tests.tasks.test_registry") + + tasks = taskinfo_selector("all_mmlu", registry)[0] + + assert len(tasks) == 57 + + +def test_invalid_task_creation(): + """ + Tests that tasks info registry correctly raises errors for invalid tasks + """ + registry = Registry() + with pytest.raises(ValueError): + registry.get_task_dict(["custom|task_revision"]) + + +def test_task_duplicates(): + """ + Tests that task info selector correctly handles if duplicate tasks are provided. + """ + registry = Registry() + + tasks, task_info = taskinfo_selector("custom|test_task_revision|0|0,custom|test_task_revision|0|0", registry) + + assert tasks == ["custom|test_task_revision"] + assert task_info["custom|test_task_revision"] == [(0, False)] + + +def test_task_creation(): + """ + Tests that tasks registry correctly creates tasks + """ + registry = Registry() + task_info = registry.get_task_dict(["lighteval|storycloze:2016"])["lighteval|storycloze:2016"] + + assert isinstance(task_info, LightevalTask) + assert task_info.name == "lighteval|storycloze:2016" diff --git a/tests/utils.py b/tests/utils.py index 32ef24318..4584ef89b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,9 +20,11 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Optional +from types import ModuleType +from typing import Optional, Union from unittest.mock import patch +from anyio import Path from transformers import AutoTokenizer from lighteval.logging.evaluation_tracker import EvaluationTracker @@ -34,6 +36,7 @@ ) from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters from lighteval.tasks.lighteval_task import LightevalTask +from lighteval.tasks.registry import Registry from lighteval.tasks.requests import ( GreedyUntilRequest, LoglikelihoodRequest, @@ -126,11 +129,13 @@ def fake_evaluate_task( evaluation_tracker.task_config_logger.log(task_dict) # Create a mock Registry class - class MockRegistry: - def __init__(self, cache_dir=None): - self.cache_dir = cache_dir + class FakeRegistry(Registry): + def __init__( + self, cache_dir: Optional[str] = None, custom_tasks: Optional[Union[str, Path, ModuleType]] = None + ): + super().__init__(cache_dir=cache_dir, custom_tasks=custom_tasks) - def get_task_dict(self, task_names_list, custom_tasks=None): + def get_task_dict(self, task_names: list[str]): return task_dict # This is due to logger complaining we have no initialised the accelerator @@ -143,7 +148,7 @@ def get_task_dict(self, task_names_list, custom_tasks=None): # This is a bit hacky, because there is no way to run end to end, with # dynamic task :(, so we just mock the registry task_run_string = f"{task_name}|{n_fewshot}|{n_fewshot_seeds}" - with patch("lighteval.pipeline.Registry", MockRegistry): + with patch("lighteval.pipeline.Registry", FakeRegistry): pipeline = Pipeline( tasks=task_run_string, pipeline_parameters=PipelineParameters(max_samples=max_samples, launcher_type=ParallelismManager.NONE),