Skip to content

Commit

Permalink
fixes and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Saunders committed Dec 16, 2024
1 parent 06b25bf commit c92aab6
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 244 deletions.
1 change: 1 addition & 0 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ def load_datasets(
tokenizer,
processor=processor,
)
print(train_dataset, eval_dataset, total_num_steps)

if (
cli_args.debug
Expand Down
3 changes: 2 additions & 1 deletion src/axolotl/cli/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from axolotl.common.cli import TrainerCliArgs
from axolotl.evaluate import evaluate

LOG = logging.getLogger("axolotl.cli.train")
LOG = logging.getLogger("axolotl.cli.evaluate")


def do_evaluate(cfg, cli_args) -> None:
# pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
Expand Down
119 changes: 86 additions & 33 deletions src/axolotl/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Module for evaluating models."""

import csv
import os
import sys
from pathlib import Path
from typing import Tuple, Union
from typing import Dict, Optional

import torch
from accelerate.logging import get_logger
from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizer

from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
Expand All @@ -22,26 +21,64 @@
sys.path.insert(0, src_dir)

configure_logging()
LOG = get_logger("axolotl.eval")
LOG = get_logger("axolotl.evaluate")


def evaluate_dataset(
trainer, dataset, dataset_type: str, flash_optimum: bool = False
) -> Optional[Dict[str, float]]:
"""Helper function to evaluate a single dataset safely.
Args:
trainer: The trainer instance
dataset: Dataset to evaluate
dataset_type: Type of dataset ('train' or 'eval')
flash_optimum: Whether to use flash optimum
Returns:
Dictionary of metrics or None if dataset is None
"""
if dataset is None:
return None

LOG.info(f"Starting {dataset_type} set evaluation...")

if flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
):
metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type)
else:
metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type)

LOG.info(f"{dataset_type.capitalize()} set evaluation completed!")
LOG.info(f"{dataset_type.capitalize()} Metrics:")
for key, value in metrics.items():
LOG.info(f"{key}: {value}")

return metrics


def evaluate(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer, dict]:
) -> Dict[str, float]:
"""
Evaluate a model on a dataset
Evaluate a model on training and validation datasets
Args:
cfg: Configuration dictionary
cli_args: Command line arguments
dataset_meta: Dataset metadata containing evaluation dataset
dataset_meta: Dataset metadata containing training and evaluation datasets
Returns:
Tuple containing:
- The model (either PeftModel or PreTrainedModel)
- The tokenizer
- Dictionary of evaluation metrics
"""
# pylint: disable=duplicate-code
# Set up CUDA allocation config if using PyTorch >= 2.2
torch_version = torch.__version__.split(".")
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
Expand All @@ -63,7 +100,8 @@ def evaluate(
if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer)

# Get evaluation dataset
# Get datasets
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps

Expand All @@ -76,46 +114,61 @@ def evaluate(
# Set up trainer
trainer = setup_trainer(
cfg,
train_dataset=eval_dataset, # None # No training dataset needed for evaluation
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=(model, None, None), # No need for model_ref or peft_config
tokenizer=tokenizer,
processor=processor,
total_num_steps=total_num_steps,
)

# Run evaluation
LOG.info("Starting evaluation...")
# Evaluate datasets
all_metrics = {}
train_metrics = evaluate_dataset(trainer, train_dataset, "train", cfg.flash_optimum)
eval_metrics = evaluate_dataset(trainer, eval_dataset, "eval", cfg.flash_optimum)

if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
):
metrics = trainer.evaluate()
else:
metrics = trainer.evaluate()
if train_metrics:
all_metrics.update(train_metrics)
if eval_metrics:
all_metrics.update(eval_metrics)

# Log results
LOG.info("Evaluation completed!")
LOG.info("Metrics:")
for key, value in metrics.items():
LOG.info(f"{key}: {value}")

# Save metrics to file if output directory is specified
if cfg.output_dir:
# Save metrics to CSV if output directory is specified and we have metrics
if cfg.output_dir and (train_metrics or eval_metrics):
output_dir = Path(cfg.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

metrics_file = output_dir / "eval_results.txt"
with metrics_file.open("w", encoding="utf-8") as file:
for key, value in metrics.items():
file.write(f"{key} = {value}\n")
metrics_file = output_dir / "eval_summary.csv"
with metrics_file.open("w", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
writer.writerow(["metric", "training", "validation"])

# Get unique metric names (removing prefixes) from available metrics
train_metric_names = {
k.replace("train_", ""): k for k in (train_metrics or {})
}
eval_metric_names = {
k.replace("eval_", ""): k for k in (eval_metrics or {})
}
all_metric_names = sorted(
set(train_metric_names.keys()) | set(eval_metric_names.keys())
)

for metric_name in all_metric_names:
train_value = (
train_metrics.get(train_metric_names.get(metric_name, ""), "")
if train_metrics
else ""
)
eval_value = (
eval_metrics.get(eval_metric_names.get(metric_name, ""), "")
if eval_metrics
else ""
)
writer.writerow([metric_name, train_value, eval_value])

LOG.info(f"Evaluation results saved to {metrics_file}")

del model
del tokenizer

return metrics
return all_metrics
3 changes: 3 additions & 0 deletions src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
eval_dataset = None
if cfg.dataset_exact_deduplication:
LOG.info("Deduplication not available for pretrained datasets")

return train_dataset, eval_dataset, cfg.max_steps, prompters

if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
Expand All @@ -134,6 +136,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
LOG.info(f"Maximum number of steps set at {total_num_steps}")
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset)

return train_dataset, eval_dataset, total_num_steps, prompters


Expand Down
73 changes: 73 additions & 0 deletions tests/cli/test_cli_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Base test class for CLI commands."""

from pathlib import Path
from unittest.mock import patch

from axolotl.cli.main import cli


class BaseCliTest:
"""Base class for CLI command tests."""

def _test_cli_validation(self, cli_runner, command: str):
"""Test CLI validation for a command.
Args:
cli_runner: CLI runner fixture
command: Command to test (train/evaluate)
"""
# Test missing config file
result = cli_runner.invoke(cli, [command, "--no-accelerate"])
assert result.exit_code != 0

# Test non-existent config file
result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"])
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output

def _test_basic_execution(
self, cli_runner, tmp_path: Path, valid_test_config: str, command: str
):
"""Test basic execution with accelerate.
Args:
cli_runner: CLI runner fixture
tmp_path: Temporary path fixture
valid_test_config: Valid config fixture
command: Command to test (train/evaluate)
"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)

with patch("subprocess.run") as mock:
result = cli_runner.invoke(cli, [command, str(config_path)])

assert mock.called
assert mock.call_args.args[0] == [
"accelerate",
"launch",
"-m",
f"axolotl.cli.{command}",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0

def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
"""Test CLI argument overrides.
Args:
tmp_path: Temporary path fixture
valid_test_config: Valid config fixture
command: Command to test (train/evaluate)
"""
config_path = tmp_path / "config.yml"
output_dir = tmp_path / "model-out"

test_config = valid_test_config.replace(
"output_dir: model-out", f"output_dir: {output_dir}"
)
config_path.write_text(test_config)
return config_path
Loading

0 comments on commit c92aab6

Please sign in to comment.