Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Last PR to make custom tasks work for everyone #23

Merged
merged 13 commits into from
Feb 8, 2024
2 changes: 1 addition & 1 deletion run_evals_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_parser():
parser.add_argument("--override_batch_size", type=int, default=-1)
parser.add_argument("--dataset_loading_processes", type=int, default=1)
parser.add_argument(
"--custom_tasks_file",
"--custom_tasks",
type=str,
default=None,
help="Path to a file with custom tasks (a TASK list of dict and potentially prompt formating functions)",
Expand Down
2 changes: 1 addition & 1 deletion run_evals_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_parser():
parser.add_argument(
"--cache-dir",
type=str,
default="",
default=None,
help="Cache directory",
)

Expand Down
12 changes: 8 additions & 4 deletions src/lighteval/logging/info_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,12 @@ class GeneralConfigLogger:

def __init__(self) -> None:
"""Stores the current lighteval commit for reproducibility, and starts the evaluation timer."""
repo = git.Repo(os.path.dirname(__file__).split("src")[0])
self.lighteval_sha = repo.git.rev_parse("HEAD")
try:
repo = git.Repo(os.path.dirname(__file__).split("src")[0])
except git.InvalidGitRepositoryError:
repo = None

self.lighteval_sha = repo.git.rev_parse("HEAD") if repo is not None else "?"
self.start_time = time.perf_counter()

def log_args_info(
Expand Down Expand Up @@ -543,5 +547,5 @@ def log(self, task_dict: dict[str, LightevalTask]) -> None:
self.tasks_configs = {name: task.cfg for name, task in task_dict.items()}

def log_num_docs(self, task_name: str, original_num_docs: int, effective_num_docs: int) -> None:
self.tasks_configs[task_name]["original_num_docs"] = original_num_docs
self.tasks_configs[task_name]["effective_num_docs"] = effective_num_docs
self.tasks_configs[task_name].original_num_docs = original_num_docs
self.tasks_configs[task_name].effective_num_docs = effective_num_docs
2 changes: 1 addition & 1 deletion src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main(args):
with accelerator.main_process_first() if accelerator is not None else nullcontext():
task_names_list, few_shots_dict = taskinfo_selector(args.tasks)
task_dict = Registry(cache_dir=env_config.cache_dir).get_task_dict(
task_names_list, custom_tasks_file=args.custom_tasks_file
task_names_list, custom_tasks=args.custom_tasks
)
# Loading all the dataset in a distributed manner
LightevalTask.load_datasets(task_dict.values(), args.dataset_loading_processes)
Expand Down
8 changes: 4 additions & 4 deletions src/lighteval/main_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
def main(
checkpoint_config_path: str,
lighteval_config_path: Optional[str] = None,
cache_dir: str = None,
cache_dir: Optional[str] = None,
config_cls: Type = Config,
model_config_cls: Optional[Type] = None,
model_cls: Optional[Type] = None,
Expand Down Expand Up @@ -109,14 +109,14 @@ def main(
with htrack_block("Tasks loading"):
with local_ranks_zero_first():
tasks_selection = lighteval_config.tasks.tasks
if lighteval_config.tasks.custom_tasks_file:
_, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks_file)
if lighteval_config.tasks.custom_tasks:
_, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks)
if tasks_groups_dict and lighteval_config.tasks.tasks in tasks_groups_dict:
tasks_selection = tasks_groups_dict[lighteval_config.tasks.tasks]

task_names_list, few_shots_dict = taskinfo_selector(tasks_selection)
task_dict = Registry(cache_dir=cache_dir).get_task_dict(
task_names_list, custom_tasks_file=lighteval_config.tasks.custom_tasks_file
task_names_list, custom_tasks=lighteval_config.tasks.custom_tasks
)
# Loading all the dataset in a distributed manner
LightevalTask.load_datasets(task_dict.values(), lighteval_config.tasks.dataset_loading_processes)
Expand Down
6 changes: 2 additions & 4 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
LoglikelihoodSingleTokenRequest,
Request,
)
from lighteval.utils import (
is_accelerate_available,
)
from lighteval.utils import as_list, is_accelerate_available
from lighteval.utils_parallelism import find_executable_batch_size


Expand Down Expand Up @@ -342,7 +340,7 @@ def greedy_until(
list[GenerateReturn]: list of generated responses.
"""
for request in requests:
request.stop_sequence = request.stop_sequence + [self.tokenizer.eos_token]
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]
request.tokenized_context = self.tok_encode(request.context)

dataset = GenerativeTaskDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
Expand Down
97 changes: 76 additions & 21 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import collections
import random
from dataclasses import dataclass
from multiprocessing import Pool
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

from datasets import load_dataset

Expand Down Expand Up @@ -39,8 +40,62 @@
from lighteval.logging.evaluation_tracker import EvaluationTracker


@dataclass
class LightevalTaskConfig:
name: str
prompt_function: str
hf_repo: str
hf_subset: str
metric: Tuple[Union[str, Metrics]]
hf_avail_splits: Optional[Tuple[str]] = None
evaluation_splits: Optional[Tuple[str]] = None
few_shots_split: Optional[str] = None
few_shots_select: Optional[str] = None
generation_size: int = -1
stop_sequence: Optional[Tuple[str]] = None
output_regex: Optional[str] = None

frozen: bool = False
suite: Optional[Tuple[str]] = None # we use this to know if we should use a custom lighteval or bigcode task

def as_dict(self):
return {
"name": self.name,
"prompt_function": self.prompt_function,
"hf_repo": self.hf_repo,
"hf_subset": self.hf_subset,
"metric": tuple(str(m) for m in self.metric),
"hf_avail_splits": self.hf_avail_splits,
"evaluation_splits": self.evaluation_splits,
"few_shots_split": self.few_shots_split,
"few_shots_select": self.few_shots_select,
"generation_size": self.generation_size,
"stop_sequence": self.stop_sequence,
"output_regex": self.output_regex,
"frozen": self.frozen,
"suite": self.suite,
}

def __post_init__(self):
if self.suite is None:
self.suite = ["custom"]
if self.hf_avail_splits is None:
self.hf_avail_splits = ["train", "validation", "test"]
if self.evaluation_splits is None:
self.evaluation_splits = ["validation"]
if self.stop_sequence is None:
self.stop_sequence = ["\n"]

# Convert list to tuple for hashing
self.metric = tuple(self.metric)
self.hf_avail_splits = tuple(self.hf_avail_splits) if self.hf_avail_splits else None
self.evaluation_splits = tuple(self.evaluation_splits) if self.evaluation_splits else None
self.suite = tuple(self.suite) if self.suite else None
self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence else None
clefourrier marked this conversation as resolved.
Show resolved Hide resolved


class LightevalTask:
def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom_tasks_module=None):
def __init__(self, name: str, cfg: LightevalTaskConfig, cache_dir: Optional[str] = None, custom_tasks_module=None):
"""
Initialize a LightEval task.

Expand All @@ -60,8 +115,8 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom
self._cfg = cfg

# Dataset info
self.hf_repo = cfg["hf_repo"]
self.hf_subset = cfg["hf_subset"]
self.hf_repo = cfg.hf_repo
self.hf_subset = cfg.hf_subset
self.dataset_path = self.hf_repo
self.dataset_config_name = self.hf_subset
self.dataset = None # Delayed download
Expand All @@ -70,22 +125,22 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom
self._docs = None

# Managing splits and few shot
self.all_available_splits = as_list(cfg["hf_avail_splits"])
if cfg.get("evaluation_splits", None) is None:
self.all_available_splits = as_list(cfg.hf_avail_splits)
if cfg.evaluation_splits is None:
raise ValueError(f"The evaluation split for task {self.name} is None. Please select a valid split.")

self.evaluation_split = as_list(cfg["evaluation_splits"])
if cfg.get("few_shots_split", None) is not None:
self.fewshot_split = as_list(cfg["few_shots_split"])
self.evaluation_split = as_list(cfg.evaluation_splits)
if cfg.few_shots_split is not None:
self.fewshot_split = as_list(cfg.few_shots_split)
else:
self.fewshot_split = as_list(self.get_first_possible_fewshot_splits())
self.fewshot_sampler = FewShotSampler(
few_shots_select=cfg["few_shots_select"], few_shots_split=self.fewshot_split
few_shots_select=cfg.few_shots_select, few_shots_split=self.fewshot_split
)

# Metrics
self.metrics = as_list(cfg["metric"])
self.suite = as_list(cfg["suite"])
self.metrics = as_list(cfg.metric)
self.suite = as_list(cfg.suite)
ignored = [metric for metric in self.metrics if Metrics[metric].value.category == MetricCategory.IGNORED]
if len(ignored) > 0:
hlog_warn(f"[WARNING] Not implemented yet: ignoring the metric {' ,'.join(ignored)} for task {self.name}.")
Expand All @@ -95,20 +150,20 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom
# Data processing
# to use once prompt formatting is managed as a module
if custom_tasks_module is None:
self.formatter = getattr(tasks_prompt_formatting, cfg["prompt_function"])
elif hasattr(custom_tasks_module, cfg["prompt_function"]):
self.formatter = getattr(tasks_prompt_formatting, cfg.prompt_function)
elif hasattr(custom_tasks_module, cfg.prompt_function):
# If we have a prompt in both the custom_tasks_module and our tasks_prompt_formatting
# We take the prompt from the custom_tasks_module
if hasattr(tasks_prompt_formatting, cfg["prompt_function"]):
if hasattr(tasks_prompt_formatting, cfg.prompt_function):
hlog_warn(
f"Be careful you are using custom prompt function {cfg['prompt_function']} and not the default one."
f"Be careful you are using custom prompt function {cfg.prompt_function} and not the default one."
)
self.formatter = getattr(custom_tasks_module, cfg["prompt_function"])
self.formatter = getattr(custom_tasks_module, cfg.prompt_function)
else:
self.formatter = getattr(tasks_prompt_formatting, cfg["prompt_function"])
self.generation_size = cfg["generation_size"]
self.stop_sequence = cfg["stop_sequence"]
self.output_regex = cfg["output_regex"]
self.formatter = getattr(tasks_prompt_formatting, cfg.prompt_function)
self.generation_size = cfg.generation_size
self.stop_sequence = cfg.stop_sequence
self.output_regex = cfg.output_regex

# Save options
self.save_queries: bool = False
Expand Down
Loading
Loading