Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation and typing of functions #5

Merged
merged 15 commits into from
Feb 6, 2024
27 changes: 19 additions & 8 deletions src/lighteval/logging/evaluation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@


class EnhancedJSONEncoder(json.JSONEncoder):
"""Provides a proper json encoding for the loggers and trackers json dumps.
"""
Provides a proper json encoding for the loggers and trackers json dumps.
Notably manages the json encoding of dataclasses.
"""

Expand All @@ -39,10 +40,16 @@ def default(self, o):


class EvaluationTracker:
"""Keeps track of the overall evaluation process and relevant informations.

The [`EvaluationTracker`] contains specific loggers for experiments details ([`DetailsLogger`]), metrics ([`MetricsLogger`]), task versions ([`VersionsLogger`]) as well as for the general configurations of both the specific task ([`TaskConfigLogger`]) and overall evaluation run ([`GeneralConfigLogger`]).
It compiles the data from these loggers and writes it to files, which can be published to the Hugging Face hub if requested.
"""
Keeps track of the overall evaluation process and relevant informations.

The [`EvaluationTracker`] contains specific loggers for experiments details
([`DetailsLogger`]), metrics ([`MetricsLogger`]), task versions
([`VersionsLogger`]) as well as for the general configurations of both the
specific task ([`TaskConfigLogger`]) and overall evaluation run
([`GeneralConfigLogger`]). It compiles the data from these loggers and
writes it to files, which can be published to the Hugging Face hub if
requested.
"""

details_logger: DetailsLogger
Expand All @@ -53,11 +60,15 @@ class EvaluationTracker:
hub_results_org: str

def __init__(self, hub_results_org: str = "", token: str = "") -> None:
"""Creates all the necessary loggers for evaluation tracking.
"""
Creates all the necessary loggers for evaluation tracking.

Args:
hub_results_org (str): The organisation to push the results to. See more details about the datasets organisation in [`EvaluationTracker.save`]
token (str): Token to use when pushing to the hub. This token should have write access to `hub_results_org`.
hub_results_org (str): The organisation to push the results to. See
more details about the datasets organisation in
[`EvaluationTracker.save`]
token (str): Token to use when pushing to the hub. This token should
have write access to `hub_results_org`.
"""
self.details_logger = DetailsLogger()
self.metrics_logger = MetricsLogger()
Expand Down
33 changes: 30 additions & 3 deletions src/lighteval/logging/info_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import time
from dataclasses import asdict, dataclass, field
from typing import Union

import git
import numpy as np
Expand Down Expand Up @@ -72,14 +73,40 @@ def __init__(self) -> None:
self.lighteval_sha = repo.git.rev_parse("HEAD")
self.start_time = time.perf_counter()

def log_args_info(self, num_fewshot_seeds, override_batch_size, max_samples, job_id, config=None) -> None:
def log_args_info(
self,
num_fewshot_seeds: int,
override_batch_size: Union[None, int],
max_samples: Union[None, int],
job_id: str,
config: "BrrrConfig" = None,
) -> None:
"""
Logs the information about the arguments passed to the method.

Args:
num_fewshot_seeds (int): number of few-shot seeds.
override_batch_size (Union[None, int]): overridden batch size.
max_samples (Union[None, int]): maximum number of samples, if None, use all the samples available.
job_id (str): job ID.
config (optional): BrrrConfig

Returns:
None
"""
self.num_fewshot_seeds = num_fewshot_seeds
self.override_batch_size = override_batch_size
self.max_samples = max_samples
self.job_id = job_id
self.config = config

def log_model_info(self, model_info: ModelInfo) -> None:
"""
Logs the model information.

Args:
model_info (ModelInfo): model information to be logged.
"""
self.model_name = model_info.model_name
self.model_sha = model_info.model_sha
self.model_dtype = model_info.model_dtype
Expand Down Expand Up @@ -160,7 +187,7 @@ class CompiledDetail:
padded (int): Total umber of samples which needed padding during the batching step for the current task.
non_padded (int): Total number of samples which did not need padding during the batching step for the current task.
effective_few_shots (float): Average effective few shots across all samples for the current task.
The effective few shot is the number of few shots actually used to fit the prompt in the model context
effective few shot is the number of few shots actually used to fit the prompt in the model context
length while allowing model generation of the expected size.
num_truncated_few_shots (int): Total number of samples which required truncated prompts to fit the model size for the current task.
"""
Expand All @@ -186,7 +213,7 @@ class CompiledDetailOverAllTasks:
padded (int): Number of samples which needed padding during the batching step across all tasks.
non_padded (int): Number of samples which did not need padding during the batching step across all tasks.
effective_few_shots (float): Average effective few shots across all samples across all tasks.
The effective few shot is the number of few shots actually used to fit the prompt in the model context
effective few shot is the number of few shots actually used to fit the prompt in the model context
length while allowing model generation of the expected size.
num_truncated_few_shots (int): Number of samples which required truncated prompts to fit the model size across all tasks.
"""
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def higher_is_better():
return res

@staticmethod
def corpus_level_fns():
def corpus_level_fns() -> dict[str, callable]:
res = {}
for metric in Metrics:
if metric.value.category == MetricCategory.IGNORED:
Expand Down
85 changes: 75 additions & 10 deletions src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from argparse import Namespace
from dataclasses import dataclass
from typing import Optional, Union

import torch
from transformers import AutoConfig, BitsAndBytesConfig, GPTQConfig
from transformers import AutoConfig, BitsAndBytesConfig, GPTQConfig, PretrainedConfig

from lighteval.logging.hierarchical_logger import hlog
from lighteval.models.utils import _get_model_sha
Expand All @@ -23,15 +24,20 @@

@dataclass
class EnvConfig:
"""
Configuration class for environment settings.

