Skip to content

Commit

Permalink
add e2e tests for Unsloth qlora and test the builds (#2093)
Browse files Browse the repository at this point in the history
* see if unsloth installs cleanly in ci

* check unsloth install on regular tests, not sdist

* fix ampere check exception for ci

* use cached_property instead

* add an e2e test for unsloth qlora

* reduce seq len and mbsz to prevent oom in ci

* add checks for fp16 and sdp_attention

* pin unsloth to a specific release

* add unsloth to docker image too

* fix flash attn xentropy patch

* fix loss, add check for loss when using fa_xentropy

* fix special tokens for test

* typo

* test fa xentropy with and without gradient accum

* pr feedback changes
  • Loading branch information
winglian authored Nov 30, 2024
1 parent 1cf7075 commit 5f1d98e
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 50 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ jobs:
run: |
pip3 show torch
pip3 install -U -e .
python scripts/unsloth_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests
Expand Down
2 changes: 2 additions & 0 deletions cicd/Dockerfile.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi

RUN python scripts/unsloth_install.py | sh

# So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt

Expand Down
2 changes: 2 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi

RUN python scripts/unsloth_install.py | sh

# So we can test the Docker image
RUN pip install pytest

Expand Down
7 changes: 5 additions & 2 deletions scripts/unsloth_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

v = V(torch.__version__)
cuda = str(torch.version.cuda)
is_ampere = torch.cuda.get_device_capability()[0] >= 8
try:
is_ampere = torch.cuda.get_device_capability()[0] >= 8
except RuntimeError:
is_ampere = False
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V("2.1.0"):
Expand All @@ -29,5 +32,5 @@
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
print(
f'pip install unsloth-zoo && pip install --no-deps "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"'
f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"'
)
34 changes: 26 additions & 8 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
import warnings
from functools import partial
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -94,14 +93,33 @@ def replace_llama_qkv_with_fused(model):
set_module_name(model, name, qkv)


def patch_llama_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss

LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
def patch_fa_llama_cross_entropy():
LOG.info(
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
)
from flash_attn.ops.triton.cross_entropy import (
cross_entropy_loss as flash_attn_cross_entropy_loss,
)

def fa2_fixed_cross_entropy(
source,
target,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
): # pylint: disable=unused-argument
reduction = "sum" if num_items_in_batch is not None else "mean"
loss, _ = flash_attn_cross_entropy_loss(
source, target, ignore_index=ignore_index
)
if reduction == "sum":
loss = loss.sum() / num_items_in_batch
else:
loss = loss.sum() / (target != ignore_index).sum()
return loss

transformers.loss.loss_utils.fixed_cross_entropy = fa2_fixed_cross_entropy


def patch_llama_rms_norm():
try:
Expand Down Expand Up @@ -147,7 +165,7 @@ def replace_llama_attn_with_flash_attn(

# skip only if explicitly disabled
if cross_entropy:
patch_llama_cross_entropy()
patch_fa_llama_cross_entropy()

# skip only if explicitly disabled
if rms_norm:
Expand Down
46 changes: 23 additions & 23 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

# pylint: disable=too-many-lines
import gc
import importlib
import logging
import math
import os
import types
from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401

import addict
Expand Down Expand Up @@ -409,7 +411,7 @@ def apply_patches(self) -> None:
)

if self.cfg.is_llama_derived_model:
self.patch_loss()
self.patch_loss_llama()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

Expand Down Expand Up @@ -451,27 +453,34 @@ def patch_attention(self) -> None:

replace_stablelm_attn_with_flash_attn(self.cfg.base_model)

def patch_loss(self) -> None:
@cached_property
def has_flash_attn(self) -> bool:
"""Check if flash attention is installed"""
return importlib.util.find_spec("flash_attn") is not None

def patch_loss_llama(self) -> None:
"""
Patch loss functions
"""
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)
if self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_fa_llama_cross_entropy,
patch_llama_rms_norm,
)

if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
patch_fa_llama_cross_entropy()
elif self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")

if self.cfg.flash_attn_cross_entropy:
patch_llama_cross_entropy()
if self.cfg.flash_attn_rms_norm:
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm

patch_unsloth_layernorm()
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

Expand All @@ -481,6 +490,7 @@ def patch_llama_derived_model(self) -> None:
"""
Modify all llama derived models in one block
"""
self.patch_loss_llama()

if self.cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
Expand Down Expand Up @@ -528,16 +538,6 @@ def patch_llama_derived_model(self) -> None:
"Shifted-sparse attention not currently implemented without flash attention."
)

if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")

if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

patch_self_attn_lora()

def set_auto_model_loader(self) -> None:
"""set self.AutoModelLoader
- default value: AutoModelForCausalLM (set at __init__)
Expand Down
47 changes: 30 additions & 17 deletions tests/e2e/patched/test_fa_xentropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

import logging
import os
import unittest
from importlib import reload
from pathlib import Path

import pytest
from tbparse import SummaryReader
from transformers.utils import is_torch_bf16_gpu_available

from axolotl.cli import load_datasets
Expand All @@ -17,7 +17,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import most_recent_subdir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand All @@ -31,18 +31,20 @@ def reload_transformers():
reload(transformers.models.llama.modeling_llama)


class TestFAXentropyLlama(unittest.TestCase):
class TestFAXentropyLlama:
"""
Test case for Llama models using LoRA w multipack
"""

@with_temp_dir
def test_lora_packing_fa_cross_entropy(self, temp_dir):
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
)
def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
Expand All @@ -55,25 +57,29 @@ def test_lora_packing_fa_cross_entropy(self, temp_dir):
"lora_target_linear": True,
"val_set_size": 0.2,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"chat_template": "chatml",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"path": "mlabonne/FineTome-100k",
"field_messages": "conversations",
"message_field_content": "value",
"message_field_role": "from",
"type": "chat_template",
"split": "train[:2%]",
},
],
"num_epochs": 1,
"max_steps": 10,
"save_steps": 10,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"max_steps": 5,
"save_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
Expand All @@ -87,3 +93,10 @@ def test_lora_packing_fa_cross_entropy(self, temp_dir):

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 1.5, "Loss is too high"
Loading

0 comments on commit 5f1d98e

Please sign in to comment.