Skip to content

Commit

Permalink
add examples of samples (does not include few shot), and add robustne…
Browse files Browse the repository at this point in the history
…ss to logger
  • Loading branch information
clefourrier committed Jul 17, 2024
1 parent 4550cb7 commit 2a6da98
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 22 deletions.
41 changes: 34 additions & 7 deletions src/lighteval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@
# SOFTWARE.

import argparse
import os

from lighteval.parsers import parser_accelerate, parser_nanotron
from lighteval.tasks.registry import Registry
from lighteval.tasks.registry import Registry, taskinfo_selector


CACHE_DIR = os.getenv("HF_HOME")


def cli_evaluate():
Expand All @@ -40,25 +44,48 @@ def cli_evaluate():
parser_b = subparsers.add_parser("nanotron", help="use nanotron as backend for evaluation.")
parser_nanotron(parser_b)

parser.add_argument("--list-tasks", action="store_true", help="List available tasks")
# values which should always be set
parser.add_argument(
"--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models"
)

# utils functions
parser.add_argument("--list-tasks", action="store_true", help="List available tasks")
parser.add_argument(
"--tasks-examples",
type=str,
default=None,
help="Id of tasks or path to a text file with a list of tasks (e.g. 'original|mmlu:abstract_algebra|5') for which you want to manually inspect samples.",
)
args = parser.parse_args()

if args.subcommand == "accelerate":
from lighteval.main_accelerate import main as main_accelerate

main_accelerate(args)
return

if args.subcommand == "nanotron":
elif args.subcommand == "nanotron":
from lighteval.main_nanotron import main as main_nanotron

main_nanotron(args.checkpoint_config_path, args.lighteval_override, args.cache_dir)
return

if args.list_tasks:
elif args.list_tasks:
Registry(cache_dir="").print_all_tasks()
return

elif args.tasks_examples:
print(f"Loading the tasks dataset to cache folder: {args.cache_dir}")
print(
"All examples will be displayed without few shot, as few shot sample construction requires loading a model and using its tokenizer."
)
task_names_list, _ = taskinfo_selector(args.tasks_examples)
task_dict = Registry(cache_dir=args.cache_dir).get_task_dict(task_names_list)
for name, task in task_dict.items():
print("-" * 10, name, "-" * 10)
for sample in task.eval_docs()[:10]:
print(sample)

else:
print("You did not provide any argument. Exiting")


if __name__ == "__main__":
Expand Down
19 changes: 14 additions & 5 deletions src/lighteval/logging/hierarchical_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sys
import time
from datetime import timedelta
from logging import Logger
from typing import Any, Callable

from lighteval.utils import is_accelerate_available, is_nanotron_available
Expand All @@ -37,8 +38,6 @@

logger = get_logger(__name__, log_level="INFO")
else:
from logging import Logger

logger = Logger(__name__, level="INFO")

from colorama import Fore, Style
Expand Down Expand Up @@ -76,6 +75,7 @@ def log(self, x: Any) -> None:


HIERARCHICAL_LOGGER = HierarchicalLogger()
BACKUP_LOGGER = Logger(__name__, level="INFO")


# Exposed public methods
Expand All @@ -84,23 +84,32 @@ def hlog(x: Any) -> None:
Logs a string version of x through the singleton [`HierarchicalLogger`].
"""
HIERARCHICAL_LOGGER.log(x)
try:
HIERARCHICAL_LOGGER.log(x)
except RuntimeError:
BACKUP_LOGGER.warning(x)


def hlog_warn(x: Any) -> None:
"""Warning logger.
Logs a string version of x, which will appear in a yellow color, through the singleton [`HierarchicalLogger`].
"""
HIERARCHICAL_LOGGER.log(Fore.YELLOW + str(x) + Style.RESET_ALL)
try:
HIERARCHICAL_LOGGER.log(Fore.YELLOW + str(x) + Style.RESET_ALL)
except RuntimeError:
BACKUP_LOGGER.warning(Fore.YELLOW + str(x) + Style.RESET_ALL)


def hlog_err(x: Any) -> None:
"""Error logger.
Logs a string version of x, which will appear in a red color, through the singleton [`HierarchicalLogger`].
"""
HIERARCHICAL_LOGGER.log(Fore.RED + str(x) + Style.RESET_ALL)
try:
HIERARCHICAL_LOGGER.log(Fore.RED + str(x) + Style.RESET_ALL)
except RuntimeError:
BACKUP_LOGGER.warning(Fore.RED + str(x) + Style.RESET_ALL)


class htrack_block:
Expand Down
1 change: 0 additions & 1 deletion src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
hlog_warn("Using either accelerate or text-generation to run this script is advised.")

TOKEN = os.getenv("HF_TOKEN")
CACHE_DIR = os.getenv("HF_HOME")

if is_accelerate_available():
from accelerate import Accelerator, InitProcessGroupKwargs
Expand Down
9 changes: 0 additions & 9 deletions src/lighteval/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ def parser_accelerate(parser=None):
parser.add_argument(
"--public_run", default=False, action="store_true", help="Push results and details to a public repo"
)
parser.add_argument(
"--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models"
)
parser.add_argument(
"--results_org",
type=str,
Expand Down Expand Up @@ -120,9 +117,3 @@ def parser_nanotron(parser=None):
type=str,
help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config",
)
parser.add_argument(
"--cache-dir",
type=str,
default=None,
help="Cache directory",
)

0 comments on commit 2a6da98

Please sign in to comment.