-
-
Notifications
You must be signed in to change notification settings - Fork 899
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Basic evaluate CLI command / codepath (#2188)
* basic evaluate CLI command / codepath * tests for evaluate CLI command * fixes and cleanup * review comments; slightly DRYing up things --------- Co-authored-by: Dan Saunders <[email protected]>
- Loading branch information
Showing
10 changed files
with
494 additions
and
112 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
""" | ||
CLI to run training on a model | ||
""" | ||
import logging | ||
from pathlib import Path | ||
from typing import Union | ||
|
||
import fire | ||
from dotenv import load_dotenv | ||
from transformers.hf_argparser import HfArgumentParser | ||
|
||
from axolotl.cli import ( | ||
check_accelerate_default_config, | ||
check_user_token, | ||
load_cfg, | ||
load_datasets, | ||
load_rl_datasets, | ||
print_axolotl_text_art, | ||
) | ||
from axolotl.common.cli import TrainerCliArgs | ||
from axolotl.evaluate import evaluate | ||
|
||
LOG = logging.getLogger("axolotl.cli.evaluate") | ||
|
||
|
||
def do_evaluate(cfg, cli_args) -> None: | ||
# pylint: disable=duplicate-code | ||
print_axolotl_text_art() | ||
check_accelerate_default_config() | ||
check_user_token() | ||
|
||
if cfg.rl: # and cfg.rl != "orpo": | ||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) | ||
else: | ||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) | ||
|
||
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) | ||
|
||
|
||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: | ||
# pylint: disable=duplicate-code | ||
parsed_cfg = load_cfg(config, **kwargs) | ||
parser = HfArgumentParser(TrainerCliArgs) | ||
parsed_cli_args, _ = parser.parse_args_into_dataclasses( | ||
return_remaining_strings=True | ||
) | ||
do_evaluate(parsed_cfg, parsed_cli_args) | ||
|
||
|
||
if __name__ == "__main__": | ||
load_dotenv() | ||
fire.Fire(do_cli) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
"""Module for evaluating models.""" | ||
|
||
import csv | ||
import os | ||
import sys | ||
from pathlib import Path | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
from accelerate.logging import get_logger | ||
|
||
from axolotl.common.cli import TrainerCliArgs | ||
from axolotl.logging_config import configure_logging | ||
from axolotl.train import TrainDatasetMeta | ||
from axolotl.utils.dict import DictDefault | ||
from axolotl.utils.models import load_model, load_processor, load_tokenizer | ||
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer | ||
|
||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | ||
src_dir = os.path.join(project_root, "src") | ||
sys.path.insert(0, src_dir) | ||
|
||
configure_logging() | ||
LOG = get_logger("axolotl.evaluate") | ||
|
||
|
||
def evaluate_dataset( | ||
trainer, dataset, dataset_type: str, flash_optimum: bool = False | ||
) -> Optional[Dict[str, float]]: | ||
"""Helper function to evaluate a single dataset safely. | ||
Args: | ||
trainer: The trainer instance | ||
dataset: Dataset to evaluate | ||
dataset_type: Type of dataset ('train' or 'eval') | ||
flash_optimum: Whether to use flash optimum | ||
Returns: | ||
Dictionary of metrics or None if dataset is None | ||
""" | ||
if dataset is None: | ||
return None | ||
|
||
LOG.info(f"Starting {dataset_type} set evaluation...") | ||
|
||
if flash_optimum: | ||
with torch.backends.cuda.sdp_kernel( | ||
enable_flash=True, | ||
enable_math=True, | ||
enable_mem_efficient=True, | ||
): | ||
metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type) | ||
else: | ||
metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type) | ||
|
||
LOG.info(f"{dataset_type.capitalize()} set evaluation completed!") | ||
LOG.info(f"{dataset_type.capitalize()} Metrics:") | ||
for key, value in metrics.items(): | ||
LOG.info(f"{key}: {value}") | ||
|
||
return metrics | ||
|
||
|
||
def evaluate( | ||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta | ||
) -> Dict[str, float]: | ||
""" | ||
Evaluate a model on training and validation datasets | ||
Args: | ||
cfg: Configuration dictionary | ||
cli_args: Command line arguments | ||
dataset_meta: Dataset metadata containing training and evaluation datasets | ||
Returns: | ||
Tuple containing: | ||
- The model (either PeftModel or PreTrainedModel) | ||
- The tokenizer | ||
- Dictionary of evaluation metrics | ||
""" | ||
# pylint: disable=duplicate-code | ||
# Enable expandable segments for cuda allocation to improve VRAM usage | ||
set_pytorch_cuda_alloc_conf() | ||
|
||
# Load tokenizer | ||
LOG.debug( | ||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", | ||
main_process_only=True, | ||
) | ||
tokenizer = load_tokenizer(cfg) | ||
|
||
# Load processor for multimodal models if needed | ||
processor = None | ||
if cfg.is_multimodal: | ||
processor = load_processor(cfg, tokenizer) | ||
|
||
# Get datasets | ||
train_dataset = dataset_meta.train_dataset | ||
eval_dataset = dataset_meta.eval_dataset | ||
total_num_steps = dataset_meta.total_num_steps | ||
|
||
# Load model | ||
LOG.debug("loading model for evaluation...") | ||
model, _ = load_model( | ||
cfg, tokenizer, processor=processor, inference=cli_args.inference | ||
) | ||
|
||
# Set up trainer | ||
trainer = setup_trainer( | ||
cfg, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
model=(model, None, None), # No need for model_ref or peft_config | ||
tokenizer=tokenizer, | ||
processor=processor, | ||
total_num_steps=total_num_steps, | ||
) | ||
|
||
# Evaluate datasets | ||
all_metrics = {} | ||
train_metrics = evaluate_dataset(trainer, train_dataset, "train", cfg.flash_optimum) | ||
eval_metrics = evaluate_dataset(trainer, eval_dataset, "eval", cfg.flash_optimum) | ||
|
||
if train_metrics: | ||
all_metrics.update(train_metrics) | ||
if eval_metrics: | ||
all_metrics.update(eval_metrics) | ||
|
||
# Save metrics to CSV if output directory is specified and we have metrics | ||
if cfg.output_dir and (train_metrics or eval_metrics): | ||
output_dir = Path(cfg.output_dir) | ||
output_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
metrics_file = output_dir / "eval_summary.csv" | ||
with metrics_file.open("w", newline="", encoding="utf-8") as file: | ||
writer = csv.writer(file) | ||
writer.writerow(["metric", "training", "validation"]) | ||
|
||
# Get unique metric names (removing prefixes) from available metrics | ||
train_metric_names = { | ||
k.replace("train_", ""): k for k in (train_metrics or {}) | ||
} | ||
eval_metric_names = { | ||
k.replace("eval_", ""): k for k in (eval_metrics or {}) | ||
} | ||
all_metric_names = sorted( | ||
set(train_metric_names.keys()) | set(eval_metric_names.keys()) | ||
) | ||
|
||
for metric_name in all_metric_names: | ||
train_value = ( | ||
train_metrics.get(train_metric_names.get(metric_name, ""), "") | ||
if train_metrics | ||
else "" | ||
) | ||
eval_value = ( | ||
eval_metrics.get(eval_metric_names.get(metric_name, ""), "") | ||
if eval_metrics | ||
else "" | ||
) | ||
writer.writerow([metric_name, train_value, eval_value]) | ||
|
||
LOG.info(f"Evaluation results saved to {metrics_file}") | ||
|
||
del model | ||
del tokenizer | ||
|
||
return all_metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.