Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation and typing of functions #5

Merged
merged 15 commits into from
Feb 6, 2024
27 changes: 19 additions & 8 deletions src/lighteval/logging/evaluation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@


class EnhancedJSONEncoder(json.JSONEncoder):
"""Provides a proper json encoding for the loggers and trackers json dumps.
"""
Provides a proper json encoding for the loggers and trackers json dumps.
Notably manages the json encoding of dataclasses.
"""

Expand All @@ -39,10 +40,16 @@ def default(self, o):


class EvaluationTracker:
"""Keeps track of the overall evaluation process and relevant informations.

The [`EvaluationTracker`] contains specific loggers for experiments details ([`DetailsLogger`]), metrics ([`MetricsLogger`]), task versions ([`VersionsLogger`]) as well as for the general configurations of both the specific task ([`TaskConfigLogger`]) and overall evaluation run ([`GeneralConfigLogger`]).
It compiles the data from these loggers and writes it to files, which can be published to the Hugging Face hub if requested.
"""
Keeps track of the overall evaluation process and relevant informations.

The [`EvaluationTracker`] contains specific loggers for experiments details
([`DetailsLogger`]), metrics ([`MetricsLogger`]), task versions
([`VersionsLogger`]) as well as for the general configurations of both the
specific task ([`TaskConfigLogger`]) and overall evaluation run
([`GeneralConfigLogger`]). It compiles the data from these loggers and
writes it to files, which can be published to the Hugging Face hub if
requested.
"""

details_logger: DetailsLogger
Expand All @@ -53,11 +60,15 @@ class EvaluationTracker:
hub_results_org: str

def __init__(self, hub_results_org: str = "", token: str = "") -> None:
"""Creates all the necessary loggers for evaluation tracking.
"""
Creates all the necessary loggers for evaluation tracking.

Args:
hub_results_org (str): The organisation to push the results to. See more details about the datasets organisation in [`EvaluationTracker.save`]
token (str): Token to use when pushing to the hub. This token should have write access to `hub_results_org`.
hub_results_org (str): The organisation to push the results to. See
more details about the datasets organisation in
[`EvaluationTracker.save`]
token (str): Token to use when pushing to the hub. This token should
have write access to `hub_results_org`.
"""
self.details_logger = DetailsLogger()
self.metrics_logger = MetricsLogger()
Expand Down
29 changes: 28 additions & 1 deletion src/lighteval/logging/info_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import time
from dataclasses import asdict, dataclass, field
from typing import Union

import git
import numpy as np
Expand Down Expand Up @@ -72,14 +73,40 @@ def __init__(self) -> None:
self.lighteval_sha = repo.git.rev_parse("HEAD")
self.start_time = time.perf_counter()

def log_args_info(self, num_fewshot_seeds, override_batch_size, max_samples, job_id, config=None) -> None:
def log_args_info(
self,
num_fewshot_seeds: int,
override_batch_size: Union[None, int],
max_samples: Union[None, int],
job_id: str,
config: "BrrrConfig" = None,
) -> None:
"""
Logs the information about the arguments passed to the method.

Args:
num_fewshot_seeds (int): The number of few-shot seeds.
override_batch_size (Union[None, int]): The overridden batch size.
max_samples (Union[None, int]): The maximum number of samples.
NathanHB marked this conversation as resolved.
Show resolved Hide resolved
job_id (str): The job ID.
config (optional): BrrrConfig

Returns:
None
"""
self.num_fewshot_seeds = num_fewshot_seeds
self.override_batch_size = override_batch_size
self.max_samples = max_samples
self.job_id = job_id
self.config = config

def log_model_info(self, model_info: ModelInfo) -> None:
"""
Logs the model information.

Args:
model_info (ModelInfo): The model information to be logged.
"""
self.model_name = model_info.model_name
self.model_sha = model_info.model_sha
self.model_dtype = model_info.model_dtype
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def higher_is_better():
return res

@staticmethod
def corpus_level_fns():
def corpus_level_fns() -> dict[str, callable]:
res = {}
for metric in Metrics:
if metric.value.category == MetricCategory.IGNORED:
Expand Down
80 changes: 72 additions & 8 deletions src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from argparse import Namespace
from dataclasses import dataclass
from typing import Optional, Union

import torch
from transformers import AutoConfig, BitsAndBytesConfig, GPTQConfig
from transformers import AutoConfig, BitsAndBytesConfig, GPTQConfig, PretrainedConfig

from lighteval.logging.hierarchical_logger import hlog
from lighteval.models.utils import _get_model_sha
Expand All @@ -23,12 +24,17 @@

@dataclass
class EnvConfig:
"""
Configuration class for environment settings.

Attributes:
cache_dir (str): The directory for caching data.
token (str): The authentication token used for accessing the HuggingFace Hub.
"""

cache_dir: str = None
token: str = None


@dataclass
class BaseModelConfig:
"""Args:
pretrained (str):
The HuggingFace Hub model ID name or the path to a pre-trained
Expand All @@ -50,6 +56,49 @@ class BaseModelConfig:
Use `dtype="auto"` to derive the type from the model’s weights.
"""


@dataclass
class BaseModelConfig:
"""
Base configuration class for models.

Attributes:
pretrained (str): The HuggingFace Hub model ID name or the path to a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Remove all the The at the start of arg docstrings.
    Ex: instead of The revision of the model, use Model revision (commit hash) to use for evaluation.
  • Follow the ref doc for optional items
  • When an item is optional, don't forget to add it in the arg description. For example, for quantization config, you can add an "if needed" for ex.

pre-trained model to load. This is effectively the
`pretrained_model_name_or_path` argument of `from_pretrained` in the
HuggingFace `transformers` API.
accelerator (Accelerator): The accelerator to use for model training.
tokenizer (Optional[str]): The HuggingFace Hub tokenizer ID that will be
used for tokenization.
multichoice_continuations_start_space (Optional[bool]): Whether to add a
NathanHB marked this conversation as resolved.
Show resolved Hide resolved
space at the start of each continuation in multichoice generation.
subfolder (Optional[str]): The subfolder within the model repository.
revision (str): The revision of the model.
batch_size (int): The batch size for model training.
max_gen_toks (Optional[int]): The maximum number of tokens to generate.
max_length (Optional[int]): The maximum length of the generated output.
add_special_tokens (bool, optional, defaults to True):
Whether to add special tokens to the input sequences. If `None`, the
default value will be set to `True` for seq2seq models (e.g. T5) and
`False` for causal models.
NathanHB marked this conversation as resolved.
Show resolved Hide resolved
model_parallel (Optional[bool]): Whether to use model parallelism.
dtype (Optional[Union[str, torch.dtype]]): The data type of the model.
device (Union[int, str]): The device to use for model training.
quantization_config (Optional[BitsAndBytesConfig]): The quantization
configuration for the model.
load_in_8bit (bool): Whether to load the model in 8-bit precision.
load_in_4bit (bool): Whether to load the model in 4-bit precision.
trust_remote_code (bool): Whether to trust remote code during model
loading.

Methods:
__post_init__(): Performs post-initialization checks on the configuration.
_init_configs(model_name, env_config): Initializes the model configuration.
init_configs(env_config): Initializes the model configuration using the environment configuration.
get_model_sha(): Retrieves the SHA of the model.

"""

pretrained: str
accelerator: "Accelerator" = None
tokenizer: Optional[str] = None
Expand Down Expand Up @@ -77,7 +126,7 @@ def __post_init__(self):
if not isinstance(self.device, str):
raise ValueError("Current device must be passed as string.")

def _init_configs(self, model_name, env_config: EnvConfig):
def _init_configs(self, model_name: str, env_config: EnvConfig) -> PretrainedConfig:
revision = self.revision
if self.subfolder:
revision = f"{self.revision}/{self.subfolder}"
Expand All @@ -98,7 +147,7 @@ def _init_configs(self, model_name, env_config: EnvConfig):

return auto_config

def init_configs(self, env_config: EnvConfig):
def init_configs(self, env_config: EnvConfig) -> PretrainedConfig:
return self._init_configs(self.pretrained, env_config=env_config)

def get_model_sha(self):
Expand Down Expand Up @@ -146,8 +195,23 @@ class TGIModelConfig:
inference_server_auth: str


def create_model_config(args, accelerator: Accelerator): # noqa C901
# Tests
def create_model_config(args: Namespace, accelerator: Union[Accelerator, None]) -> BaseModelConfig: # noqa: C901
"""
Create a model configuration based on the provided arguments.

Args:
args (Namespace): The command-line arguments.
accelerator (Union[Accelerator, None]): The accelerator to use for model training.

Returns:
BaseModelConfig: The model configuration.

Raises:
ValueError: If both an inference server address and model arguments are provided.
ValueError: If both multichoice continuations start with a space and do not start with a space.
NathanHB marked this conversation as resolved.
Show resolved Hide resolved
ValueError: If a base model is not specified when using delta weights or adapter weights.
ValueError: If a base model is specified when not using delta weights or adapter weights.
"""
if args.inference_server_address is not None and args.model_args is not None:
raise ValueError("You cannot both use an inference server and load a model from its checkpoint.")

Expand Down
Loading
Loading