Skip to content

Commit

Permalink
moving tests around for flash_attn install
Browse files Browse the repository at this point in the history
  • Loading branch information
djsaunde committed Dec 18, 2024
1 parent 8290095 commit 16292e2
Show file tree
Hide file tree
Showing 15 changed files with 52 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def __init__(

if config.split_heads:
# Split heads mode
assert (
self.base_num_heads % 2 == 0
), "Number of heads must be even for splitting"
# assert (
# self.base_num_heads % 2 == 0
# ), "Number of heads must be even for splitting"
self.heads_per_component = self.base_num_heads // 2

# Single projections
Expand Down
1 change: 1 addition & 0 deletions tests/cli/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Shared pytest fixtures for cli module."""

import pytest
from click.testing import CliRunner

Expand Down
Empty file removed tests/cli/integrations/__init__.py
Empty file.

This file was deleted.

1 change: 1 addition & 0 deletions tests/cli/test_cli_fetch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI fetch command."""

from unittest.mock import patch

from axolotl.cli.main import fetch
Expand Down
1 change: 1 addition & 0 deletions tests/cli/test_cli_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI inference command."""

from unittest.mock import patch

from axolotl.cli.main import cli
Expand Down
1 change: 1 addition & 0 deletions tests/cli/test_cli_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""General pytest tests for axolotl.cli.main interface."""

from axolotl.cli.main import build_command, cli


Expand Down
1 change: 1 addition & 0 deletions tests/cli/test_cli_merge_lora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI merge_lora command."""

from unittest.mock import patch

from axolotl.cli.main import cli
Expand Down
1 change: 1 addition & 0 deletions tests/cli/test_cli_merge_sharded_fsdp_weights.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code

from unittest.mock import patch

from axolotl.cli.main import cli
Expand Down
1 change: 1 addition & 0 deletions tests/cli/test_cli_preprocess.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI preprocess command."""

import shutil
from pathlib import Path
from unittest.mock import patch
Expand Down
1 change: 1 addition & 0 deletions tests/cli/test_cli_shard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI shard command."""
# pylint: disable=duplicate-code

from unittest.mock import patch

from axolotl.cli.main import cli
Expand Down
1 change: 1 addition & 0 deletions tests/cli/test_cli_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI --version"""

from axolotl.cli.main import cli


Expand Down
1 change: 1 addition & 0 deletions tests/cli/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name

import json
from unittest.mock import Mock, patch

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Shared fixtures for differential transformer conversion tests."""

import pytest
from click.testing import CliRunner


@pytest.fixture()
Expand All @@ -26,3 +27,8 @@ def base_config():
"pad_token": "<|endoftext|>",
},
}


@pytest.fixture
def cli_runner():
return CliRunner()
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pathlib import Path
from typing import Optional
from unittest.mock import patch

import pytest
import yaml
Expand All @@ -12,9 +13,41 @@
from axolotl.cli.integrations.convert_differential_transformer import (
convert_differential_transformer,
)
from axolotl.cli.main import cli
from axolotl.common.cli import ConvertDiffTransformerCliArgs


def test_cli_validation(cli_runner):
# Test missing config file
result = cli_runner.invoke(cli, ["convert-differential-transformer"])
assert result.exit_code != 0
assert "Error: Missing argument 'CONFIG'." in result.output

# Test non-existent config file
result = cli_runner.invoke(
cli, ["convert-differential-transformer", "nonexistent.yml"]
)
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output


def test_basic_execution(cli_runner, tmp_path: Path, base_config):
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)

with patch(
"axolotl.cli.integrations.convert_differential_transformer.do_cli"
) as mock_do_cli:
result = cli_runner.invoke(
cli, ["convert-differential-transformer", str(config_path)]
)
assert result.exit_code == 0

mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)


def test_conversion_cli_basic(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
Expand Down Expand Up @@ -113,7 +146,6 @@ def test_conversion_cli_repoduce_attentions(
)
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
output_dir = tmp_path / "converted"
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True

Expand Down

0 comments on commit 16292e2

Please sign in to comment.