-
-
Notifications
You must be signed in to change notification settings - Fork 921
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
attempt to also run e2e tests that needs gpus (#1070)
* attempt to also run e2e tests that needs gpus * fix stray quote * checkout specific github ref * dockerfile for tests with proper checkout ensure wandb is dissabled for docker pytests clear wandb env after testing clear wandb env after testing make sure to provide a default val for pop tryin skipping wandb validation tests explicitly disable wandb in the e2e tests explicitly report_to None to see if that fixes the docker e2e tests split gpu from non-gpu unit tests skip bf16 check in test for now build docker w/o cache since it uses branch name ref revert some changes now that caching is fixed skip bf16 check if on gpu w support * pytest skip for auto-gptq requirements * skip mamba tests for now, split multipack and non packed lora llama tests * split tests that use monkeypatches * fix relative import for prev commit * move other tests using monkeypatches to the correct run
- Loading branch information
Showing
13 changed files
with
214 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
ARG BASE_TAG=main-base | ||
FROM winglian/axolotl-base:$BASE_TAG | ||
|
||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" | ||
ARG AXOLOTL_EXTRAS="" | ||
ARG CUDA="118" | ||
ENV BNB_CUDA_VERSION=$CUDA | ||
ARG PYTORCH_VERSION="2.0.1" | ||
ARG GITHUB_REF="main" | ||
|
||
ENV PYTORCH_VERSION=$PYTORCH_VERSION | ||
|
||
RUN apt-get update && \ | ||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev | ||
|
||
WORKDIR /workspace | ||
|
||
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git | ||
|
||
WORKDIR /workspace/axolotl | ||
|
||
RUN git fetch origin +$GITHUB_REF && \ | ||
git checkout FETCH_HEAD | ||
|
||
# If AXOLOTL_EXTRAS is set, append it in brackets | ||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ | ||
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \ | ||
else \ | ||
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \ | ||
fi | ||
|
||
# So we can test the Docker image | ||
RUN pip install pytest | ||
|
||
# fix so that git fetch/pull from remote works | ||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ | ||
git config --get remote.origin.fetch | ||
|
||
# helper for huggingface-login cli | ||
RUN git config --global credential.helper store |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
""" | ||
E2E tests for lora llama | ||
""" | ||
|
||
import logging | ||
import os | ||
import unittest | ||
from pathlib import Path | ||
|
||
import pytest | ||
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available | ||
|
||
from axolotl.cli import load_datasets | ||
from axolotl.common.cli import TrainerCliArgs | ||
from axolotl.train import train | ||
from axolotl.utils.config import normalize_config | ||
from axolotl.utils.dict import DictDefault | ||
|
||
from ..utils import with_temp_dir | ||
|
||
LOG = logging.getLogger("axolotl.tests.e2e") | ||
os.environ["WANDB_DISABLED"] = "true" | ||
|
||
|
||
class TestLoraLlama(unittest.TestCase): | ||
""" | ||
Test case for Llama models using LoRA w multipack | ||
""" | ||
|
||
@with_temp_dir | ||
def test_lora_packing(self, temp_dir): | ||
# pylint: disable=duplicate-code | ||
cfg = DictDefault( | ||
{ | ||
"base_model": "JackFram/llama-68m", | ||
"tokenizer_type": "LlamaTokenizer", | ||
"sequence_len": 1024, | ||
"sample_packing": True, | ||
"flash_attention": True, | ||
"load_in_8bit": True, | ||
"adapter": "lora", | ||
"lora_r": 32, | ||
"lora_alpha": 64, | ||
"lora_dropout": 0.05, | ||
"lora_target_linear": True, | ||
"val_set_size": 0.1, | ||
"special_tokens": { | ||
"unk_token": "<unk>", | ||
"bos_token": "<s>", | ||
"eos_token": "</s>", | ||
}, | ||
"datasets": [ | ||
{ | ||
"path": "mhenrichsen/alpaca_2k_test", | ||
"type": "alpaca", | ||
}, | ||
], | ||
"num_epochs": 2, | ||
"micro_batch_size": 8, | ||
"gradient_accumulation_steps": 1, | ||
"output_dir": temp_dir, | ||
"learning_rate": 0.00001, | ||
"optimizer": "adamw_torch", | ||
"lr_scheduler": "cosine", | ||
} | ||
) | ||
if is_torch_bf16_gpu_available(): | ||
cfg.bf16 = True | ||
else: | ||
cfg.fp16 = True | ||
|
||
normalize_config(cfg) | ||
cli_args = TrainerCliArgs() | ||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) | ||
|
||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) | ||
assert (Path(temp_dir) / "adapter_model.bin").exists() | ||
|
||
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available") | ||
@with_temp_dir | ||
def test_lora_gptq_packed(self, temp_dir): | ||
# pylint: disable=duplicate-code | ||
cfg = DictDefault( | ||
{ | ||
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ", | ||
"model_type": "AutoModelForCausalLM", | ||
"tokenizer_type": "LlamaTokenizer", | ||
"sequence_len": 1024, | ||
"sample_packing": True, | ||
"flash_attention": True, | ||
"load_in_8bit": True, | ||
"adapter": "lora", | ||
"gptq": True, | ||
"gptq_disable_exllama": True, | ||
"lora_r": 32, | ||
"lora_alpha": 64, | ||
"lora_dropout": 0.05, | ||
"lora_target_linear": True, | ||
"val_set_size": 0.1, | ||
"special_tokens": { | ||
"unk_token": "<unk>", | ||
"bos_token": "<s>", | ||
"eos_token": "</s>", | ||
}, | ||
"datasets": [ | ||
{ | ||
"path": "mhenrichsen/alpaca_2k_test", | ||
"type": "alpaca", | ||
}, | ||
], | ||
"num_epochs": 2, | ||
"save_steps": 0.5, | ||
"micro_batch_size": 8, | ||
"gradient_accumulation_steps": 1, | ||
"output_dir": temp_dir, | ||
"learning_rate": 0.00001, | ||
"optimizer": "adamw_torch", | ||
"lr_scheduler": "cosine", | ||
} | ||
) | ||
normalize_config(cfg) | ||
cli_args = TrainerCliArgs() | ||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) | ||
|
||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) | ||
assert (Path(temp_dir) / "adapter_model.bin").exists() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.