diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 4e885a76d..33d12157a 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -4,7 +4,6 @@ import logging import os -import unittest from pathlib import Path from axolotl.cli import load_datasets @@ -13,18 +12,15 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir - LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" -class TestLlama(unittest.TestCase): +class TestLlama: """ Test case for Llama models """ - @with_temp_dir def test_fft_trust_remote_code(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( @@ -46,7 +42,8 @@ def test_fft_trust_remote_code(self, temp_dir): }, ], "num_epochs": 1, - "micro_batch_size": 8, + "max_steps": 5, + "micro_batch_size": 2, "gradient_accumulation_steps": 1, "output_dir": temp_dir, "learning_rate": 0.00001, @@ -64,3 +61,46 @@ def test_fft_trust_remote_code(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() + + def test_fix_untrained_tokens(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "fix_untrained_tokens": True, + "sequence_len": 512, + "val_set_size": 0.0, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "chat_template": "chatml", + "datasets": [ + { + "path": "mlabonne/FineTome-100k", + "type": "chat_template", + "split": "train[:10%]", + "field_messages": "conversations", + "message_field_role": "from", + "message_field_content": "value", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "bf16": True, + "save_safetensors": True, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists()