From 71d78afde4b5c2d69de4a94381bbdf5bf416fa24 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 14 Sep 2023 21:47:46 -0400 Subject: [PATCH] e2e testing --- .github/workflows/e2e.yml | 31 ++++++++++++++ .github/workflows/main.yml | 4 +- .github/workflows/tests.yml | 2 +- tests/e2e/test_lora_llama.py | 80 ++++++++++++++++++++++++++++++++++++ 4 files changed, 114 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/e2e.yml create mode 100644 tests/e2e/test_lora_llama.py diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml new file mode 100644 index 0000000000..ada1fd0c48 --- /dev/null +++ b/.github/workflows/e2e.yml @@ -0,0 +1,31 @@ +name: E2E +on: + workflow_dispatch: + +jobs: + e2e-test: + runs-on: [self-hosted, gpu] + strategy: + fail-fast: false + matrix: + python_version: ["3.10"] + timeout-minutes: 10 + + steps: + - name: Check out repository code + uses: actions/checkout@v3 + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python_version }} + cache: 'pip' # caching pip dependencies + + - name: Install dependencies + run: | + pip3 install -e . + pip3 install -r requirements-tests.txt + + - name: Run e2e tests + run: | + pytest tests/e2e/ diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 30d4774dbf..a5b4d30379 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -23,7 +23,7 @@ jobs: python_version: "3.10" pytorch: 2.0.1 axolotl_extras: - runs-on: self-hosted + runs-on: [self-hosted, gpu, docker] steps: - name: Checkout uses: actions/checkout@v3 @@ -68,7 +68,7 @@ jobs: pytorch: 2.0.1 axolotl_extras: is_latest: true - runs-on: self-hosted + runs-on: [self-hosted, gpu, docker] steps: - name: Checkout uses: actions/checkout@v3 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9ff08db074..a2ee392626 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,4 +29,4 @@ jobs: - name: Run tests run: | - pytest tests/ + pytest --ignore=tests/e2e/ tests/ diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py new file mode 100644 index 0000000000..7873b7ec20 --- /dev/null +++ b/tests/e2e/test_lora_llama.py @@ -0,0 +1,80 @@ +""" +E2E tests for lora llama +""" + +import logging +import os +import tempfile +import unittest + +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import TrainDatasetMeta, train +from axolotl.utils.config import normalize_config +from axolotl.utils.data import prepare_dataset +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_tokenizer + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +def load_datasets( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, # pylint:disable=unused-argument +) -> TrainDatasetMeta: + tokenizer = load_tokenizer(cfg) + + train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer) + + return TrainDatasetMeta( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + total_num_steps=total_num_steps, + ) + + +class TestLoraLlama(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + def test_lora(self): + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "base_model_config": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 32, + "lora_alpha": 64, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": tempfile.mkdtemp(), + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)