Attributes:
cache_dir (str): directory for caching data.
token (str): authentication token used for accessing the HuggingFace Hub.
"""

cache_dir: str = None
token: str = None


@dataclass
class BaseModelConfig:
"""Args:
pretrained (str):
The HuggingFace Hub model ID name or the path to a pre-trained
HuggingFace Hub model ID name or the path to a pre-trained
model to load. This is effectively the `pretrained_model_name_or_path`
argument of `from_pretrained` in the HuggingFace `transformers` API.
add_special_tokens (bool, optional, defaults to True):
Expand All @@ -47,7 +53,51 @@ class BaseModelConfig:
dtype (Union[str, torch.dtype], optional, defaults to None):):
Converts the model weights to `dtype`, if specified. Strings get
converted to `torch.dtype` objects (e.g. `float16` -> `torch.float16`).
Use `dtype="auto"` to derive the type from the model’s weights.
Use `dtype="auto"` to derive the type from the model's weights.
"""


@dataclass
class BaseModelConfig:
"""
Base configuration class for models.

Attributes:
pretrained (str): HuggingFace Hub model ID name or the path to a
pre-trained model to load. This is effectively the
`pretrained_model_name_or_path` argument of `from_pretrained` in the
HuggingFace `transformers` API.
accelerator (Accelerator): accelerator to use for model training.
tokenizer (Optional[str]): HuggingFace Hub tokenizer ID that will be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be str, optional

used for tokenization.
multichoice_continuations_start_space (Optional[bool]): Whether to add a
NathanHB marked this conversation as resolved.
Show resolved Hide resolved
space at the start of each continuation in multichoice generation.
For example, context: "What is the capital of France?" and choices: "Paris", "London".
Will be tokenized as: "What is the capital of France? Paris" and "What is the capital of France? London".
subfolder (Optional[str]): The subfolder within the model repository.
revision (str): The revision of the model.
batch_size (int): The batch size for model training.
max_gen_toks (Optional[int]): The maximum number of tokens to generate.
max_length (Optional[int]): The maximum length of the generated output.
add_special_tokens (bool, optional, defaults to True): Whether to add special tokens to the input sequences.
If `None`, the default value will be set to `True` for seq2seq models (e.g. T5) and
`False` for causal models.
model_parallel (Optional[bool]): Whether to use model parallelism.
dtype (Optional[Union[str, torch.dtype]]): data type of the model.
device (Union[int, str]): device to use for model training.
quantization_config (Optional[BitsAndBytesConfig]): quantization
configuration for the model. Needed for 4-bit and 8-bit precision.
load_in_8bit (bool): Whether to load the model in 8-bit precision.
load_in_4bit (bool): Whether to load the model in 4-bit precision.
trust_remote_code (bool): Whether to trust remote code during model
loading.

Methods:
__post_init__(): Performs post-initialization checks on the configuration.
_init_configs(model_name, env_config): Initializes the model configuration.
init_configs(env_config): Initializes the model configuration using the environment configuration.
get_model_sha(): Retrieves the SHA of the model.

"""

pretrained: str
Expand Down Expand Up @@ -77,7 +127,7 @@ def __post_init__(self):
if not isinstance(self.device, str):
raise ValueError("Current device must be passed as string.")

def _init_configs(self, model_name, env_config: EnvConfig):
def _init_configs(self, model_name: str, env_config: EnvConfig) -> PretrainedConfig:
revision = self.revision
if self.subfolder:
revision = f"{self.revision}/{self.subfolder}"
Expand All @@ -98,7 +148,7 @@ def _init_configs(self, model_name, env_config: EnvConfig):

return auto_config

def init_configs(self, env_config: EnvConfig):
def init_configs(self, env_config: EnvConfig) -> PretrainedConfig:
return self._init_configs(self.pretrained, env_config=env_config)

def get_model_sha(self):
Expand Down Expand Up @@ -146,8 +196,23 @@ class TGIModelConfig:
inference_server_auth: str


def create_model_config(args, accelerator: "Accelerator"): # noqa C901
# Tests
def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]) -> BaseModelConfig: # noqa: C901
"""
Create a model configuration based on the provided arguments.

Args:
args (Namespace): command-line arguments.
accelerator (Union[Accelerator, None]): accelerator to use for model training.

Returns:
BaseModelConfig: model configuration.

Raises:
ValueError: If both an inference server address and model arguments are provided.
ValueError: If multichoice continuations both should start with a space and should not start with a space.
ValueError: If a base model is not specified when using delta weights or adapter weights.
ValueError: If a base model is specified when not using delta weights or adapter weights.
"""
if args.inference_server_address is not None and args.model_args is not None:
raise ValueError("You cannot both use an inference server and load a model from its checkpoint.")

Expand Down
Loading
Loading