From 735bf17fc44f1be207c6a8eeb146531ec9166eab Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 16 Dec 2024 15:46:31 -0500 Subject: [PATCH] Basic evaluate CLI command / codepath (#2188) * basic evaluate CLI command / codepath * tests for evaluate CLI command * fixes and cleanup * review comments; slightly DRYing up things --------- Co-authored-by: Dan Saunders --- src/axolotl/cli/evaluate.py | 52 ++++++++++ src/axolotl/cli/main.py | 27 +++++- src/axolotl/common/cli.py | 25 +++-- src/axolotl/evaluate.py | 168 +++++++++++++++++++++++++++++++++ src/axolotl/train.py | 19 ++-- src/axolotl/utils/data/sft.py | 3 + src/axolotl/utils/trainer.py | 11 +++ tests/cli/test_cli_base.py | 73 ++++++++++++++ tests/cli/test_cli_evaluate.py | 67 +++++++++++++ tests/cli/test_cli_train.py | 161 +++++++++++++------------------ 10 files changed, 494 insertions(+), 112 deletions(-) create mode 100644 src/axolotl/cli/evaluate.py create mode 100644 src/axolotl/evaluate.py create mode 100644 tests/cli/test_cli_base.py create mode 100644 tests/cli/test_cli_evaluate.py diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py new file mode 100644 index 0000000000..8e99d6f4b1 --- /dev/null +++ b/src/axolotl/cli/evaluate.py @@ -0,0 +1,52 @@ +""" +CLI to run training on a model +""" +import logging +from pathlib import Path +from typing import Union + +import fire +from dotenv import load_dotenv +from transformers.hf_argparser import HfArgumentParser + +from axolotl.cli import ( + check_accelerate_default_config, + check_user_token, + load_cfg, + load_datasets, + load_rl_datasets, + print_axolotl_text_art, +) +from axolotl.common.cli import TrainerCliArgs +from axolotl.evaluate import evaluate + +LOG = logging.getLogger("axolotl.cli.evaluate") + + +def do_evaluate(cfg, cli_args) -> None: + # pylint: disable=duplicate-code + print_axolotl_text_art() + check_accelerate_default_config() + check_user_token() + + if cfg.rl: # and cfg.rl != "orpo": + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + else: + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + + +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: + # pylint: disable=duplicate-code + parsed_cfg = load_cfg(config, **kwargs) + parser = HfArgumentParser(TrainerCliArgs) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + do_evaluate(parsed_cfg, parsed_cli_args) + + +if __name__ == "__main__": + load_dotenv() + fire.Fire(do_cli) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index a77410776a..ec7f5b6947 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -12,7 +12,7 @@ build_command, fetch_from_github, ) -from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs +from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig @@ -60,6 +60,31 @@ def train(config: str, accelerate: bool, **kwargs): do_cli(config=config, **kwargs) +@cli.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@click.option( + "--accelerate/--no-accelerate", + default=True, + help="Use accelerate launch for multi-GPU training", +) +@add_options_from_dataclass(EvaluateCliArgs) +@add_options_from_config(AxolotlInputConfig) +def evaluate(config: str, accelerate: bool, **kwargs): + """Evaluate a model.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + if accelerate: + base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"] + if config: + base_cmd.append(config) + cmd = build_command(base_cmd, kwargs) + subprocess.run(cmd, check=True) # nosec B603 + else: + from axolotl.cli.evaluate import do_cli + + do_cli(config=config, **kwargs) + + @cli.command() @click.argument("config", type=click.Path(exists=True, path_type=str)) @click.option( diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 6a3a22e637..02ad9201b8 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -15,6 +15,19 @@ LOG = logging.getLogger("axolotl.common.cli") +@dataclass +class PreprocessCliArgs: + """ + dataclass representing arguments for preprocessing only + """ + + debug: bool = field(default=False) + debug_text_only: bool = field(default=False) + debug_num_examples: int = field(default=1) + prompter: Optional[str] = field(default=None) + download: Optional[bool] = field(default=True) + + @dataclass class TrainerCliArgs: """ @@ -31,16 +44,14 @@ class TrainerCliArgs: @dataclass -class PreprocessCliArgs: +class EvaluateCliArgs: """ - dataclass representing arguments for preprocessing only + dataclass representing the various evaluation arguments """ debug: bool = field(default=False) debug_text_only: bool = field(default=False) - debug_num_examples: int = field(default=1) - prompter: Optional[str] = field(default=None) - download: Optional[bool] = field(default=True) + debug_num_examples: int = field(default=0) def load_model_and_tokenizer( @@ -50,7 +61,9 @@ def load_model_and_tokenizer( ): LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer = load_tokenizer(cfg) + LOG.info("loading model and (optionally) peft_config...") - model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + inference = getattr(cli_args, "inference", False) + model, _ = load_model(cfg, tokenizer, inference=inference) return model, tokenizer diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py new file mode 100644 index 0000000000..7fd60cb5f4 --- /dev/null +++ b/src/axolotl/evaluate.py @@ -0,0 +1,168 @@ +"""Module for evaluating models.""" + +import csv +import os +import sys +from pathlib import Path +from typing import Dict, Optional + +import torch +from accelerate.logging import get_logger + +from axolotl.common.cli import TrainerCliArgs +from axolotl.logging_config import configure_logging +from axolotl.train import TrainDatasetMeta +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_processor, load_tokenizer +from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer + +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +src_dir = os.path.join(project_root, "src") +sys.path.insert(0, src_dir) + +configure_logging() +LOG = get_logger("axolotl.evaluate") + + +def evaluate_dataset( + trainer, dataset, dataset_type: str, flash_optimum: bool = False +) -> Optional[Dict[str, float]]: + """Helper function to evaluate a single dataset safely. + + Args: + trainer: The trainer instance + dataset: Dataset to evaluate + dataset_type: Type of dataset ('train' or 'eval') + flash_optimum: Whether to use flash optimum + + Returns: + Dictionary of metrics or None if dataset is None + """ + if dataset is None: + return None + + LOG.info(f"Starting {dataset_type} set evaluation...") + + if flash_optimum: + with torch.backends.cuda.sdp_kernel( + enable_flash=True, + enable_math=True, + enable_mem_efficient=True, + ): + metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type) + else: + metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type) + + LOG.info(f"{dataset_type.capitalize()} set evaluation completed!") + LOG.info(f"{dataset_type.capitalize()} Metrics:") + for key, value in metrics.items(): + LOG.info(f"{key}: {value}") + + return metrics + + +def evaluate( + *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta +) -> Dict[str, float]: + """ + Evaluate a model on training and validation datasets + + Args: + cfg: Configuration dictionary + cli_args: Command line arguments + dataset_meta: Dataset metadata containing training and evaluation datasets + + Returns: + Tuple containing: + - The model (either PeftModel or PreTrainedModel) + - The tokenizer + - Dictionary of evaluation metrics + """ + # pylint: disable=duplicate-code + # Enable expandable segments for cuda allocation to improve VRAM usage + set_pytorch_cuda_alloc_conf() + + # Load tokenizer + LOG.debug( + f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", + main_process_only=True, + ) + tokenizer = load_tokenizer(cfg) + + # Load processor for multimodal models if needed + processor = None + if cfg.is_multimodal: + processor = load_processor(cfg, tokenizer) + + # Get datasets + train_dataset = dataset_meta.train_dataset + eval_dataset = dataset_meta.eval_dataset + total_num_steps = dataset_meta.total_num_steps + + # Load model + LOG.debug("loading model for evaluation...") + model, _ = load_model( + cfg, tokenizer, processor=processor, inference=cli_args.inference + ) + + # Set up trainer + trainer = setup_trainer( + cfg, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + model=(model, None, None), # No need for model_ref or peft_config + tokenizer=tokenizer, + processor=processor, + total_num_steps=total_num_steps, + ) + + # Evaluate datasets + all_metrics = {} + train_metrics = evaluate_dataset(trainer, train_dataset, "train", cfg.flash_optimum) + eval_metrics = evaluate_dataset(trainer, eval_dataset, "eval", cfg.flash_optimum) + + if train_metrics: + all_metrics.update(train_metrics) + if eval_metrics: + all_metrics.update(eval_metrics) + + # Save metrics to CSV if output directory is specified and we have metrics + if cfg.output_dir and (train_metrics or eval_metrics): + output_dir = Path(cfg.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + metrics_file = output_dir / "eval_summary.csv" + with metrics_file.open("w", newline="", encoding="utf-8") as file: + writer = csv.writer(file) + writer.writerow(["metric", "training", "validation"]) + + # Get unique metric names (removing prefixes) from available metrics + train_metric_names = { + k.replace("train_", ""): k for k in (train_metrics or {}) + } + eval_metric_names = { + k.replace("eval_", ""): k for k in (eval_metrics or {}) + } + all_metric_names = sorted( + set(train_metric_names.keys()) | set(eval_metric_names.keys()) + ) + + for metric_name in all_metric_names: + train_value = ( + train_metrics.get(train_metric_names.get(metric_name, ""), "") + if train_metrics + else "" + ) + eval_value = ( + eval_metrics.get(eval_metric_names.get(metric_name, ""), "") + if eval_metrics + else "" + ) + writer.writerow([metric_name, train_value, eval_value]) + + LOG.info(f"Evaluation results saved to {metrics_file}") + + del model + del tokenizer + + return all_metrics diff --git a/src/axolotl/train.py b/src/axolotl/train.py index c8576f1b48..851a71e547 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -24,7 +24,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.models import load_model, load_processor, load_tokenizer -from axolotl.utils.trainer import setup_trainer +from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer try: from optimum.bettertransformer import BetterTransformer @@ -53,25 +53,22 @@ class TrainDatasetMeta: def train( *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: - # enable expandable segments for cuda allocation to improve VRAM usage - torch_version = torch.__version__.split(".") - torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) - if torch_major == 2 and torch_minor >= 2: - if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: - os.environ[ - "PYTORCH_CUDA_ALLOC_CONF" - ] = "expandable_segments:True,roundup_power2_divisions:16" - - # load the tokenizer first + # Enable expandable segments for cuda allocation to improve VRAM usage + set_pytorch_cuda_alloc_conf() + + # Load tokenizer LOG.debug( f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", main_process_only=True, ) tokenizer = load_tokenizer(cfg) + + # Load processor for multimodal models if needed processor = None if cfg.is_multimodal: processor = load_processor(cfg, tokenizer) + # Get datasets train_dataset = dataset_meta.train_dataset eval_dataset = dataset_meta.eval_dataset total_num_steps = dataset_meta.total_num_steps diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index f56fe8f38c..286e5f2d70 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -119,7 +119,9 @@ def prepare_dataset(cfg, tokenizer, processor=None): eval_dataset = None if cfg.dataset_exact_deduplication: LOG.info("Deduplication not available for pretrained datasets") + return train_dataset, eval_dataset, cfg.max_steps, prompters + if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) if total_eval_steps == 0: @@ -134,6 +136,7 @@ def prepare_dataset(cfg, tokenizer, processor=None): LOG.info(f"Maximum number of steps set at {total_num_steps}") else: total_num_steps = calculate_total_num_steps(cfg, train_dataset) + return train_dataset, eval_dataset, total_num_steps, prompters diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 32e54c9a86..fd09b3eb67 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -512,6 +512,17 @@ def prepare_opinionated_env(cfg): os.environ["TOKENIZERS_PARALLELISM"] = "false" +def set_pytorch_cuda_alloc_conf(): + """Set up CUDA allocation config if using PyTorch >= 2.2""" + torch_version = torch.__version__.split(".") + torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) + if torch_major == 2 and torch_minor >= 2: + if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: + os.environ[ + "PYTORCH_CUDA_ALLOC_CONF" + ] = "expandable_segments:True,roundup_power2_divisions:16" + + def setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps ): diff --git a/tests/cli/test_cli_base.py b/tests/cli/test_cli_base.py new file mode 100644 index 0000000000..6dbae045f6 --- /dev/null +++ b/tests/cli/test_cli_base.py @@ -0,0 +1,73 @@ +"""Base test class for CLI commands.""" + +from pathlib import Path +from unittest.mock import patch + +from axolotl.cli.main import cli + + +class BaseCliTest: + """Base class for CLI command tests.""" + + def _test_cli_validation(self, cli_runner, command: str): + """Test CLI validation for a command. + + Args: + cli_runner: CLI runner fixture + command: Command to test (train/evaluate) + """ + # Test missing config file + result = cli_runner.invoke(cli, [command, "--no-accelerate"]) + assert result.exit_code != 0 + + # Test non-existent config file + result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"]) + assert result.exit_code != 0 + assert "Error: Invalid value for 'CONFIG'" in result.output + + def _test_basic_execution( + self, cli_runner, tmp_path: Path, valid_test_config: str, command: str + ): + """Test basic execution with accelerate. + + Args: + cli_runner: CLI runner fixture + tmp_path: Temporary path fixture + valid_test_config: Valid config fixture + command: Command to test (train/evaluate) + """ + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("subprocess.run") as mock: + result = cli_runner.invoke(cli, [command, str(config_path)]) + + assert mock.called + assert mock.call_args.args[0] == [ + "accelerate", + "launch", + "-m", + f"axolotl.cli.{command}", + str(config_path), + "--debug-num-examples", + "0", + ] + assert mock.call_args.kwargs == {"check": True} + assert result.exit_code == 0 + + def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str): + """Test CLI argument overrides. + + Args: + tmp_path: Temporary path fixture + valid_test_config: Valid config fixture + command: Command to test (train/evaluate) + """ + config_path = tmp_path / "config.yml" + output_dir = tmp_path / "model-out" + + test_config = valid_test_config.replace( + "output_dir: model-out", f"output_dir: {output_dir}" + ) + config_path.write_text(test_config) + return config_path diff --git a/tests/cli/test_cli_evaluate.py b/tests/cli/test_cli_evaluate.py new file mode 100644 index 0000000000..d8eb41467f --- /dev/null +++ b/tests/cli/test_cli_evaluate.py @@ -0,0 +1,67 @@ +"""Tests for evaluate CLI command.""" + +from unittest.mock import patch + +from axolotl.cli.main import cli + +from .test_cli_base import BaseCliTest + + +class TestEvaluateCommand(BaseCliTest): + """Test cases for evaluate command.""" + + cli = cli + + def test_evaluate_cli_validation(self, cli_runner): + """Test CLI validation""" + self._test_cli_validation(cli_runner, "evaluate") + + def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config): + """Test basic successful execution""" + self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate") + + def test_evaluate_basic_execution_no_accelerate( + self, cli_runner, tmp_path, valid_test_config + ): + """Test basic successful execution without accelerate""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate: + result = cli_runner.invoke( + cli, + [ + "evaluate", + str(config_path), + "--no-accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_evaluate.assert_called_once() + + def test_evaluate_cli_overrides(self, cli_runner, tmp_path, valid_test_config): + """Test CLI arguments properly override config values""" + config_path = self._test_cli_overrides(tmp_path, valid_test_config) + + with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate: + result = cli_runner.invoke( + cli, + [ + "evaluate", + str(config_path), + "--micro-batch-size", + "2", + "--sequence-len", + "128", + "--no-accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_evaluate.assert_called_once() + cfg = mock_evaluate.call_args[0][0] + assert cfg.micro_batch_size == 2 + assert cfg.sequence_len == 128 diff --git a/tests/cli/test_cli_train.py b/tests/cli/test_cli_train.py index 7f028fb4f2..560f3caf58 100644 --- a/tests/cli/test_cli_train.py +++ b/tests/cli/test_cli_train.py @@ -1,98 +1,71 @@ -"""pytest tests for axolotl CLI train command.""" +"""Tests for train CLI command.""" + from unittest.mock import MagicMock, patch from axolotl.cli.main import cli - -def test_train_cli_validation(cli_runner): - """Test CLI validation""" - # Test missing config file - result = cli_runner.invoke(cli, ["train", "--no-accelerate"]) - assert result.exit_code != 0 - - # Test non-existent config file - result = cli_runner.invoke(cli, ["train", "nonexistent.yml", "--no-accelerate"]) - assert result.exit_code != 0 - assert "Error: Invalid value for 'CONFIG'" in result.output - - -def test_train_basic_execution(cli_runner, tmp_path, valid_test_config): - """Test basic successful execution""" - config_path = tmp_path / "config.yml" - config_path.write_text(valid_test_config) - - with patch("subprocess.run") as mock: - result = cli_runner.invoke(cli, ["train", str(config_path)]) - - assert mock.called - assert mock.call_args.args[0] == [ - "accelerate", - "launch", - "-m", - "axolotl.cli.train", - str(config_path), - "--debug-num-examples", - "0", - ] - assert mock.call_args.kwargs == {"check": True} - assert result.exit_code == 0 - - -def test_train_basic_execution_no_accelerate(cli_runner, tmp_path, valid_test_config): - """Test basic successful execution""" - config_path = tmp_path / "config.yml" - config_path.write_text(valid_test_config) - - with patch("axolotl.cli.train.train") as mock_train: - mock_train.return_value = (MagicMock(), MagicMock()) - - result = cli_runner.invoke( - cli, - [ - "train", - str(config_path), - "--learning-rate", - "1e-4", - "--micro-batch-size", - "2", - "--no-accelerate", - ], - catch_exceptions=False, - ) - - assert result.exit_code == 0 - mock_train.assert_called_once() - - -def test_train_cli_overrides(cli_runner, tmp_path, valid_test_config): - """Test CLI arguments properly override config values""" - config_path = tmp_path / "config.yml" - output_dir = tmp_path / "model-out" - - test_config = valid_test_config.replace( - "output_dir: model-out", f"output_dir: {output_dir}" - ) - config_path.write_text(test_config) - - with patch("axolotl.cli.train.train") as mock_train: - mock_train.return_value = (MagicMock(), MagicMock()) - - result = cli_runner.invoke( - cli, - [ - "train", - str(config_path), - "--learning-rate", - "1e-4", - "--micro-batch-size", - "2", - "--no-accelerate", - ], - catch_exceptions=False, - ) - - assert result.exit_code == 0 - mock_train.assert_called_once() - cfg = mock_train.call_args[1]["cfg"] - assert cfg["learning_rate"] == 1e-4 - assert cfg["micro_batch_size"] == 2 +from .test_cli_base import BaseCliTest + + +class TestTrainCommand(BaseCliTest): + """Test cases for train command.""" + + cli = cli + + def test_train_cli_validation(self, cli_runner): + """Test CLI validation""" + self._test_cli_validation(cli_runner, "train") + + def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config): + """Test basic successful execution""" + self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train") + + def test_train_basic_execution_no_accelerate( + self, cli_runner, tmp_path, valid_test_config + ): + """Test basic successful execution without accelerate""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("axolotl.cli.train.train") as mock_train: + mock_train.return_value = (MagicMock(), MagicMock()) + + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--no-accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_train.assert_called_once() + + def test_train_cli_overrides(self, cli_runner, tmp_path, valid_test_config): + """Test CLI arguments properly override config values""" + config_path = self._test_cli_overrides(tmp_path, valid_test_config) + + with patch("axolotl.cli.train.train") as mock_train: + mock_train.return_value = (MagicMock(), MagicMock()) + + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--learning-rate", + "1e-4", + "--micro-batch-size", + "2", + "--no-accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_train.assert_called_once() + cfg = mock_train.call_args[1]["cfg"] + assert cfg["learning_rate"] == 1e-4 + assert cfg["micro_batch_size"] == 2