From ac57b78e7b34cf41ae93d1a9f8fae1b23f52ffe5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine=20Fourrier?= <22726840+clefourrier@users.noreply.github.com> Date: Tue, 9 Jul 2024 12:13:01 +0200 Subject: [PATCH] Homogeneize logging system (#150) --------- Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> Co-authored-by: Nathan Habib --- README.md | 2 +- pyproject.toml | 3 +- run_evals_accelerate.py | 1 + src/lighteval/logging/evaluation_tracker.py | 230 +++++++++----------- src/lighteval/main_accelerate.py | 14 +- src/lighteval/main_nanotron.py | 8 +- src/lighteval/utils.py | 9 + 7 files changed, 136 insertions(+), 131 deletions(-) diff --git a/README.md b/README.md index dc4827735..8c6f10635 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ Install the dependencies. For the default installation, you just need: pip install . ``` -If you want to evaluate models with frameworks like `accelerate` or `peft`, you will need to specify the optional dependencies group that fits your use case (`accelerate`,`tgi`,`optimum`,`quantization`,`adapters`,`nanotron`): +If you want to evaluate models with frameworks like `accelerate` or `peft`, you will need to specify the optional dependencies group that fits your use case (`accelerate`,`tgi`,`optimum`,`quantization`,`adapters`,`nanotron`,`tensorboardX`): ```bash pip install '.[optional1,optional2]' diff --git a/pyproject.toml b/pyproject.toml index a9fe4bc7a..b771942d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ keywords = ["evaluation", "nlp", "llm"] dependencies = [ # Base dependencies "transformers>=4.38.0", - "huggingface_hub>=0.22.0", + "huggingface_hub>=0.23.0", "torch>=2.0", "GitPython>=3.1.41", # for logging "datasets>=2.14.0", @@ -86,6 +86,7 @@ nanotron = [ "nanotron", "tensorboardX" ] +tensorboardX = ["tensorboardX"] quality = ["ruff==v0.2.2","pre-commit"] tests = ["pytest==7.4.0"] dev = ["lighteval[accelerate,quality,tests]"] diff --git a/run_evals_accelerate.py b/run_evals_accelerate.py index 20b6ec9f1..d623de256 100644 --- a/run_evals_accelerate.py +++ b/run_evals_accelerate.py @@ -48,6 +48,7 @@ def get_parser(): parser.add_argument("--push_results_to_hub", default=False, action="store_true") parser.add_argument("--save_details", action="store_true") parser.add_argument("--push_details_to_hub", default=False, action="store_true") + parser.add_argument("--push_results_to_tensorboard", default=False, action="store_true") parser.add_argument( "--public_run", default=False, action="store_true", help="Push results and details to a public repo" ) diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index f4bdf9566..b1dbe616d 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -41,11 +41,11 @@ TaskConfigLogger, VersionsLogger, ) -from lighteval.utils import is_nanotron_available, obj_to_markdown +from lighteval.utils import NO_TENSORBOARDX_WARN_MSG, is_nanotron_available, is_tensorboardX_available, obj_to_markdown if is_nanotron_available(): - from nanotron.config import Config + from nanotron.config import GeneralArgs class EnhancedJSONEncoder(json.JSONEncoder): @@ -80,56 +80,74 @@ class EvaluationTracker: task_config_logger: TaskConfigLogger hub_results_org: str - def __init__(self, hub_results_org: str = "", token: str = "") -> None: - """ + def __init__( + self, + output_dir: str = None, + hub_results_org: str = "", + push_results_to_hub: bool = False, + push_details_to_hub: bool = False, + push_results_to_tensorboard: bool = False, + tensorboard_metric_prefix: str = "eval", + public: bool = False, + token: str = "", + nanotron_run_info: "GeneralArgs" = None, + ) -> None: + """) Creates all the necessary loggers for evaluation tracking. Args: + output_dir (str): Local folder path where you want results to be saved hub_results_org (str): The organisation to push the results to. See more details about the datasets organisation in [`EvaluationTracker.save`] + push_results_to_hub (bool): If True, results are pushed to the hub. + Results will be pushed either to `{hub_results_org}/results`, a public dataset, if `public` is True else to `{hub_results_org}/private-results`, a private dataset. + push_details_to_hub (bool): If True, details are pushed to the hub. + Results are pushed to `{hub_results_org}/details__{sanitized model_name}` for the model `model_name`, a public dataset, + if `public` is True else `{hub_results_org}/details__{sanitized model_name}_private`, a private dataset. + push_results_to_tensorboard (bool): If True, will create and push the results for a tensorboard folder on the hub + public (bool): If True, results and details are pushed in private orgs token (str): Token to use when pushing to the hub. This token should have write access to `hub_results_org`. + nanotron_run_info (GeneralArgs): Reference to informations about Nanotron models runs """ self.details_logger = DetailsLogger() self.metrics_logger = MetricsLogger() self.versions_logger = VersionsLogger() self.general_config_logger = GeneralConfigLogger() self.task_config_logger = TaskConfigLogger() - self.hub_results_org = hub_results_org - self.hub_results_repo = f"{hub_results_org}/results" - self.hub_private_results_repo = f"{hub_results_org}/private-results" + self.api = HfApi(token=token) - def save( - self, - output_dir: str, - push_results_to_hub: bool, - push_details_to_hub: bool, - public: bool, - push_results_to_tensorboard: bool = False, - ) -> None: - """Saves the experiment information and results to files, and to the hub if requested. + self.output_dir = output_dir - Note: - In case of save failure, this function will only print a warning, with the error message. + self.hub_results_org = hub_results_org # will also contain tensorboard results + if hub_results_org in ["", None] and any( + [push_details_to_hub, push_results_to_hub, push_results_to_tensorboard] + ): + raise Exception( + "You need to select which org to push to, using `--results_org`, if you want to save information to the hub." + ) - Args: - output_dir (str): Local folder path where you want results to be saved - push_results_to_hub (bool): If True, results are pushed to the hub. - Results will be pushed either to `{hub_results_org}/results`, a public dataset, if `public` is True else to `{hub_results_org}/private-results`, a private dataset. - push_details_to_hub (bool): If True, details are pushed to the hub. - Results are pushed to `{hub_results_org}/details__{sanitized model_name}` for the model `model_name`, a public dataset, - if `public` is True else `{hub_results_org}/details__{sanitized model_name}_private`, a private dataset. - public (bool): If True, results and details are pushed in private orgs + self.hub_results_repo = f"{hub_results_org}/results" + self.hub_private_results_repo = f"{hub_results_org}/private-results" + self.push_results_to_hub = push_results_to_hub + self.push_details_to_hub = push_details_to_hub - """ + self.push_results_to_tensorboard = push_results_to_tensorboard + self.tensorboard_repo = f"{hub_results_org}/tensorboard_logs" + self.tensorboard_metric_prefix = tensorboard_metric_prefix + self.nanotron_run_info = nanotron_run_info + + self.public = public + + def save(self) -> None: + """Saves the experiment information and results to files, and to the hub if requested.""" hlog("Saving experiment tracker") - # try: date_id = datetime.now().isoformat().replace(":", "-") - output_dir_results = Path(output_dir) / "results" / self.general_config_logger.model_name - output_dir_details = Path(output_dir) / "details" / self.general_config_logger.model_name + output_dir_results = Path(self.output_dir) / "results" / self.general_config_logger.model_name + output_dir_details = Path(self.output_dir) / "details" / self.general_config_logger.model_name output_dir_details_sub_folder = output_dir_details / date_id output_dir_results.mkdir(parents=True, exist_ok=True) output_dir_details_sub_folder.mkdir(parents=True, exist_ok=True) @@ -140,9 +158,6 @@ def save( hlog(f"Saving results to {output_results_file} and {output_results_in_details_file}") config_general = copy.deepcopy(self.general_config_logger) - config_general.config = ( - config_general.config.as_dict() if is_dataclass(config_general.config) else config_general.config - ) config_general = asdict(config_general) to_dump = { @@ -163,14 +178,8 @@ def save( for task_name, task_details in self.details_logger.details.items(): output_file_details = output_dir_details_sub_folder / f"details_{task_name}_{date_id}.parquet" - # Create a dataset from the dictionary - try: - dataset = Dataset.from_list([asdict(detail) for detail in task_details]) - except Exception: - # We force cast to str to avoid formatting problems for nested objects - dataset = Dataset.from_list( - [{k: str(v) for k, v in asdict(detail).items()} for detail in task_details] - ) + # Create a dataset from the dictionary - we force cast to str to avoid formatting problems for nested objects + dataset = Dataset.from_list([{k: str(v) for k, v in asdict(detail).items()} for detail in task_details]) # We don't keep 'id' around if it's there column_names = dataset.column_names @@ -182,30 +191,25 @@ def save( # Save the dataset to a Parquet file dataset.to_parquet(output_file_details.as_posix()) - if push_results_to_hub: + if self.push_results_to_hub: self.api.upload_folder( - repo_id=self.hub_results_repo if public else self.hub_private_results_repo, + repo_id=self.hub_results_repo if self.public else self.hub_private_results_repo, folder_path=output_dir_results, path_in_repo=self.general_config_logger.model_name, repo_type="dataset", commit_message=f"Updating model {self.general_config_logger.model_name}", ) - if push_details_to_hub: + if self.push_details_to_hub: self.details_to_hub( - model_name=self.general_config_logger.model_name, results_file_path=output_results_in_details_file, details_folder_path=output_dir_details_sub_folder, - push_as_public=public, ) - if push_results_to_tensorboard: - self.push_results_to_tensorboard( + if self.push_results_to_tensorboard: + self.push_to_tensorboard( results=self.metrics_logger.metric_aggregated, details=self.details_logger.details ) - # except Exception as e: - # hlog("WARNING: Could not save results") - # hlog(repr(e)) def generate_final_dict(self) -> dict: """Aggregates and returns all the logger's experiment information in a dictionary. @@ -230,29 +234,25 @@ def generate_final_dict(self) -> dict: def details_to_hub( self, - model_name: str, results_file_path: Path | str, details_folder_path: Path | str, - push_as_public: bool = False, ) -> None: """Pushes the experiment details (all the model predictions for every step) to the hub. Args: - model_name (str): Name of the currently evaluated model results_file_path (str or Path): Local path of the current's experiment aggregated results individual file details_folder_path (str or Path): Local path of the current's experiment details folder. The details folder (created by [`EvaluationTracker.save`]) should contain one parquet file per task used during the evaluation run of the current model. - push_as_public (bool, optional): If True, the results will be pushed publicly, else the datasets will be private. """ results_file_path = str(results_file_path) details_folder_path = str(details_folder_path) - sanitized_model_name = model_name.replace("/", "__") + sanitized_model_name = self.general_config_logger.model_name.replace("/", "__") # "Default" detail names are the public detail names (same as results vs private-results) repo_id = f"{self.hub_results_org}/details_{sanitized_model_name}" - if not push_as_public: # if not public, we add `_private` + if not self.public: # if not public, we add `_private` repo_id = f"{repo_id}_private" sub_folder_path = os.path.basename(results_file_path).replace(".json", "").replace("results_", "") @@ -265,7 +265,7 @@ def details_to_hub( if len(checked_paths) == 0: hlog(f"Repo {repo_id} not found for {results_file_path}. Creating it.") - self.api.create_repo(repo_id, private=not (push_as_public), repo_type="dataset", exist_ok=True) + self.api.create_repo(repo_id, private=not (self.public), repo_type="dataset", exist_ok=True) # Create parquet version of results file as well results = load_dataset("json", data_files=results_file_path) @@ -287,43 +287,45 @@ def details_to_hub( repo_id=repo_id, folder_path=details_folder_path, path_in_repo=sub_folder_path, repo_type="dataset" ) - self.recreate_metadata_card(repo_id, model_name) + self.recreate_metadata_card(repo_id) - def recreate_metadata_card(self, repo_id: str, model_name: str = None) -> None: # noqa: C901 + def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901 """Fully updates the details repository metadata card for the currently evaluated model Args: repo_id (str): Details dataset repository path on the hub (`org/dataset`) - model_name (str): Name of the currently evaluated model. - """ # Add a nice dataset card and the configuration YAML files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset") results_files = [f for f in files_in_repo if ".json" in f] - parquet_results_files = [f for f in files_in_repo if ".parquet" in f and "results_" in f] - parquet_files = [f for f in files_in_repo if ".parquet" in f and "results_" not in f] + parquet_files = [f for f in files_in_repo if ".parquet" in f] multiple_results = len(results_files) > 1 # Get last eval results date for each task (evals might be non overlapping) last_eval_date_results = {} for sub_file in parquet_files: + # We focus on details only + if "results_" in sub_file: + continue + # subfile have this general format: # `2023-09-03T10-57-04.203304/details_harness|hendrycksTest-us_foreign_policy|5_2023-09-03T10-57-04.203304.parquet` # in the iso date, the `:` are replaced by `-` because windows does not allow `:` in their filenames - - task_name = os.path.basename(sub_file).replace("details_", "").split("_2023")[0].split("_2024")[0] + task_name = ( + os.path.basename(sub_file).replace("details_", "").split("_202")[0] + ) # 202 for dates, 2023, 2024, ... # task_name is then equal to `leaderboard|mmlu:us_foreign_policy|5` - iso_date = os.path.dirname(sub_file) # to be able to parse the filename as iso dates, we need to re-replace the `-` with `:` # iso_date[13] = iso_date[16] = ':' - iso_date = iso_date[:13] + ":" + iso_date[14:16] + ":" + iso_date[17:] - + dir_name = os.path.dirname(sub_file) + iso_date = ":".join(dir_name.rsplit("-", 2)) eval_date = datetime.fromisoformat(iso_date) last_eval_date_results[task_name] = ( max(last_eval_date_results[task_name], eval_date) if task_name in last_eval_date_results else eval_date ) + max_last_eval_date_results = list(last_eval_date_results.values())[0] # Now we convert them in iso-format for task in last_eval_date_results: @@ -336,43 +338,20 @@ def recreate_metadata_card(self, repo_id: str, model_name: str = None) -> None: card_metadata = MetadataConfigs() # Add the results config and add the result file as a parquet file - for sub_file in parquet_results_files: - eval_date = os.path.basename(sub_file).replace("results_", "").replace(".parquet", "") - sanitized_eval_date = re.sub(r"[^\w\.]", "_", eval_date) - sanitized_last_eval_date_results = re.sub(r"[^\w\.]", "_", max_last_eval_date_results) - - repo_file_name = os.path.basename(sub_file) - - if multiple_results: - if "results" not in card_metadata: - card_metadata["results"] = { - "data_files": [{"split": sanitized_eval_date, "path": [repo_file_name]}] - } - else: - former_entry = card_metadata["results"] - card_metadata["results"] = { - "data_files": former_entry["data_files"] - + [{"split": sanitized_eval_date, "path": [repo_file_name]}] - } + for sub_file in parquet_files: + if "results_" in sub_file: + eval_date = os.path.basename(sub_file).replace("results_", "").replace(".parquet", "") + sanitized_task = "results" + sanitized_last_eval_date_results = re.sub(r"[^\w\.]", "_", max_last_eval_date_results) + repo_file_name = os.path.basename(sub_file) else: - if "results" in card_metadata: - raise ValueError( - f"Entry for results already exists in {former_entry} for repo {repo_id} and file {sub_file}" - ) - card_metadata["results"] = {"data_files": [{"split": sanitized_eval_date, "path": [repo_file_name]}]} + task_name = os.path.basename(sub_file).replace("details_", "").split("_2023")[0].split("_2024")[0] + sanitized_task = re.sub(r"\W", "_", task_name) + eval_date = os.path.dirname(sub_file) + sanitized_last_eval_date_results = re.sub(r"[^\w\.]", "_", last_eval_date_results[task_name]) + repo_file_name = os.path.join("**", os.path.basename(sub_file)) - if sanitized_eval_date == sanitized_last_eval_date_results: - all_entry = card_metadata["results"]["data_files"] - card_metadata["results"] = {"data_files": all_entry + [{"split": "latest", "path": [repo_file_name]}]} - - # Add the tasks details configs - for sub_file in parquet_files: - task_name = os.path.basename(sub_file).replace("details_", "").split("_2023")[0].split("_2024")[0] - sanitized_task = re.sub(r"\W", "_", task_name) - eval_date = os.path.dirname(sub_file) sanitized_eval_date = re.sub(r"[^\w\.]", "_", eval_date) - repo_file_name = os.path.join("**", os.path.basename(sub_file)) - sanitized_last_eval_date_results = re.sub(r"[^\w\.]", "_", last_eval_date_results[task_name]) if multiple_results: if sanitized_task not in card_metadata: @@ -400,6 +379,9 @@ def recreate_metadata_card(self, repo_id: str, model_name: str = None) -> None: "data_files": all_entry + [{"split": "latest", "path": [repo_file_name]}] } + if "results_" in sub_file: + continue + # Special case for MMLU with a single split covering it all # We add another config with all MMLU splits results together for easy inspection SPECIAL_TASKS = [ @@ -481,7 +463,7 @@ def recreate_metadata_card(self, repo_id: str, model_name: str = None) -> None: card_data = DatasetCardData( dataset_summary=f"Dataset automatically created during the evaluation run of model " - f"[{model_name}](https://huggingface.co/{model_name})" + f"[{self.general_config_logger.model_name}](https://huggingface.co/{self.general_config_logger.model_name})" f"{org_string}.\n\n" f"The dataset is composed of {len(card_metadata) - 1} configuration, each one coresponding to one of the evaluated task.\n\n" f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each " @@ -494,8 +476,8 @@ def recreate_metadata_card(self, repo_id: str, model_name: str = None) -> None: f"(note that their might be results for other tasks in the repos if successive evals didn't cover the same tasks. " f'You find each in the results and the "latest" split for each eval):\n\n' f"```python\n{results_string}\n```", - repo_url=f"https://huggingface.co/{model_name}", - pretty_name=f"Evaluation run of {model_name}", + repo_url=f"https://huggingface.co/{self.general_config_logger.model_name}", + pretty_name=f"Evaluation run of {self.general_config_logger.model_name}", leaderboard_url=leaderboard_url, point_of_contact=point_of_contact, ) @@ -507,27 +489,30 @@ def recreate_metadata_card(self, repo_id: str, model_name: str = None) -> None: ) card.push_to_hub(repo_id, repo_type="dataset") - def push_results_to_tensorboard( # noqa: C901 + def push_to_tensorboard( # noqa: C901 self, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail] ): + if not is_tensorboardX_available: + hlog_warn(NO_TENSORBOARDX_WARN_MSG) + return + if not is_nanotron_available(): hlog_warn("You cannot push results to tensorboard without having nanotron installed. Skipping") return - config: Config = self.general_config_logger.config - lighteval_config = config.lighteval - try: - global_step = config.general.step - except ValueError: - global_step = 0 - if config.lighteval.logging.tensorboard_metric_prefix is not None: - prefix = config.lighteval.logging.tensorboard_metric_prefix + prefix = self.tensorboard_metric_prefix + + if self.nanotron_run_info is not None: + global_step = self.nanotron_run_info.step + run = f"{self.nanotron_run_info.run}_{prefix}" else: - prefix = "eval" - output_dir_tb = Path(lighteval_config.logging.local_output_path) / "tb" / (config.general.run + "_" + prefix) + global_step = 0 + run = prefix + + output_dir_tb = Path(self.output_dir) / "tb" / run output_dir_tb.mkdir(parents=True, exist_ok=True) tb_context = HFSummaryWriter( logdir=str(output_dir_tb), - repo_id=lighteval_config.logging.hub_repo_tensorboard, + repo_id=self.tensorboard_repo, repo_private=True, path_in_repo="tb", commit_every=6000, # Very long time so that we can change our files names and trigger push ourselves (see below) @@ -559,14 +544,13 @@ def push_results_to_tensorboard( # noqa: C901 ) else: tb_context.add_scalar(f"{prefix}/{task_name}/{metric}", value, global_step=global_step) - # e.g. MMLU + # Tasks with subtasks for name, values in bench_averages.items(): for metric, values in values.items(): hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard") tb_context.add_scalar(f"{prefix}/{name}/{metric}", sum(values) / len(values), global_step=global_step) tb_context.add_text("eval_config", obj_to_markdown(results), global_step=global_step) - # tb_context.add_text("eval_sizes", obj_to_markdown(sizes), global_step=global_step) for task_name, task_details in details.items(): tb_context.add_text( @@ -589,8 +573,6 @@ def push_results_to_tensorboard( # noqa: C901 # Now we can push to the hub tb_context.scheduler.trigger() hlog( - f"Pushed to tensorboard at https://huggingface.co/tensorboard/{lighteval_config.logging.hub_repo_tensorboard}/" - f" at {output_dir_tb} and global_step {global_step}" + f"Pushed to tensorboard at https://huggingface.co/{self.tensorboard_repo}/{output_dir_tb}/tensorboard" + f"at global_step {global_step}" ) - # except Exception as e: - # logger.warning(f"Could not push to tensorboard\n{e}") diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index d2ffbbe3b..12122c527 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -56,7 +56,15 @@ @htrack() def main(args): env_config = EnvConfig(token=TOKEN, cache_dir=args.cache_dir) - evaluation_tracker = EvaluationTracker(hub_results_org=args.results_org, token=TOKEN) + evaluation_tracker = EvaluationTracker( + output_dir=args.output_dir, + hub_results_org=args.results_org, + push_results_to_hub=args.push_results_to_hub, + push_details_to_hub=args.push_details_to_hub, + push_results_to_tensorboard=args.push_results_to_tensorboard, + public=args.public_run, + token=TOKEN, + ) evaluation_tracker.general_config_logger.log_args_info( args.num_fewshot_seeds, args.override_batch_size, args.max_samples, args.job_id ) @@ -124,9 +132,7 @@ def main(args): evaluation_tracker.details_logger.aggregate() if args.output_dir: - evaluation_tracker.save( - args.output_dir, args.push_results_to_hub, args.push_details_to_hub, args.public_run - ) + evaluation_tracker.save() final_dict = evaluation_tracker.generate_final_dict() diff --git a/src/lighteval/main_nanotron.py b/src/lighteval/main_nanotron.py index 4610ea869..f479c5d7a 100644 --- a/src/lighteval/main_nanotron.py +++ b/src/lighteval/main_nanotron.py @@ -96,7 +96,13 @@ def main( data_parallel_size=lighteval_config.parallelism.dp, ) - evaluation_tracker = EvaluationTracker(token=TOKEN) + evaluation_tracker = EvaluationTracker( + token=TOKEN, + output_dir=lighteval_config.logging.local_output_path, + hub_results_org=lighteval_config.logging.hub_repo_tensorboard, + tensorboard_metric_prefix=lighteval_config.logging.tensorboard_metric_prefix, + nanotron_run_info=nanotron_config.general, + ) evaluation_tracker.general_config_logger.log_args_info( num_fewshot_seeds=1, override_batch_size=None, diff --git a/src/lighteval/utils.py b/src/lighteval/utils.py index 768a1cd88..162357858 100644 --- a/src/lighteval/utils.py +++ b/src/lighteval/utils.py @@ -191,6 +191,15 @@ def is_peft_available() -> bool: NO_PEFT_ERROR_MSG = "You are trying to use adapter weights models, for which you need `peft`, which is not available in your environment. Please install it using pip." +def is_tensorboardX_available() -> bool: + return importlib.util.find_spec("tensorboardX") is not None + + +NO_TENSORBOARDX_WARN_MSG = ( + "You are trying to log using tensorboardX, which is not installed. Please install it using pip. Skipping." +) + + def is_openai_available() -> bool: return importlib.util.find_spec("openai") is not None