diff --git a/src/axolotl/integrations/differential_transformer/differential_attention.py b/src/axolotl/integrations/differential_transformer/differential_attention.py index 58d4b94ec..af7473436 100644 --- a/src/axolotl/integrations/differential_transformer/differential_attention.py +++ b/src/axolotl/integrations/differential_transformer/differential_attention.py @@ -84,9 +84,9 @@ def __init__( if config.split_heads: # Split heads mode - assert ( - self.base_num_heads % 2 == 0 - ), "Number of heads must be even for splitting" + # assert ( + # self.base_num_heads % 2 == 0 + # ), "Number of heads must be even for splitting" self.heads_per_component = self.base_num_heads // 2 # Single projections diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index 78b090e19..d360e29d6 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -1,4 +1,5 @@ """Shared pytest fixtures for cli module.""" + import pytest from click.testing import CliRunner diff --git a/tests/cli/integrations/__init__.py b/tests/cli/integrations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/cli/integrations/test_cli_convert_differential_transformer.py b/tests/cli/integrations/test_cli_convert_differential_transformer.py deleted file mode 100644 index cd2a464c6..000000000 --- a/tests/cli/integrations/test_cli_convert_differential_transformer.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Tests for convert-differential-transformer CLI command.""" - -from pathlib import Path -from unittest.mock import patch - -from axolotl.cli.main import cli - - -def test_cli_validation(cli_runner): - """Test CLI validation for a command. - - Args: - cli_runner: CLI runner fixture - """ - # Test missing config file - result = cli_runner.invoke(cli, ["convert-differential-transformer"]) - assert result.exit_code != 0 - assert "Error: Missing argument 'CONFIG'." in result.output - - # Test non-existent config file - result = cli_runner.invoke( - cli, ["convert-differential-transformer", "nonexistent.yml"] - ) - assert result.exit_code != 0 - assert "Error: Invalid value for 'CONFIG'" in result.output - - -def test_basic_execution(cli_runner, tmp_path: Path, valid_test_config: str): - """Test basic execution. - - Args: - cli_runner: CLI runner fixture - tmp_path: Temporary path fixture - valid_test_config: Valid config fixture - """ - config_path = tmp_path / "config.yml" - config_path.write_text(valid_test_config) - - with patch( - "axolotl.cli.integrations.convert_differential_transformer.do_cli" - ) as mock_do_cli: - result = cli_runner.invoke( - cli, ["convert-differential-transformer", str(config_path)] - ) - assert result.exit_code == 0 - - mock_do_cli.assert_called_once() - assert mock_do_cli.call_args.kwargs["config"] == str(config_path) diff --git a/tests/cli/test_cli_fetch.py b/tests/cli/test_cli_fetch.py index 0df87b029..f06f06717 100644 --- a/tests/cli/test_cli_fetch.py +++ b/tests/cli/test_cli_fetch.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI fetch command.""" + from unittest.mock import patch from axolotl.cli.main import fetch diff --git a/tests/cli/test_cli_inference.py b/tests/cli/test_cli_inference.py index 7cb163d25..b8effa3d2 100644 --- a/tests/cli/test_cli_inference.py +++ b/tests/cli/test_cli_inference.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI inference command.""" + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_interface.py b/tests/cli/test_cli_interface.py index ed8335b76..8b5fec17f 100644 --- a/tests/cli/test_cli_interface.py +++ b/tests/cli/test_cli_interface.py @@ -1,4 +1,5 @@ """General pytest tests for axolotl.cli.main interface.""" + from axolotl.cli.main import build_command, cli diff --git a/tests/cli/test_cli_merge_lora.py b/tests/cli/test_cli_merge_lora.py index 165a64e98..aac016760 100644 --- a/tests/cli/test_cli_merge_lora.py +++ b/tests/cli/test_cli_merge_lora.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI merge_lora command.""" + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_merge_sharded_fsdp_weights.py b/tests/cli/test_cli_merge_sharded_fsdp_weights.py index cff0f3b77..420c28b9e 100644 --- a/tests/cli/test_cli_merge_sharded_fsdp_weights.py +++ b/tests/cli/test_cli_merge_sharded_fsdp_weights.py @@ -1,5 +1,6 @@ """pytest tests for axolotl CLI merge_sharded_fsdp_weights command.""" # pylint: disable=duplicate-code + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_preprocess.py b/tests/cli/test_cli_preprocess.py index 4719461aa..e2dd3a6c3 100644 --- a/tests/cli/test_cli_preprocess.py +++ b/tests/cli/test_cli_preprocess.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI preprocess command.""" + import shutil from pathlib import Path from unittest.mock import patch diff --git a/tests/cli/test_cli_shard.py b/tests/cli/test_cli_shard.py index 505a2a737..3176ed27e 100644 --- a/tests/cli/test_cli_shard.py +++ b/tests/cli/test_cli_shard.py @@ -1,5 +1,6 @@ """pytest tests for axolotl CLI shard command.""" # pylint: disable=duplicate-code + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_version.py b/tests/cli/test_cli_version.py index 819780e94..533dd5c0e 100644 --- a/tests/cli/test_cli_version.py +++ b/tests/cli/test_cli_version.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI --version""" + from axolotl.cli.main import cli diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index b88e4ac72..ecb0025e4 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -1,5 +1,6 @@ """pytest tests for axolotl CLI utils.""" # pylint: disable=redefined-outer-name + import json from unittest.mock import Mock, patch diff --git a/tests/e2e/integrations/convert_differential_transformer/conftest.py b/tests/e2e/integrations/convert_differential_transformer/conftest.py index 17a424ddb..ed1eb3f36 100644 --- a/tests/e2e/integrations/convert_differential_transformer/conftest.py +++ b/tests/e2e/integrations/convert_differential_transformer/conftest.py @@ -1,6 +1,7 @@ """Shared fixtures for differential transformer conversion tests.""" import pytest +from click.testing import CliRunner @pytest.fixture() @@ -26,3 +27,8 @@ def base_config(): "pad_token": "<|endoftext|>", }, } + + +@pytest.fixture +def cli_runner(): + return CliRunner() diff --git a/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py b/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py index 84e5fdaa1..42ce3e612 100644 --- a/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py +++ b/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Optional +from unittest.mock import patch import pytest import yaml @@ -12,9 +13,41 @@ from axolotl.cli.integrations.convert_differential_transformer import ( convert_differential_transformer, ) +from axolotl.cli.main import cli from axolotl.common.cli import ConvertDiffTransformerCliArgs +def test_cli_validation(cli_runner): + # Test missing config file + result = cli_runner.invoke(cli, ["convert-differential-transformer"]) + assert result.exit_code != 0 + assert "Error: Missing argument 'CONFIG'." in result.output + + # Test non-existent config file + result = cli_runner.invoke( + cli, ["convert-differential-transformer", "nonexistent.yml"] + ) + assert result.exit_code != 0 + assert "Error: Invalid value for 'CONFIG'" in result.output + + +def test_basic_execution(cli_runner, tmp_path: Path, base_config): + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + with patch( + "axolotl.cli.integrations.convert_differential_transformer.do_cli" + ) as mock_do_cli: + result = cli_runner.invoke( + cli, ["convert-differential-transformer", str(config_path)] + ) + assert result.exit_code == 0 + + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) + + def test_conversion_cli_basic(tmp_path: Path, base_config): output_dir = tmp_path / "converted" base_config["output_dir"] = str(output_dir) @@ -113,7 +146,6 @@ def test_conversion_cli_repoduce_attentions( ) def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str): output_dir = tmp_path / "converted" - base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B" base_config["output_dir"] = str(output_dir) base_config[attention] = True