From b3a61e8ce2aa547226d77361e08f1c5b7c99c0dd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 15 Nov 2023 23:05:55 -0500 Subject: [PATCH] add e2e tests for checking functionality of resume from checkpoint (#865) * use tensorboard to see if resume from checkpoint works * make sure e2e test is either fp16 or bf16 * set max_steps and save limit so we have the checkpoint when testing resuming * fix test parameters --- requirements.txt | 1 + tests/e2e/test_lora_llama.py | 1 + tests/e2e/test_resume.py | 95 ++++++++++++++++++++++++++++++++++++ tests/e2e/utils.py | 13 ++++- 4 files changed, 109 insertions(+), 1 deletion(-) create mode 100644 tests/e2e/test_resume.py diff --git a/requirements.txt b/requirements.txt index a8c01b53fc..9ed66033bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,3 +32,4 @@ pynvml art fschat==0.2.29 gradio +tensorboard diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index c13243dd88..9d795601a4 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -101,6 +101,7 @@ def test_lora_packing(self, temp_dir): "learning_rate": 0.00001, "optimizer": "adamw_torch", "lr_scheduler": "cosine", + "bf16": True, } ) normalize_config(cfg) diff --git a/tests/e2e/test_resume.py b/tests/e2e/test_resume.py new file mode 100644 index 0000000000..98ec3ac6bf --- /dev/null +++ b/tests/e2e/test_resume.py @@ -0,0 +1,95 @@ +""" +E2E tests for resuming training +""" + +import logging +import os +import re +import subprocess +import unittest +from pathlib import Path + +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import most_recent_subdir, with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestResumeLlama(unittest.TestCase): + """ + Test case for resuming training of llama models + """ + + @with_temp_dir + def test_resume_qlora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "sample_packing": True, + "flash_attention": True, + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 32, + "lora_alpha": 64, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "vicgalle/alpaca-gpt4", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "save_steps": 10, + "save_total_limit": 5, + "max_steps": 40, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + + resume_cfg = cfg | DictDefault( + { + "resume_from_checkpoint": f"{temp_dir}/checkpoint-30/", + } + ) + normalize_config(resume_cfg) + cli_args = TrainerCliArgs() + + train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + tb_log_path_1 = most_recent_subdir(temp_dir + "/runs") + cmd = f"tensorboard --inspect --logdir {tb_log_path_1}" + res = subprocess.run( + cmd, shell=True, text=True, capture_output=True, check=True + ) + pattern = r"first_step\s+(\d+)" + first_steps = int(re.findall(pattern, res.stdout)[0]) + assert first_steps == 31 diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 8b6c566d1d..203824fc9d 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -1,10 +1,11 @@ """ helper utils for tests """ - +import os import shutil import tempfile from functools import wraps +from pathlib import Path def with_temp_dir(test_func): @@ -20,3 +21,13 @@ def wrapper(*args, **kwargs): shutil.rmtree(temp_dir) return wrapper + + +def most_recent_subdir(path): + base_path = Path(path) + subdirectories = [d for d in base_path.iterdir() if d.is_dir()] + if not subdirectories: + return None + subdir = max(subdirectories, key=os.path.getctime) + + return subdir