diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index 004df55186..f81f711de4 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -60,7 +60,8 @@ jobs: - name: Run tests run: | - pytest --ignore=tests/e2e/ tests/ + pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/ + pytest tests/patched/ - name: cleanup pip cache run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5abab4a2df..89f19ca9c3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -79,7 +79,8 @@ jobs: - name: Run tests run: | - pytest -n8 --ignore=tests/e2e/ tests/ + pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/ + pytest tests/patched/ - name: cleanup pip cache run: | @@ -123,7 +124,8 @@ jobs: - name: Run tests run: | - pytest -n8 --ignore=tests/e2e/ tests/ + pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/ + pytest tests/patched/ - name: cleanup pip cache run: | diff --git a/README.md b/README.md index 8ee959d218..75e7faa642 100644 --- a/README.md +++ b/README.md @@ -135,6 +135,46 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml ``` +### Axolotl CLI + +If you've installed this package using `pip` from source, we now support a new, more +streamlined CLI using [click](https://click.palletsprojects.com/en/stable/). Rewriting +the above commands: + +```bash +# preprocess datasets - optional but recommended +CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/openllama-3b/lora.yml + +# finetune lora +axolotl train examples/openllama-3b/lora.yml + +# inference +axolotl inference examples/openllama-3b/lora.yml \ + --lora-model-dir="./outputs/lora-out" + +# gradio +axolotl inference examples/openllama-3b/lora.yml \ + --lora-model-dir="./outputs/lora-out" --gradio + +# remote yaml files - the yaml config can be hosted on a public URL +# Note: the yaml config must directly link to the **raw** yaml +axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml +``` + +We've also added a new command for fetching `examples` and `deepspeed_configs` to your +local machine. This will come in handy when installing `axolotl` from PyPI. + +```bash +# Fetch example YAML files (stores in "examples/" folder) +axolotl fetch examples + +# Fetch deepspeed config files (stores in "deepspeed_configs/" folder) +axolotl fetch deepspeed_configs + +# Optionally, specify a destination folder +axolotl fetch examples --dest path/to/folder +``` + ## Badge ❤🏷️ Building something cool with Axolotl? Consider adding a badge to your model card. @@ -206,7 +246,6 @@ Thanks to all of our contributors to date. Help drive open source AI progress fo ❌: not supported ❓: untested - ## Advanced Setup ### Environment diff --git a/cicd/cicd.sh b/cicd/cicd.sh index 68c40bb78d..79b3cc95e0 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -1,6 +1,7 @@ #!/bin/bash set -e -pytest -v --durations=10 -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/ -pytest -v --durations=10 -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ +pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ +pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/ +pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ diff --git a/outputs b/outputs new file mode 120000 index 0000000000..be3c4a823f --- /dev/null +++ b/outputs @@ -0,0 +1 @@ +/workspace/data/axolotl-artifacts \ No newline at end of file diff --git a/setup.py b/setup.py index ea779d5996..336da98f4b 100644 --- a/setup.py +++ b/setup.py @@ -103,6 +103,11 @@ def parse_requirements(): packages=find_packages("src"), install_requires=install_requires, dependency_links=dependency_links, + entry_points={ + "console_scripts": [ + "axolotl=axolotl.cli.main:main", + ], + }, extras_require={ "flash-attn": [ "flash-attn==2.7.0.post2", diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 1e61b220b9..e8ef862854 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -380,7 +380,7 @@ def choose_config(path: Path): if len(yaml_files) == 1: print(f"Using default YAML file '{yaml_files[0]}'") - return yaml_files[0] + return str(yaml_files[0]) print("Choose a YAML file:") for idx, file in enumerate(yaml_files): @@ -391,7 +391,7 @@ def choose_config(path: Path): try: choice = int(input("Enter the number of your choice: ")) if 1 <= choice <= len(yaml_files): - chosen_file = yaml_files[choice - 1] + chosen_file = str(yaml_files[choice - 1]) else: print("Invalid choice. Please choose a number from the list.") except ValueError: diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index b738e5c222..a5f1a8ad8b 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -2,6 +2,7 @@ CLI to run inference on a trained model """ from pathlib import Path +from typing import Union import fire import transformers @@ -16,7 +17,7 @@ from axolotl.common.cli import TrainerCliArgs -def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs): +def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs): # pylint: disable=duplicate-code print_axolotl_text_art() parsed_cfg = load_cfg(config, inference=True, **kwargs) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py new file mode 100644 index 0000000000..a138daab25 --- /dev/null +++ b/src/axolotl/cli/main.py @@ -0,0 +1,231 @@ +"""CLI definition for various axolotl commands.""" +# pylint: disable=redefined-outer-name +import subprocess # nosec B404 +from typing import Optional + +import click + +from axolotl.cli.utils import ( + add_options_from_config, + add_options_from_dataclass, + build_command, + fetch_from_github, +) +from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs +from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig + + +@click.group() +def cli(): + """Axolotl CLI - Train and fine-tune large language models""" + + +@cli.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@add_options_from_dataclass(PreprocessCliArgs) +@add_options_from_config(AxolotlInputConfig) +def preprocess(config: str, **kwargs): + """Preprocess datasets before training.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + from axolotl.cli.preprocess import do_cli + + do_cli(config=config, **kwargs) + + +@cli.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@click.option( + "--accelerate/--no-accelerate", + default=True, + help="Use accelerate launch for multi-GPU training", +) +@add_options_from_dataclass(TrainerCliArgs) +@add_options_from_config(AxolotlInputConfig) +def train(config: str, accelerate: bool, **kwargs): + """Train or fine-tune a model.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + if accelerate: + base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"] + if config: + base_cmd.append(config) + cmd = build_command(base_cmd, kwargs) + subprocess.run(cmd, check=True) # nosec B603 + else: + from axolotl.cli.train import do_cli + + do_cli(config=config, **kwargs) + + +@cli.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@click.option( + "--accelerate/--no-accelerate", + default=True, + help="Use accelerate launch for multi-GPU inference", +) +@click.option( + "--lora-model-dir", + type=click.Path(exists=True, path_type=str), + help="Directory containing LoRA model", +) +@click.option( + "--base-model", + type=click.Path(exists=True, path_type=str), + help="Path to base model for non-LoRA models", +) +@click.option("--gradio", is_flag=True, help="Launch Gradio interface") +@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode") +@add_options_from_dataclass(TrainerCliArgs) +@add_options_from_config(AxolotlInputConfig) +def inference( + config: str, + accelerate: bool, + lora_model_dir: Optional[str] = None, + base_model: Optional[str] = None, + **kwargs, +): + """Run inference with a trained model.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + del kwargs["inference"] # interferes with inference.do_cli + + if lora_model_dir: + kwargs["lora_model_dir"] = lora_model_dir + if base_model: + kwargs["output_dir"] = base_model + + if accelerate: + base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"] + if config: + base_cmd.append(config) + cmd = build_command(base_cmd, kwargs) + subprocess.run(cmd, check=True) # nosec B603 + else: + from axolotl.cli.inference import do_cli + + do_cli(config=config, **kwargs) + + +@cli.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@click.option( + "--accelerate/--no-accelerate", + default=False, + help="Use accelerate launch for multi-GPU operations", +) +@click.option( + "--model-dir", + type=click.Path(exists=True, path_type=str), + help="Directory containing model weights to shard", +) +@click.option( + "--save-dir", + type=click.Path(path_type=str), + help="Directory to save sharded weights", +) +@add_options_from_dataclass(TrainerCliArgs) +@add_options_from_config(AxolotlInputConfig) +def shard(config: str, accelerate: bool, **kwargs): + """Shard model weights.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + if accelerate: + base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"] + if config: + base_cmd.append(config) + cmd = build_command(base_cmd, kwargs) + subprocess.run(cmd, check=True) # nosec B603 + else: + from axolotl.cli.shard import do_cli + + do_cli(config=config, **kwargs) + + +@cli.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@click.option( + "--accelerate/--no-accelerate", + default=True, + help="Use accelerate launch for weight merging", +) +@click.option( + "--model-dir", + type=click.Path(exists=True, path_type=str), + help="Directory containing sharded weights", +) +@click.option( + "--save-path", type=click.Path(path_type=str), help="Path to save merged weights" +) +@add_options_from_dataclass(TrainerCliArgs) +@add_options_from_config(AxolotlInputConfig) +def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs): + """Merge sharded FSDP model weights.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + if accelerate: + base_cmd = [ + "accelerate", + "launch", + "-m", + "axolotl.cli.merge_sharded_fsdp_weights", + ] + if config: + base_cmd.append(config) + cmd = build_command(base_cmd, kwargs) + subprocess.run(cmd, check=True) # nosec B603 + else: + from axolotl.cli.merge_sharded_fsdp_weights import do_cli + + do_cli(config=config, **kwargs) + + +@cli.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@click.option( + "--lora-model-dir", + type=click.Path(exists=True, path_type=str), + help="Directory containing the LoRA model to merge", +) +@click.option( + "--output-dir", + type=click.Path(path_type=str), + help="Directory to save the merged model", +) +def merge_lora( + config: str, + lora_model_dir: Optional[str] = None, + output_dir: Optional[str] = None, +): + """Merge a trained LoRA into a base model""" + kwargs = {} + if lora_model_dir: + kwargs["lora_model_dir"] = lora_model_dir + if output_dir: + kwargs["output_dir"] = output_dir + + from axolotl.cli.merge_lora import do_cli + + do_cli(config=config, **kwargs) + + +@cli.command() +@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"])) +@click.option("--dest", help="Destination directory") +def fetch(directory: str, dest: Optional[str]): + """ + Fetch example configs or other resources. + + Available directories: + - examples: Example configuration files + - deepspeed_configs: DeepSpeed configuration files + """ + fetch_from_github(f"{directory}/", dest) + + +def main(): + cli() + + +if __name__ == "__main__": + main() diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 6588b5ee4e..8c321bc48e 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -2,6 +2,7 @@ CLI to run merge a trained LoRA into a base model """ from pathlib import Path +from typing import Union import fire import transformers @@ -11,7 +12,7 @@ from axolotl.common.cli import TrainerCliArgs -def do_cli(config: Path = Path("examples/"), **kwargs): +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): # pylint: disable=duplicate-code print_axolotl_text_art() parser = transformers.HfArgumentParser((TrainerCliArgs)) diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index 25408fd57e..6be9af1f76 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -177,7 +177,7 @@ def merge_fsdp_weights( state.wait_for_everyone() -def do_cli(config: Path = Path("examples/"), **kwargs): +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): # pylint: disable=duplicate-code print_axolotl_text_art() parser = transformers.HfArgumentParser((TrainerCliArgs)) diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py new file mode 100644 index 0000000000..f0e2573f72 --- /dev/null +++ b/src/axolotl/cli/utils.py @@ -0,0 +1,218 @@ +"""Utility methods for axoltl CLI.""" +import concurrent.futures +import dataclasses +import hashlib +import json +import logging +from pathlib import Path +from types import NoneType +from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin + +import click +import requests +from pydantic import BaseModel + +LOG = logging.getLogger("axolotl.cli.utils") + + +def add_options_from_dataclass(config_class: Type[Any]): + """Create Click options from the fields of a dataclass.""" + + def decorator(function): + # Process dataclass fields in reverse order for correct option ordering + for field in reversed(dataclasses.fields(config_class)): + field_type = field.type + + if get_origin(field_type) is Union and type(None) in get_args(field_type): + field_type = next( + t for t in get_args(field_type) if not isinstance(t, NoneType) + ) + + if field_type == bool: + field_name = field.name.replace("_", "-") + option_name = f"--{field_name}/--no-{field_name}" + function = click.option( + option_name, + default=field.default, + help=field.metadata.get("description"), + )(function) + else: + option_name = f"--{field.name.replace('_', '-')}" + function = click.option( + option_name, + type=field_type, + default=field.default, + help=field.metadata.get("description"), + )(function) + return function + + return decorator + + +def add_options_from_config(config_class: Type[BaseModel]): + """Create Click options from the fields of a Pydantic model.""" + + def decorator(function): + # Process model fields in reverse order for correct option ordering + for name, field in reversed(config_class.model_fields.items()): + if field.annotation == bool: + field_name = name.replace("_", "-") + option_name = f"--{field_name}/--no-{field_name}" + function = click.option( + option_name, default=None, help=field.description + )(function) + else: + option_name = f"--{name.replace('_', '-')}" + function = click.option( + option_name, default=None, help=field.description + )(function) + return function + + return decorator + + +def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]: + """Build command list from base command and options.""" + cmd = base_cmd.copy() + + for key, value in options.items(): + if value is None: + continue + + key = key.replace("_", "-") + + if isinstance(value, bool): + if value: + cmd.append(f"--{key}") + else: + cmd.extend([f"--{key}", str(value)]) + + return cmd + + +def download_file( + file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str +) -> Tuple[str, str]: + """ + Download a single file and return its processing status. + + Args: + file_info: Tuple of (file_path, remote_sha) + raw_base_url: Base URL for raw GitHub content + dest_path: Local destination directory + dir_prefix: Directory prefix to filter files + + Returns: + Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged' + """ + file_path, remote_sha = file_info + raw_url = f"{raw_base_url}/{file_path}" + dest_file = dest_path / file_path.split(dir_prefix)[-1] + + # Check if file exists and needs updating + if dest_file.exists(): + with open(dest_file, "rb") as file: + content = file.read() + # Calculate git blob SHA + blob = b"blob " + str(len(content)).encode() + b"\0" + content + local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest() + + if local_sha == remote_sha: + print(f"Skipping {file_path} (unchanged)") + return file_path, "unchanged" + + print(f"Updating {file_path}") + status = "new" + else: + print(f"Downloading {file_path}") + status = "new" + + # Create directories if needed + dest_file.parent.mkdir(parents=True, exist_ok=True) + + # Download and save file + try: + response = requests.get(raw_url, timeout=30) + response.raise_for_status() + + with open(dest_file, "wb") as file: + file.write(response.content) + + return file_path, status + except (requests.RequestException, IOError) as request_error: + print(f"Error downloading {file_path}: {str(request_error)}") + return file_path, "error" + + +def fetch_from_github( + dir_prefix: str, dest_dir: Optional[str] = None, max_workers: int = 5 +) -> None: + """ + Sync files from a specific directory in the GitHub repository. + Only downloads files that don't exist locally or have changed. + + Args: + dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/') + dest_dir: Local destination directory + max_workers: Maximum number of concurrent downloads + """ + api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1" + raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main" + + # Get repository tree with timeout + response = requests.get(api_url, timeout=30) + response.raise_for_status() + tree = json.loads(response.text) + + # Filter for files and get their SHA + files = { + item["path"]: item["sha"] + for item in tree["tree"] + if item["type"] == "blob" and item["path"].startswith(dir_prefix) + } + + if not files: + raise click.ClickException(f"No files found in {dir_prefix}") + + # Default destination directory is the last part of dir_prefix + default_dest = Path(dir_prefix.rstrip("/")) + dest_path = Path(dest_dir) if dest_dir else default_dest + + # Keep track of processed files for summary + files_processed: Dict[str, List[str]] = { + "new": [], + "updated": [], + "unchanged": [], + "error": [], + } + + # Process files in parallel using ThreadPoolExecutor + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_file = { + executor.submit( + download_file, + (file_path, remote_sha), + raw_base_url, + dest_path, + dir_prefix, + ): file_path + for file_path, remote_sha in files.items() + } + + # Process completed tasks as they finish + for future in concurrent.futures.as_completed(future_to_file): + file_path = future_to_file[future] + try: + file_path, status = future.result() + files_processed[status].append(file_path) + except (requests.RequestException, IOError) as request_error: + print(f"Error processing {file_path}: {str(request_error)}") + files_processed["error"].append(file_path) + + # Log summary + LOG.info("\nSync Summary:") + LOG.info(f"New files: {len(files_processed['new'])}") + LOG.info(f"Updated files: {len(files_processed['updated'])}") + LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}") + if files_processed["error"]: + LOG.info(f"Failed files: {len(files_processed['error'])}") diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py new file mode 100644 index 0000000000..78b090e19e --- /dev/null +++ b/tests/cli/conftest.py @@ -0,0 +1,36 @@ +"""Shared pytest fixtures for cli module.""" +import pytest +from click.testing import CliRunner + +VALID_TEST_CONFIG = """ +base_model: HuggingFaceTB/SmolLM2-135M +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +sequence_len: 2048 +max_steps: 1 +micro_batch_size: 1 +gradient_accumulation_steps: 1 +learning_rate: 1e-3 +special_tokens: + pad_token: <|endoftext|> +""" + + +@pytest.fixture +def cli_runner(): + return CliRunner() + + +@pytest.fixture +def valid_test_config(): + return VALID_TEST_CONFIG + + +@pytest.fixture +def config_path(tmp_path): + """Creates a temporary config file""" + path = tmp_path / "config.yml" + path.write_text(VALID_TEST_CONFIG) + + return path diff --git a/tests/cli/test_cli_fetch.py b/tests/cli/test_cli_fetch.py new file mode 100644 index 0000000000..0df87b0299 --- /dev/null +++ b/tests/cli/test_cli_fetch.py @@ -0,0 +1,38 @@ +"""pytest tests for axolotl CLI fetch command.""" +from unittest.mock import patch + +from axolotl.cli.main import fetch + + +def test_fetch_cli_examples(cli_runner): + """Test fetch command with examples directory""" + with patch("axolotl.cli.main.fetch_from_github") as mock_fetch: + result = cli_runner.invoke(fetch, ["examples"]) + + assert result.exit_code == 0 + mock_fetch.assert_called_once_with("examples/", None) + + +def test_fetch_cli_deepspeed(cli_runner): + """Test fetch command with deepspeed_configs directory""" + with patch("axolotl.cli.main.fetch_from_github") as mock_fetch: + result = cli_runner.invoke(fetch, ["deepspeed_configs"]) + + assert result.exit_code == 0 + mock_fetch.assert_called_once_with("deepspeed_configs/", None) + + +def test_fetch_cli_with_dest(cli_runner, tmp_path): + """Test fetch command with custom destination""" + with patch("axolotl.cli.main.fetch_from_github") as mock_fetch: + custom_dir = tmp_path / "tmp_examples" + result = cli_runner.invoke(fetch, ["examples", "--dest", str(custom_dir)]) + + assert result.exit_code == 0 + mock_fetch.assert_called_once_with("examples/", str(custom_dir)) + + +def test_fetch_cli_invalid_directory(cli_runner): + """Test fetch command with invalid directory choice""" + result = cli_runner.invoke(fetch, ["invalid"]) + assert result.exit_code != 0 diff --git a/tests/cli/test_cli_inference.py b/tests/cli/test_cli_inference.py new file mode 100644 index 0000000000..7cb163d255 --- /dev/null +++ b/tests/cli/test_cli_inference.py @@ -0,0 +1,30 @@ +"""pytest tests for axolotl CLI inference command.""" +from unittest.mock import patch + +from axolotl.cli.main import cli + + +def test_inference_basic(cli_runner, config_path): + """Test basic inference""" + with patch("axolotl.cli.inference.do_inference") as mock: + result = cli_runner.invoke( + cli, + ["inference", str(config_path), "--no-accelerate"], + catch_exceptions=False, + ) + + assert mock.called + assert result.exit_code == 0 + + +def test_inference_gradio(cli_runner, config_path): + """Test basic inference (gradio path)""" + with patch("axolotl.cli.inference.do_inference_gradio") as mock: + result = cli_runner.invoke( + cli, + ["inference", str(config_path), "--no-accelerate", "--gradio"], + catch_exceptions=False, + ) + + assert mock.called + assert result.exit_code == 0 diff --git a/tests/cli/test_cli_interface.py b/tests/cli/test_cli_interface.py new file mode 100644 index 0000000000..ed8335b766 --- /dev/null +++ b/tests/cli/test_cli_interface.py @@ -0,0 +1,47 @@ +"""General pytest tests for axolotl.cli.main interface.""" +from axolotl.cli.main import build_command, cli + + +def test_build_command(): + """Test converting dict of options to CLI arguments""" + base_cmd = ["accelerate", "launch"] + options = { + "learning_rate": 1e-4, + "batch_size": 8, + "debug": True, + "use_fp16": False, + "null_value": None, + } + + result = build_command(base_cmd, options) + assert result == [ + "accelerate", + "launch", + "--learning-rate", + "0.0001", + "--batch-size", + "8", + "--debug", + ] + + +def test_invalid_command_options(cli_runner): + """Test handling of invalid command options""" + result = cli_runner.invoke( + cli, + [ + "train", + "config.yml", + "--invalid-option", + "value", + ], + ) + assert result.exit_code != 0 + assert "No such option" in result.output + + +def test_required_config_argument(cli_runner): + """Test commands fail properly when config argument is missing""" + result = cli_runner.invoke(cli, ["train"]) + assert result.exit_code != 0 + assert "Missing argument 'CONFIG'" in result.output diff --git a/tests/cli/test_cli_merge_lora.py b/tests/cli/test_cli_merge_lora.py new file mode 100644 index 0000000000..165a64e98c --- /dev/null +++ b/tests/cli/test_cli_merge_lora.py @@ -0,0 +1,56 @@ +"""pytest tests for axolotl CLI merge_lora command.""" +from unittest.mock import patch + +from axolotl.cli.main import cli + + +def test_merge_lora_basic(cli_runner, config_path): + """Test basic merge_lora command""" + with patch("axolotl.cli.merge_lora.do_cli") as mock_do_cli: + result = cli_runner.invoke(cli, ["merge-lora", str(config_path)]) + assert result.exit_code == 0 + + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) + + +def test_merge_lora_with_dirs(cli_runner, config_path, tmp_path): + """Test merge_lora with custom lora and output directories""" + lora_dir = tmp_path / "lora" + output_dir = tmp_path / "output" + lora_dir.mkdir() + + with patch("axolotl.cli.merge_lora.do_cli") as mock_do_cli: + result = cli_runner.invoke( + cli, + [ + "merge-lora", + str(config_path), + "--lora-model-dir", + str(lora_dir), + "--output-dir", + str(output_dir), + ], + ) + assert result.exit_code == 0 + + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) + assert mock_do_cli.call_args.kwargs["lora_model_dir"] == str(lora_dir) + assert mock_do_cli.call_args.kwargs["output_dir"] == str(output_dir) + + +def test_merge_lora_nonexistent_config(cli_runner, tmp_path): + """Test merge_lora with nonexistent config""" + config_path = tmp_path / "nonexistent.yml" + result = cli_runner.invoke(cli, ["merge-lora", str(config_path)]) + assert result.exit_code != 0 + + +def test_merge_lora_nonexistent_lora_dir(cli_runner, config_path, tmp_path): + """Test merge_lora with nonexistent lora directory""" + lora_dir = tmp_path / "nonexistent" + result = cli_runner.invoke( + cli, ["merge-lora", str(config_path), "--lora-model-dir", str(lora_dir)] + ) + assert result.exit_code != 0 diff --git a/tests/cli/test_cli_merge_sharded_fsdp_weights.py b/tests/cli/test_cli_merge_sharded_fsdp_weights.py new file mode 100644 index 0000000000..cff0f3b773 --- /dev/null +++ b/tests/cli/test_cli_merge_sharded_fsdp_weights.py @@ -0,0 +1,60 @@ +"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command.""" +# pylint: disable=duplicate-code +from unittest.mock import patch + +from axolotl.cli.main import cli + + +def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path): + """Test merge_sharded_fsdp_weights command without accelerate""" + with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock: + result = cli_runner.invoke( + cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"] + ) + + assert mock.called + assert mock.call_args.kwargs["config"] == str(config_path) + assert result.exit_code == 0 + + +def test_merge_sharded_fsdp_weights_with_model_dir(cli_runner, config_path, tmp_path): + """Test merge_sharded_fsdp_weights command with model_dir option""" + model_dir = tmp_path / "model" + model_dir.mkdir() + + with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock: + result = cli_runner.invoke( + cli, + [ + "merge-sharded-fsdp-weights", + str(config_path), + "--no-accelerate", + "--model-dir", + str(model_dir), + ], + ) + + assert mock.called + assert mock.call_args.kwargs["config"] == str(config_path) + assert mock.call_args.kwargs["model_dir"] == str(model_dir) + assert result.exit_code == 0 + + +def test_merge_sharded_fsdp_weights_with_save_path(cli_runner, config_path): + """Test merge_sharded_fsdp_weights command with save_path option""" + with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock: + result = cli_runner.invoke( + cli, + [ + "merge-sharded-fsdp-weights", + str(config_path), + "--no-accelerate", + "--save-path", + "/path/to/save", + ], + ) + + assert mock.called + assert mock.call_args.kwargs["config"] == str(config_path) + assert mock.call_args.kwargs["save_path"] == "/path/to/save" + assert result.exit_code == 0 diff --git a/tests/cli/test_cli_preprocess.py b/tests/cli/test_cli_preprocess.py new file mode 100644 index 0000000000..4719461aaf --- /dev/null +++ b/tests/cli/test_cli_preprocess.py @@ -0,0 +1,71 @@ +"""pytest tests for axolotl CLI preprocess command.""" +import shutil +from pathlib import Path +from unittest.mock import patch + +import pytest + +from axolotl.cli.main import cli + + +@pytest.fixture(autouse=True) +def cleanup_last_run_prepared(): + yield + + if Path("last_run_prepared").exists(): + shutil.rmtree("last_run_prepared") + + +def test_preprocess_config_not_found(cli_runner): + """Test preprocess fails when config not found""" + result = cli_runner.invoke(cli, ["preprocess", "nonexistent.yml"]) + assert result.exit_code != 0 + + +def test_preprocess_basic(cli_runner, config_path): + """Test basic preprocessing with minimal config""" + with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli: + result = cli_runner.invoke(cli, ["preprocess", str(config_path)]) + assert result.exit_code == 0 + + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) + assert mock_do_cli.call_args.kwargs["download"] is True + + +def test_preprocess_without_download(cli_runner, config_path): + """Test preprocessing without model download""" + with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli: + result = cli_runner.invoke( + cli, ["preprocess", str(config_path), "--no-download"] + ) + assert result.exit_code == 0 + + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) + assert mock_do_cli.call_args.kwargs["download"] is False + + +def test_preprocess_custom_path(cli_runner, tmp_path, valid_test_config): + """Test preprocessing with custom dataset path""" + config_path = tmp_path / "config.yml" + custom_path = tmp_path / "custom_prepared" + config_path.write_text(valid_test_config) + + with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli: + result = cli_runner.invoke( + cli, + [ + "preprocess", + str(config_path), + "--dataset-prepared-path", + str(custom_path.absolute()), + ], + ) + assert result.exit_code == 0 + + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) + assert mock_do_cli.call_args.kwargs["dataset_prepared_path"] == str( + custom_path.absolute() + ) diff --git a/tests/cli/test_cli_shard.py b/tests/cli/test_cli_shard.py new file mode 100644 index 0000000000..505a2a7372 --- /dev/null +++ b/tests/cli/test_cli_shard.py @@ -0,0 +1,76 @@ +"""pytest tests for axolotl CLI shard command.""" +# pylint: disable=duplicate-code +from unittest.mock import patch + +from axolotl.cli.main import cli + + +def test_shard_with_accelerate(cli_runner, config_path): + """Test shard command with accelerate""" + with patch("subprocess.run") as mock: + result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"]) + + assert mock.called + assert mock.call_args.args[0] == [ + "accelerate", + "launch", + "-m", + "axolotl.cli.shard", + str(config_path), + "--debug-num-examples", + "0", + ] + assert mock.call_args.kwargs == {"check": True} + assert result.exit_code == 0 + + +def test_shard_no_accelerate(cli_runner, config_path): + """Test shard command without accelerate""" + with patch("axolotl.cli.shard.do_cli") as mock: + result = cli_runner.invoke(cli, ["shard", str(config_path), "--no-accelerate"]) + + assert mock.called + assert result.exit_code == 0 + + +def test_shard_with_model_dir(cli_runner, config_path, tmp_path): + """Test shard command with model_dir option""" + model_dir = tmp_path / "model" + model_dir.mkdir() + + with patch("axolotl.cli.shard.do_cli") as mock: + result = cli_runner.invoke( + cli, + [ + "shard", + str(config_path), + "--no-accelerate", + "--model-dir", + str(model_dir), + ], + catch_exceptions=False, + ) + + assert mock.called + assert mock.call_args.kwargs["config"] == str(config_path) + assert mock.call_args.kwargs["model_dir"] == str(model_dir) + assert result.exit_code == 0 + + +def test_shard_with_save_dir(cli_runner, config_path): + with patch("axolotl.cli.shard.do_cli") as mock: + result = cli_runner.invoke( + cli, + [ + "shard", + str(config_path), + "--no-accelerate", + "--save-dir", + "/path/to/save", + ], + ) + + assert mock.called + assert mock.call_args.kwargs["config"] == str(config_path) + assert mock.call_args.kwargs["save_dir"] == "/path/to/save" + assert result.exit_code == 0 diff --git a/tests/cli/test_cli_train.py b/tests/cli/test_cli_train.py new file mode 100644 index 0000000000..7f028fb4f2 --- /dev/null +++ b/tests/cli/test_cli_train.py @@ -0,0 +1,98 @@ +"""pytest tests for axolotl CLI train command.""" +from unittest.mock import MagicMock, patch + +from axolotl.cli.main import cli + + +def test_train_cli_validation(cli_runner): + """Test CLI validation""" + # Test missing config file + result = cli_runner.invoke(cli, ["train", "--no-accelerate"]) + assert result.exit_code != 0 + + # Test non-existent config file + result = cli_runner.invoke(cli, ["train", "nonexistent.yml", "--no-accelerate"]) + assert result.exit_code != 0 + assert "Error: Invalid value for 'CONFIG'" in result.output + + +def test_train_basic_execution(cli_runner, tmp_path, valid_test_config): + """Test basic successful execution""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("subprocess.run") as mock: + result = cli_runner.invoke(cli, ["train", str(config_path)]) + + assert mock.called + assert mock.call_args.args[0] == [ + "accelerate", + "launch", + "-m", + "axolotl.cli.train", + str(config_path), + "--debug-num-examples", + "0", + ] + assert mock.call_args.kwargs == {"check": True} + assert result.exit_code == 0 + + +def test_train_basic_execution_no_accelerate(cli_runner, tmp_path, valid_test_config): + """Test basic successful execution""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("axolotl.cli.train.train") as mock_train: + mock_train.return_value = (MagicMock(), MagicMock()) + + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--learning-rate", + "1e-4", + "--micro-batch-size", + "2", + "--no-accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_train.assert_called_once() + + +def test_train_cli_overrides(cli_runner, tmp_path, valid_test_config): + """Test CLI arguments properly override config values""" + config_path = tmp_path / "config.yml" + output_dir = tmp_path / "model-out" + + test_config = valid_test_config.replace( + "output_dir: model-out", f"output_dir: {output_dir}" + ) + config_path.write_text(test_config) + + with patch("axolotl.cli.train.train") as mock_train: + mock_train.return_value = (MagicMock(), MagicMock()) + + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--learning-rate", + "1e-4", + "--micro-batch-size", + "2", + "--no-accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_train.assert_called_once() + cfg = mock_train.call_args[1]["cfg"] + assert cfg["learning_rate"] == 1e-4 + assert cfg["micro_batch_size"] == 2 diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py new file mode 100644 index 0000000000..a593479a41 --- /dev/null +++ b/tests/cli/test_utils.py @@ -0,0 +1,89 @@ +"""pytest tests for axolotl CLI utils.""" +# pylint: disable=redefined-outer-name +import json +from unittest.mock import Mock, patch + +import click +import pytest +import requests + +from axolotl.cli.utils import fetch_from_github + +# Sample GitHub API response +MOCK_TREE_RESPONSE = { + "tree": [ + {"path": "examples/config1.yml", "type": "blob", "sha": "abc123"}, + {"path": "examples/config2.yml", "type": "blob", "sha": "def456"}, + {"path": "other/file.txt", "type": "blob", "sha": "xyz789"}, + ] +} + + +@pytest.fixture +def mock_responses(): + """Mock responses for API and file downloads""" + + def mock_get(url, timeout=None): # pylint: disable=unused-argument + response = Mock() + if "api.github.com" in url: + response.text = json.dumps(MOCK_TREE_RESPONSE) + else: + response.content = b"file content" + return response + + return mock_get + + +def test_fetch_from_github_new_files(tmp_path, mock_responses): + """Test fetching new files""" + with patch("requests.get", mock_responses): + fetch_from_github("examples/", tmp_path) + + # Verify files were created + assert (tmp_path / "config1.yml").exists() + assert (tmp_path / "config2.yml").exists() + assert not (tmp_path / "file.txt").exists() + + +def test_fetch_from_github_unchanged_files(tmp_path, mock_responses): + """Test handling of unchanged files""" + # Create existing file with matching SHA + existing_file = tmp_path / "config1.yml" + existing_file.write_bytes(b"file content") + + with patch("requests.get", mock_responses): + fetch_from_github("examples/", tmp_path) + + # File should not be downloaded again + assert existing_file.read_bytes() == b"file content" + + +def test_fetch_from_github_invalid_prefix(mock_responses): + """Test error handling for invalid directory prefix""" + with patch("requests.get", mock_responses): + with pytest.raises(click.ClickException): + fetch_from_github("nonexistent/", None) + + +def test_fetch_from_github_network_error(): + """Test handling of network errors""" + with patch("requests.get", side_effect=requests.RequestException): + with pytest.raises(requests.RequestException): + fetch_from_github("examples/", None) + + +@pytest.fixture +def integration_test_dir(tmp_path): + """Fixture for integration test directory that cleans up after itself""" + test_dir = tmp_path / "github_downloads" + test_dir.mkdir(parents=True) + yield test_dir + + +def test_fetch_from_github_real(integration_test_dir): + """Test actual GitHub API interaction""" + fetch_from_github("examples/", integration_test_dir) + + # Verify some known files exist + assert (integration_test_dir / "openllama-3b" / "lora.yml").exists() + assert (integration_test_dir / "openllama-3b" / "qlora.yml").exists() diff --git a/tests/test_validation.py b/tests/patched/test_validation.py similarity index 100% rename from tests/test_validation.py rename to tests/patched/test_validation.py diff --git a/tests/test_data.py b/tests/test_data.py index 9d7f5a0412..e156e1f3c5 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,5 +1,5 @@ """ -test module for the axolotl.utis.data module +test module for the axolotl.utils.data module """ import unittest