diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 611881ab1a..27d27831fd 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -8,6 +8,7 @@ from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art from axolotl.common.cli import TrainerCliArgs +from axolotl.utils.dict import DictDefault def do_cli(config: Path = Path("examples/"), **kwargs): @@ -27,21 +28,26 @@ def do_cli(config: Path = Path("examples/"), **kwargs): flash_attention=False, **kwargs, ) + cfg = modify_cfg_for_merge(parsed_cfg) - if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir: - parsed_cfg.lora_model_dir = parsed_cfg.output_dir - if not Path(parsed_cfg.lora_model_dir).exists(): + do_merge_lora(cfg=cfg, cli_args=parsed_cli_args) + + +def modify_cfg_for_merge(cfg: DictDefault) -> DictDefault: + if not cfg.lora_model_dir and cfg.output_dir: + cfg.lora_model_dir = cfg.output_dir + if not Path(cfg.lora_model_dir).exists(): raise ValueError( - f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist." + f"Target directory for merge: `{cfg.lora_model_dir}` does not exist." ) - parsed_cfg.load_in_4bit = False - parsed_cfg.load_in_8bit = False - parsed_cfg.flash_attention = False - parsed_cfg.deepspeed = None - parsed_cfg.fsdp = None + cfg.load_in_4bit = False + cfg.load_in_8bit = False + cfg.flash_attention = False + cfg.deepspeed = None + cfg.fsdp = None - do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) + return cfg if __name__ == "__main__": diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index c79652bef7..02d71d1746 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -1,13 +1,16 @@ """ E2E tests for lora llama """ - +import json import logging import os import unittest from pathlib import Path -from axolotl.cli import load_datasets +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.cli import do_merge_lora, load_datasets +from axolotl.cli.merge_lora import modify_cfg_for_merge from axolotl.common.cli import TrainerCliArgs from axolotl.train import train from axolotl.utils.config import normalize_config @@ -39,11 +42,6 @@ def test_lora(self, temp_dir): "lora_dropout": 0.05, "lora_target_linear": True, "val_set_size": 0.1, - "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", - }, "datasets": [ { "path": "mhenrichsen/alpaca_2k_test", @@ -57,6 +55,7 @@ def test_lora(self, temp_dir): "learning_rate": 0.00001, "optimizer": "adamw_torch", "lr_scheduler": "cosine", + "max_steps": 10, } ) normalize_config(cfg) @@ -65,3 +64,67 @@ def test_lora(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_lora_merge(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 32, + "lora_alpha": 64, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 10, + "bf16": "auto", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + cfg.lora_model_dir = cfg.output_dir + cfg.load_in_4bit = False + cfg.load_in_8bit = False + cfg.flash_attention = False + cfg.deepspeed = None + cfg.fsdp = None + + cfg = modify_cfg_for_merge(cfg) + cfg.merge_lora = True + + cli_args = TrainerCliArgs(merge_lora=True) + + do_merge_lora(cfg=cfg, cli_args=cli_args) + assert (Path(temp_dir) / "merged/pytorch_model.bin").exists() + + with open( + Path(temp_dir) / "merged/config.json", "r", encoding="utf-8" + ) as f_handle: + config = f_handle.read() + config = json.loads(config) + if is_torch_bf16_gpu_available(): + assert config["torch_dtype"] == "bfloat16" + else: + assert config["torch_dtype"] == "float16"