From 2cda88413772bbb2a65570004f0b5f25a717d66c Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 13 Dec 2024 16:45:17 -0500 Subject: [PATCH 1/4] basic evaluate CLI command / codepath --- src/axolotl/cli/evaluate.py | 51 +++++++++++++++ src/axolotl/cli/main.py | 27 +++++++- src/axolotl/common/cli.py | 25 ++++++-- src/axolotl/evaluate.py | 121 ++++++++++++++++++++++++++++++++++++ 4 files changed, 217 insertions(+), 7 deletions(-) create mode 100644 src/axolotl/cli/evaluate.py create mode 100644 src/axolotl/evaluate.py diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py new file mode 100644 index 0000000000..202a7a6917 --- /dev/null +++ b/src/axolotl/cli/evaluate.py @@ -0,0 +1,51 @@ +""" +CLI to run training on a model +""" +import logging +from pathlib import Path +from typing import Union + +import fire +from dotenv import load_dotenv +from transformers.hf_argparser import HfArgumentParser + +from axolotl.cli import ( + check_accelerate_default_config, + check_user_token, + load_cfg, + load_datasets, + load_rl_datasets, + print_axolotl_text_art, +) +from axolotl.common.cli import TrainerCliArgs +from axolotl.evaluate import evaluate + +LOG = logging.getLogger("axolotl.cli.train") + + +def do_evaluate(cfg, cli_args) -> None: + print_axolotl_text_art() + check_accelerate_default_config() + check_user_token() + + if cfg.rl: # and cfg.rl != "orpo": + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + else: + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + + +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: + # pylint: disable=duplicate-code + parsed_cfg = load_cfg(config, **kwargs) + parser = HfArgumentParser(TrainerCliArgs) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + do_evaluate(parsed_cfg, parsed_cli_args) + + +if __name__ == "__main__": + load_dotenv() + fire.Fire(do_cli) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index a77410776a..ec7f5b6947 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -12,7 +12,7 @@ build_command, fetch_from_github, ) -from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs +from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig @@ -60,6 +60,31 @@ def train(config: str, accelerate: bool, **kwargs): do_cli(config=config, **kwargs) +@cli.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@click.option( + "--accelerate/--no-accelerate", + default=True, + help="Use accelerate launch for multi-GPU training", +) +@add_options_from_dataclass(EvaluateCliArgs) +@add_options_from_config(AxolotlInputConfig) +def evaluate(config: str, accelerate: bool, **kwargs): + """Evaluate a model.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + if accelerate: + base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"] + if config: + base_cmd.append(config) + cmd = build_command(base_cmd, kwargs) + subprocess.run(cmd, check=True) # nosec B603 + else: + from axolotl.cli.evaluate import do_cli + + do_cli(config=config, **kwargs) + + @cli.command() @click.argument("config", type=click.Path(exists=True, path_type=str)) @click.option( diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 6a3a22e637..02ad9201b8 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -15,6 +15,19 @@ LOG = logging.getLogger("axolotl.common.cli") +@dataclass +class PreprocessCliArgs: + """ + dataclass representing arguments for preprocessing only + """ + + debug: bool = field(default=False) + debug_text_only: bool = field(default=False) + debug_num_examples: int = field(default=1) + prompter: Optional[str] = field(default=None) + download: Optional[bool] = field(default=True) + + @dataclass class TrainerCliArgs: """ @@ -31,16 +44,14 @@ class TrainerCliArgs: @dataclass -class PreprocessCliArgs: +class EvaluateCliArgs: """ - dataclass representing arguments for preprocessing only + dataclass representing the various evaluation arguments """ debug: bool = field(default=False) debug_text_only: bool = field(default=False) - debug_num_examples: int = field(default=1) - prompter: Optional[str] = field(default=None) - download: Optional[bool] = field(default=True) + debug_num_examples: int = field(default=0) def load_model_and_tokenizer( @@ -50,7 +61,9 @@ def load_model_and_tokenizer( ): LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer = load_tokenizer(cfg) + LOG.info("loading model and (optionally) peft_config...") - model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + inference = getattr(cli_args, "inference", False) + model, _ = load_model(cfg, tokenizer, inference=inference) return model, tokenizer diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py new file mode 100644 index 0000000000..0febfaac68 --- /dev/null +++ b/src/axolotl/evaluate.py @@ -0,0 +1,121 @@ +"""Module for evaluating models.""" + +import os +import sys +from pathlib import Path +from typing import Tuple, Union + +import torch +from accelerate.logging import get_logger +from peft import PeftModel +from transformers import PreTrainedModel, PreTrainedTokenizer + +from axolotl.common.cli import TrainerCliArgs +from axolotl.logging_config import configure_logging +from axolotl.train import TrainDatasetMeta +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_processor, load_tokenizer +from axolotl.utils.trainer import setup_trainer + +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +src_dir = os.path.join(project_root, "src") +sys.path.insert(0, src_dir) + +configure_logging() +LOG = get_logger("axolotl.eval") + + +def evaluate( + *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta +) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer, dict]: + """ + Evaluate a model on a dataset + + Args: + cfg: Configuration dictionary + cli_args: Command line arguments + dataset_meta: Dataset metadata containing evaluation dataset + + Returns: + Tuple containing: + - The model (either PeftModel or PreTrainedModel) + - The tokenizer + - Dictionary of evaluation metrics + """ + # Set up CUDA allocation config if using PyTorch >= 2.2 + torch_version = torch.__version__.split(".") + torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) + if torch_major == 2 and torch_minor >= 2: + if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: + os.environ[ + "PYTORCH_CUDA_ALLOC_CONF" + ] = "expandable_segments:True,roundup_power2_divisions:16" + + # Load tokenizer + LOG.debug( + f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", + main_process_only=True, + ) + tokenizer = load_tokenizer(cfg) + + # Load processor for multimodal models if needed + processor = None + if cfg.is_multimodal: + processor = load_processor(cfg, tokenizer) + + # Get evaluation dataset + eval_dataset = dataset_meta.eval_dataset + total_num_steps = dataset_meta.total_num_steps + + # Load model + LOG.debug("loading model for evaluation...") + model, _ = load_model( + cfg, tokenizer, processor=processor, inference=cli_args.inference + ) + + # Set up trainer + trainer = setup_trainer( + cfg, + train_dataset=eval_dataset, # None # No training dataset needed for evaluation + eval_dataset=eval_dataset, + model=(model, None, None), # No need for model_ref or peft_config + tokenizer=tokenizer, + processor=processor, + total_num_steps=total_num_steps, + ) + + # Run evaluation + LOG.info("Starting evaluation...") + + if cfg.flash_optimum: + with torch.backends.cuda.sdp_kernel( + enable_flash=True, + enable_math=True, + enable_mem_efficient=True, + ): + metrics = trainer.evaluate() + else: + metrics = trainer.evaluate() + + # Log results + LOG.info("Evaluation completed!") + LOG.info("Metrics:") + for key, value in metrics.items(): + LOG.info(f"{key}: {value}") + + # Save metrics to file if output directory is specified + if cfg.output_dir: + output_dir = Path(cfg.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + metrics_file = output_dir / "eval_results.txt" + with metrics_file.open("w", encoding="utf-8") as file: + for key, value in metrics.items(): + file.write(f"{key} = {value}\n") + + LOG.info(f"Evaluation results saved to {metrics_file}") + + del model + del tokenizer + + return metrics From 5918dc288e8d1be303237b0e4ffcec2774f0647c Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 16 Dec 2024 12:02:25 -0500 Subject: [PATCH 2/4] tests for evaluate CLI command --- tests/cli/test_cli_evaluate.py | 120 +++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 tests/cli/test_cli_evaluate.py diff --git a/tests/cli/test_cli_evaluate.py b/tests/cli/test_cli_evaluate.py new file mode 100644 index 0000000000..4e6e352d68 --- /dev/null +++ b/tests/cli/test_cli_evaluate.py @@ -0,0 +1,120 @@ +"""pytest tests for axolotl CLI evaluate command.""" +from unittest.mock import patch + +from axolotl.cli.main import cli + + +def test_evaluate_cli_validation(cli_runner): + """Test CLI validation""" + # Test missing config file + result = cli_runner.invoke(cli, ["evaluate", "--no-accelerate"]) + assert result.exit_code != 0 + + # Test non-existent config file + result = cli_runner.invoke(cli, ["evaluate", "nonexistent.yml", "--no-accelerate"]) + assert result.exit_code != 0 + assert "Error: Invalid value for 'CONFIG'" in result.output + + +def test_evaluate_basic_execution(cli_runner, tmp_path, valid_test_config): + """Test basic successful execution""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("subprocess.run") as mock: + result = cli_runner.invoke(cli, ["evaluate", str(config_path)]) + + assert mock.called + assert mock.call_args.args[0] == [ + "accelerate", + "launch", + "-m", + "axolotl.cli.evaluate", + str(config_path), + "--debug-num-examples", + "0", + ] + assert mock.call_args.kwargs == {"check": True} + assert result.exit_code == 0 + + +def test_evaluate_basic_execution_no_accelerate( + cli_runner, tmp_path, valid_test_config +): + """Test basic successful execution without accelerate""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate: + result = cli_runner.invoke( + cli, + [ + "evaluate", + str(config_path), + "--micro-batch-size", + "2", + "--no-accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_evaluate.assert_called_once() + cfg = mock_evaluate.call_args[0][0] + assert cfg.micro_batch_size == 2 + + +def test_evaluate_cli_overrides(cli_runner, tmp_path, valid_test_config): + """Test CLI arguments properly override config values""" + config_path = tmp_path / "config.yml" + output_dir = tmp_path / "model-out" + + test_config = valid_test_config.replace( + "output_dir: model-out", f"output_dir: {output_dir}" + ) + config_path.write_text(test_config) + + with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate: + result = cli_runner.invoke( + cli, + [ + "evaluate", + str(config_path), + "--micro-batch-size", + "2", + "--sequence-len", + "128", + "--no-accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_evaluate.assert_called_once() + cfg = mock_evaluate.call_args[0][0] + assert cfg.micro_batch_size == 2 + assert cfg.sequence_len == 128 + + +def test_evaluate_with_rl_dpo(cli_runner, tmp_path, valid_test_config): + """Test evaluation with DPO reinforcement learning""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate: + result = cli_runner.invoke( + cli, + [ + "evaluate", + str(config_path), + "--rl", + "dpo", + "--no-accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_evaluate.assert_called_once() + cfg = mock_evaluate.call_args[0][0] + assert cfg.rl == "dpo" From 2fb3ed5aa349fdbd7ce0df97b632c68c122dce05 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 16 Dec 2024 13:39:35 -0500 Subject: [PATCH 3/4] fixes and cleanup --- src/axolotl/cli/__init__.py | 1 + src/axolotl/cli/evaluate.py | 3 +- src/axolotl/evaluate.py | 119 ++++++++++++++++------ src/axolotl/utils/data/sft.py | 3 + tests/cli/test_cli_base.py | 73 ++++++++++++++ tests/cli/test_cli_evaluate.py | 179 ++++++++++++--------------------- tests/cli/test_cli_train.py | 161 ++++++++++++----------------- 7 files changed, 295 insertions(+), 244 deletions(-) create mode 100644 tests/cli/test_cli_base.py diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index d07b10ce3d..6506e44f77 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -476,6 +476,7 @@ def load_datasets( tokenizer, processor=processor, ) + print(train_dataset, eval_dataset, total_num_steps) if ( cli_args.debug diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index 202a7a6917..8e99d6f4b1 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -20,10 +20,11 @@ from axolotl.common.cli import TrainerCliArgs from axolotl.evaluate import evaluate -LOG = logging.getLogger("axolotl.cli.train") +LOG = logging.getLogger("axolotl.cli.evaluate") def do_evaluate(cfg, cli_args) -> None: + # pylint: disable=duplicate-code print_axolotl_text_art() check_accelerate_default_config() check_user_token() diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index 0febfaac68..d53d7bb5da 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -1,14 +1,13 @@ """Module for evaluating models.""" +import csv import os import sys from pathlib import Path -from typing import Tuple, Union +from typing import Dict, Optional import torch from accelerate.logging import get_logger -from peft import PeftModel -from transformers import PreTrainedModel, PreTrainedTokenizer from axolotl.common.cli import TrainerCliArgs from axolotl.logging_config import configure_logging @@ -22,19 +21,56 @@ sys.path.insert(0, src_dir) configure_logging() -LOG = get_logger("axolotl.eval") +LOG = get_logger("axolotl.evaluate") + + +def evaluate_dataset( + trainer, dataset, dataset_type: str, flash_optimum: bool = False +) -> Optional[Dict[str, float]]: + """Helper function to evaluate a single dataset safely. + + Args: + trainer: The trainer instance + dataset: Dataset to evaluate + dataset_type: Type of dataset ('train' or 'eval') + flash_optimum: Whether to use flash optimum + + Returns: + Dictionary of metrics or None if dataset is None + """ + if dataset is None: + return None + + LOG.info(f"Starting {dataset_type} set evaluation...") + + if flash_optimum: + with torch.backends.cuda.sdp_kernel( + enable_flash=True, + enable_math=True, + enable_mem_efficient=True, + ): + metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type) + else: + metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type) + + LOG.info(f"{dataset_type.capitalize()} set evaluation completed!") + LOG.info(f"{dataset_type.capitalize()} Metrics:") + for key, value in metrics.items(): + LOG.info(f"{key}: {value}") + + return metrics def evaluate( *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta -) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer, dict]: +) -> Dict[str, float]: """ - Evaluate a model on a dataset + Evaluate a model on training and validation datasets Args: cfg: Configuration dictionary cli_args: Command line arguments - dataset_meta: Dataset metadata containing evaluation dataset + dataset_meta: Dataset metadata containing training and evaluation datasets Returns: Tuple containing: @@ -42,6 +78,7 @@ def evaluate( - The tokenizer - Dictionary of evaluation metrics """ + # pylint: disable=duplicate-code # Set up CUDA allocation config if using PyTorch >= 2.2 torch_version = torch.__version__.split(".") torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) @@ -63,7 +100,8 @@ def evaluate( if cfg.is_multimodal: processor = load_processor(cfg, tokenizer) - # Get evaluation dataset + # Get datasets + train_dataset = dataset_meta.train_dataset eval_dataset = dataset_meta.eval_dataset total_num_steps = dataset_meta.total_num_steps @@ -76,7 +114,7 @@ def evaluate( # Set up trainer trainer = setup_trainer( cfg, - train_dataset=eval_dataset, # None # No training dataset needed for evaluation + train_dataset=train_dataset, eval_dataset=eval_dataset, model=(model, None, None), # No need for model_ref or peft_config tokenizer=tokenizer, @@ -84,38 +122,53 @@ def evaluate( total_num_steps=total_num_steps, ) - # Run evaluation - LOG.info("Starting evaluation...") + # Evaluate datasets + all_metrics = {} + train_metrics = evaluate_dataset(trainer, train_dataset, "train", cfg.flash_optimum) + eval_metrics = evaluate_dataset(trainer, eval_dataset, "eval", cfg.flash_optimum) - if cfg.flash_optimum: - with torch.backends.cuda.sdp_kernel( - enable_flash=True, - enable_math=True, - enable_mem_efficient=True, - ): - metrics = trainer.evaluate() - else: - metrics = trainer.evaluate() + if train_metrics: + all_metrics.update(train_metrics) + if eval_metrics: + all_metrics.update(eval_metrics) - # Log results - LOG.info("Evaluation completed!") - LOG.info("Metrics:") - for key, value in metrics.items(): - LOG.info(f"{key}: {value}") - - # Save metrics to file if output directory is specified - if cfg.output_dir: + # Save metrics to CSV if output directory is specified and we have metrics + if cfg.output_dir and (train_metrics or eval_metrics): output_dir = Path(cfg.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - metrics_file = output_dir / "eval_results.txt" - with metrics_file.open("w", encoding="utf-8") as file: - for key, value in metrics.items(): - file.write(f"{key} = {value}\n") + metrics_file = output_dir / "eval_summary.csv" + with metrics_file.open("w", newline="", encoding="utf-8") as file: + writer = csv.writer(file) + writer.writerow(["metric", "training", "validation"]) + + # Get unique metric names (removing prefixes) from available metrics + train_metric_names = { + k.replace("train_", ""): k for k in (train_metrics or {}) + } + eval_metric_names = { + k.replace("eval_", ""): k for k in (eval_metrics or {}) + } + all_metric_names = sorted( + set(train_metric_names.keys()) | set(eval_metric_names.keys()) + ) + + for metric_name in all_metric_names: + train_value = ( + train_metrics.get(train_metric_names.get(metric_name, ""), "") + if train_metrics + else "" + ) + eval_value = ( + eval_metrics.get(eval_metric_names.get(metric_name, ""), "") + if eval_metrics + else "" + ) + writer.writerow([metric_name, train_value, eval_value]) LOG.info(f"Evaluation results saved to {metrics_file}") del model del tokenizer - return metrics + return all_metrics diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index f56fe8f38c..286e5f2d70 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -119,7 +119,9 @@ def prepare_dataset(cfg, tokenizer, processor=None): eval_dataset = None if cfg.dataset_exact_deduplication: LOG.info("Deduplication not available for pretrained datasets") + return train_dataset, eval_dataset, cfg.max_steps, prompters + if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) if total_eval_steps == 0: @@ -134,6 +136,7 @@ def prepare_dataset(cfg, tokenizer, processor=None): LOG.info(f"Maximum number of steps set at {total_num_steps}") else: total_num_steps = calculate_total_num_steps(cfg, train_dataset) + return train_dataset, eval_dataset, total_num_steps, prompters diff --git a/tests/cli/test_cli_base.py b/tests/cli/test_cli_base.py new file mode 100644 index 0000000000..6dbae045f6 --- /dev/null +++ b/tests/cli/test_cli_base.py @@ -0,0 +1,73 @@ +"""Base test class for CLI commands.""" + +from pathlib import Path +from unittest.mock import patch + +from axolotl.cli.main import cli + + +class BaseCliTest: + """Base class for CLI command tests.""" + + def _test_cli_validation(self, cli_runner, command: str): + """Test CLI validation for a command. + + Args: + cli_runner: CLI runner fixture + command: Command to test (train/evaluate) + """ + # Test missing config file + result = cli_runner.invoke(cli, [command, "--no-accelerate"]) + assert result.exit_code != 0 + + # Test non-existent config file + result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"]) + assert result.exit_code != 0 + assert "Error: Invalid value for 'CONFIG'" in result.output + + def _test_basic_execution( + self, cli_runner, tmp_path: Path, valid_test_config: str, command: str + ): + """Test basic execution with accelerate. + + Args: + cli_runner: CLI runner fixture + tmp_path: Temporary path fixture + valid_test_config: Valid config fixture + command: Command to test (train/evaluate) + """ + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("subprocess.run") as mock: + result = cli_runner.invoke(cli, [command, str(config_path)]) + + assert mock.called + assert mock.call_args.args[0] == [ + "accelerate", + "launch", + "-m", + f"axolotl.cli.{command}", + str(config_path), + "--debug-num-examples", + "0", + ] + assert mock.call_args.kwargs == {"check": True} + assert result.exit_code == 0 + + def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str): + """Test CLI argument overrides. + + Args: + tmp_path: Temporary path fixture + valid_test_config: Valid config fixture + command: Command to test (train/evaluate) + """ + config_path = tmp_path / "config.yml" + output_dir = tmp_path / "model-out" + + test_config = valid_test_config.replace( + "output_dir: model-out", f"output_dir: {output_dir}" + ) + config_path.write_text(test_config) + return config_path diff --git a/tests/cli/test_cli_evaluate.py b/tests/cli/test_cli_evaluate.py index 4e6e352d68..d8eb41467f 100644 --- a/tests/cli/test_cli_evaluate.py +++ b/tests/cli/test_cli_evaluate.py @@ -1,120 +1,67 @@ -"""pytest tests for axolotl CLI evaluate command.""" +"""Tests for evaluate CLI command.""" + from unittest.mock import patch from axolotl.cli.main import cli - -def test_evaluate_cli_validation(cli_runner): - """Test CLI validation""" - # Test missing config file - result = cli_runner.invoke(cli, ["evaluate", "--no-accelerate"]) - assert result.exit_code != 0 - - # Test non-existent config file - result = cli_runner.invoke(cli, ["evaluate", "nonexistent.yml", "--no-accelerate"]) - assert result.exit_code != 0 - assert "Error: Invalid value for 'CONFIG'" in result.output - - -def test_evaluate_basic_execution(cli_runner, tmp_path, valid_test_config): - """Test basic successful execution""" - config_path = tmp_path / "config.yml" - config_path.write_text(valid_test_config) - - with patch("subprocess.run") as mock: - result = cli_runner.invoke(cli, ["evaluate", str(config_path)]) - - assert mock.called - assert mock.call_args.args[0] == [ - "accelerate", - "launch", - "-m", - "axolotl.cli.evaluate", - str(config_path), - "--debug-num-examples", - "0", - ] - assert mock.call_args.kwargs == {"check": True} - assert result.exit_code == 0 - - -def test_evaluate_basic_execution_no_accelerate( - cli_runner, tmp_path, valid_test_config -): - """Test basic successful execution without accelerate""" - config_path = tmp_path / "config.yml" - config_path.write_text(valid_test_config) - - with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate: - result = cli_runner.invoke( - cli, - [ - "evaluate", - str(config_path), - "--micro-batch-size", - "2", - "--no-accelerate", - ], - catch_exceptions=False, - ) - - assert result.exit_code == 0 - mock_evaluate.assert_called_once() - cfg = mock_evaluate.call_args[0][0] - assert cfg.micro_batch_size == 2 - - -def test_evaluate_cli_overrides(cli_runner, tmp_path, valid_test_config): - """Test CLI arguments properly override config values""" - config_path = tmp_path / "config.yml" - output_dir = tmp_path / "model-out" - - test_config = valid_test_config.replace( - "output_dir: model-out", f"output_dir: {output_dir}" - ) - config_path.write_text(test_config) - - with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate: - result = cli_runner.invoke( - cli, - [ - "evaluate", - str(config_path), - "--micro-batch-size", - "2", - "--sequence-len", - "128", - "--no-accelerate", - ], - catch_exceptions=False, - ) - - assert result.exit_code == 0 - mock_evaluate.assert_called_once() - cfg = mock_evaluate.call_args[0][0] - assert cfg.micro_batch_size == 2 - assert cfg.sequence_len == 128 - - -def test_evaluate_with_rl_dpo(cli_runner, tmp_path, valid_test_config): - """Test evaluation with DPO reinforcement learning""" - config_path = tmp_path / "config.yml" - config_path.write_text(valid_test_config) - - with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate: - result = cli_runner.invoke( - cli, - [ - "evaluate", - str(config_path), - "--rl", - "dpo", - "--no-accelerate", - ], - catch_exceptions=False, - ) - - assert result.exit_code == 0 - mock_evaluate.assert_called_once() - cfg = mock_evaluate.call_args[0][0] - assert cfg.rl == "dpo" +from .test_cli_base import BaseCliTest + + +class TestEvaluateCommand(BaseCliTest): + """Test cases for evaluate command.""" + + cli = cli + + def test_evaluate_cli_validation(self, cli_runner): + """Test CLI validation""" + self._test_cli_validation(cli_runner, "evaluate") + + def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config): + """Test basic successful execution""" + self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate") + + def test_evaluate_basic_execution_no_accelerate( + self, cli_runner, tmp_path, valid_test_config + ): + """Test basic successful execution without accelerate""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate: + result = cli_runner.invoke( + cli, + [ + "evaluate", + str(config_path), + "--no-accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_evaluate.assert_called_once() + + def test_evaluate_cli_overrides(self, cli_runner, tmp_path, valid_test_config): + """Test CLI arguments properly override config values""" + config_path = self._test_cli_overrides(tmp_path, valid_test_config) + + with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate: + result = cli_runner.invoke( + cli, + [ + "evaluate", + str(config_path), + "--micro-batch-size", + "2", + "--sequence-len", + "128", + "--no-accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_evaluate.assert_called_once() + cfg = mock_evaluate.call_args[0][0] + assert cfg.micro_batch_size == 2 + assert cfg.sequence_len == 128 diff --git a/tests/cli/test_cli_train.py b/tests/cli/test_cli_train.py index 7f028fb4f2..560f3caf58 100644 --- a/tests/cli/test_cli_train.py +++ b/tests/cli/test_cli_train.py @@ -1,98 +1,71 @@ -"""pytest tests for axolotl CLI train command.""" +"""Tests for train CLI command.""" + from unittest.mock import MagicMock, patch from axolotl.cli.main import cli - -def test_train_cli_validation(cli_runner): - """Test CLI validation""" - # Test missing config file - result = cli_runner.invoke(cli, ["train", "--no-accelerate"]) - assert result.exit_code != 0 - - # Test non-existent config file - result = cli_runner.invoke(cli, ["train", "nonexistent.yml", "--no-accelerate"]) - assert result.exit_code != 0 - assert "Error: Invalid value for 'CONFIG'" in result.output - - -def test_train_basic_execution(cli_runner, tmp_path, valid_test_config): - """Test basic successful execution""" - config_path = tmp_path / "config.yml" - config_path.write_text(valid_test_config) - - with patch("subprocess.run") as mock: - result = cli_runner.invoke(cli, ["train", str(config_path)]) - - assert mock.called - assert mock.call_args.args[0] == [ - "accelerate", - "launch", - "-m", - "axolotl.cli.train", - str(config_path), - "--debug-num-examples", - "0", - ] - assert mock.call_args.kwargs == {"check": True} - assert result.exit_code == 0 - - -def test_train_basic_execution_no_accelerate(cli_runner, tmp_path, valid_test_config): - """Test basic successful execution""" - config_path = tmp_path / "config.yml" - config_path.write_text(valid_test_config) - - with patch("axolotl.cli.train.train") as mock_train: - mock_train.return_value = (MagicMock(), MagicMock()) - - result = cli_runner.invoke( - cli, - [ - "train", - str(config_path), - "--learning-rate", - "1e-4", - "--micro-batch-size", - "2", - "--no-accelerate", - ], - catch_exceptions=False, - ) - - assert result.exit_code == 0 - mock_train.assert_called_once() - - -def test_train_cli_overrides(cli_runner, tmp_path, valid_test_config): - """Test CLI arguments properly override config values""" - config_path = tmp_path / "config.yml" - output_dir = tmp_path / "model-out" - - test_config = valid_test_config.replace( - "output_dir: model-out", f"output_dir: {output_dir}" - ) - config_path.write_text(test_config) - - with patch("axolotl.cli.train.train") as mock_train: - mock_train.return_value = (MagicMock(), MagicMock()) - - result = cli_runner.invoke( - cli, - [ - "train", - str(config_path), - "--learning-rate", - "1e-4", - "--micro-batch-size", - "2", - "--no-accelerate", - ], - catch_exceptions=False, - ) - - assert result.exit_code == 0 - mock_train.assert_called_once() - cfg = mock_train.call_args[1]["cfg"] - assert cfg["learning_rate"] == 1e-4 - assert cfg["micro_batch_size"] == 2 +from .test_cli_base import BaseCliTest + + +class TestTrainCommand(BaseCliTest): + """Test cases for train command.""" + + cli = cli + + def test_train_cli_validation(self, cli_runner): + """Test CLI validation""" + self._test_cli_validation(cli_runner, "train") + + def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config): + """Test basic successful execution""" + self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train") + + def test_train_basic_execution_no_accelerate( + self, cli_runner, tmp_path, valid_test_config + ): + """Test basic successful execution without accelerate""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("axolotl.cli.train.train") as mock_train: + mock_train.return_value = (MagicMock(), MagicMock()) + + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--no-accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_train.assert_called_once() + + def test_train_cli_overrides(self, cli_runner, tmp_path, valid_test_config): + """Test CLI arguments properly override config values""" + config_path = self._test_cli_overrides(tmp_path, valid_test_config) + + with patch("axolotl.cli.train.train") as mock_train: + mock_train.return_value = (MagicMock(), MagicMock()) + + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--learning-rate", + "1e-4", + "--micro-batch-size", + "2", + "--no-accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_train.assert_called_once() + cfg = mock_train.call_args[1]["cfg"] + assert cfg["learning_rate"] == 1e-4 + assert cfg["micro_batch_size"] == 2 From edfee9e4c0786364e04bb3a44af88b8657807627 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 16 Dec 2024 15:12:08 -0500 Subject: [PATCH 4/4] review comments; slightly DRYing up things --- src/axolotl/cli/__init__.py | 1 - src/axolotl/evaluate.py | 12 +++--------- src/axolotl/train.py | 19 ++++++++----------- src/axolotl/utils/trainer.py | 11 +++++++++++ 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 6506e44f77..d07b10ce3d 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -476,7 +476,6 @@ def load_datasets( tokenizer, processor=processor, ) - print(train_dataset, eval_dataset, total_num_steps) if ( cli_args.debug diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index d53d7bb5da..7fd60cb5f4 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -14,7 +14,7 @@ from axolotl.train import TrainDatasetMeta from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_processor, load_tokenizer -from axolotl.utils.trainer import setup_trainer +from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") @@ -79,14 +79,8 @@ def evaluate( - Dictionary of evaluation metrics """ # pylint: disable=duplicate-code - # Set up CUDA allocation config if using PyTorch >= 2.2 - torch_version = torch.__version__.split(".") - torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) - if torch_major == 2 and torch_minor >= 2: - if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: - os.environ[ - "PYTORCH_CUDA_ALLOC_CONF" - ] = "expandable_segments:True,roundup_power2_divisions:16" + # Enable expandable segments for cuda allocation to improve VRAM usage + set_pytorch_cuda_alloc_conf() # Load tokenizer LOG.debug( diff --git a/src/axolotl/train.py b/src/axolotl/train.py index c8576f1b48..851a71e547 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -24,7 +24,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.models import load_model, load_processor, load_tokenizer -from axolotl.utils.trainer import setup_trainer +from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer try: from optimum.bettertransformer import BetterTransformer @@ -53,25 +53,22 @@ class TrainDatasetMeta: def train( *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: - # enable expandable segments for cuda allocation to improve VRAM usage - torch_version = torch.__version__.split(".") - torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) - if torch_major == 2 and torch_minor >= 2: - if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: - os.environ[ - "PYTORCH_CUDA_ALLOC_CONF" - ] = "expandable_segments:True,roundup_power2_divisions:16" - - # load the tokenizer first + # Enable expandable segments for cuda allocation to improve VRAM usage + set_pytorch_cuda_alloc_conf() + + # Load tokenizer LOG.debug( f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", main_process_only=True, ) tokenizer = load_tokenizer(cfg) + + # Load processor for multimodal models if needed processor = None if cfg.is_multimodal: processor = load_processor(cfg, tokenizer) + # Get datasets train_dataset = dataset_meta.train_dataset eval_dataset = dataset_meta.eval_dataset total_num_steps = dataset_meta.total_num_steps diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 32e54c9a86..fd09b3eb67 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -512,6 +512,17 @@ def prepare_opinionated_env(cfg): os.environ["TOKENIZERS_PARALLELISM"] = "false" +def set_pytorch_cuda_alloc_conf(): + """Set up CUDA allocation config if using PyTorch >= 2.2""" + torch_version = torch.__version__.split(".") + torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) + if torch_major == 2 and torch_minor >= 2: + if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: + os.environ[ + "PYTORCH_CUDA_ALLOC_CONF" + ] = "expandable_segments:True,roundup_power2_divisions:16" + + def setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps ):