diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b5c336246b..5b380c7845 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -67,6 +67,7 @@ jobs: run: | pip3 show torch pip3 install -U -e . + python scripts/unsloth_install.py | sh pip3 install -r requirements-dev.txt -r requirements-tests.txt - name: Run tests diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 8fd040d77b..65553d60b5 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -37,6 +37,8 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ fi +RUN python scripts/unsloth_install.py | sh + # So we can test the Docker image RUN pip install -r requirements-dev.txt -r requirements-tests.txt diff --git a/docker/Dockerfile b/docker/Dockerfile index 6e14a70a76..173a508792 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -26,6 +26,8 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ fi +RUN python scripts/unsloth_install.py | sh + # So we can test the Docker image RUN pip install pytest diff --git a/scripts/unsloth_install.py b/scripts/unsloth_install.py index 66b983e72d..a6570b4e9f 100644 --- a/scripts/unsloth_install.py +++ b/scripts/unsloth_install.py @@ -8,7 +8,10 @@ v = V(torch.__version__) cuda = str(torch.version.cuda) -is_ampere = torch.cuda.get_device_capability()[0] >= 8 +try: + is_ampere = torch.cuda.get_device_capability()[0] >= 8 +except RuntimeError: + is_ampere = False if cuda != "12.1" and cuda != "11.8" and cuda != "12.4": raise RuntimeError(f"CUDA = {cuda} not supported!") if v <= V("2.1.0"): @@ -29,5 +32,5 @@ raise RuntimeError(f"Torch = {v} too new!") x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "") print( - f'pip install unsloth-zoo && pip install --no-deps "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"' + f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"' ) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index c804d0c6b9..ad0459ccc0 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -4,7 +4,6 @@ import logging import warnings -from functools import partial from typing import List, Optional, Tuple, Union import torch @@ -94,14 +93,33 @@ def replace_llama_qkv_with_fused(model): set_module_name(model, name, qkv) -def patch_llama_cross_entropy(): - from flash_attn.losses.cross_entropy import CrossEntropyLoss - - LOG.info("patching with flash_attn.losses.cross_entropy") - transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( - CrossEntropyLoss, inplace_backward=True +def patch_fa_llama_cross_entropy(): + LOG.info( + "patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy" + ) + from flash_attn.ops.triton.cross_entropy import ( + cross_entropy_loss as flash_attn_cross_entropy_loss, ) + def fa2_fixed_cross_entropy( + source, + target, + num_items_in_batch: int = None, + ignore_index: int = -100, + **kwargs, + ): # pylint: disable=unused-argument + reduction = "sum" if num_items_in_batch is not None else "mean" + loss, _ = flash_attn_cross_entropy_loss( + source, target, ignore_index=ignore_index + ) + if reduction == "sum": + loss = loss.sum() / num_items_in_batch + else: + loss = loss.sum() / (target != ignore_index).sum() + return loss + + transformers.loss.loss_utils.fixed_cross_entropy = fa2_fixed_cross_entropy + def patch_llama_rms_norm(): try: @@ -147,7 +165,7 @@ def replace_llama_attn_with_flash_attn( # skip only if explicitly disabled if cross_entropy: - patch_llama_cross_entropy() + patch_fa_llama_cross_entropy() # skip only if explicitly disabled if rms_norm: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 082df7c27b..fc1f0cf1c2 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -2,10 +2,12 @@ # pylint: disable=too-many-lines import gc +import importlib import logging import math import os import types +from functools import cached_property from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 import addict @@ -409,7 +411,7 @@ def apply_patches(self) -> None: ) if self.cfg.is_llama_derived_model: - self.patch_loss() + self.patch_loss_llama() if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora @@ -451,27 +453,34 @@ def patch_attention(self) -> None: replace_stablelm_attn_with_flash_attn(self.cfg.base_model) - def patch_loss(self) -> None: + @cached_property + def has_flash_attn(self) -> bool: + """Check if flash attention is installed""" + return importlib.util.find_spec("flash_attn") is not None + + def patch_loss_llama(self) -> None: """ Patch loss functions """ - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - patch_llama_cross_entropy, - patch_llama_rms_norm, - ) + if self.has_flash_attn: + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + patch_fa_llama_cross_entropy, + patch_llama_rms_norm, + ) + + if self.cfg.flash_attn_cross_entropy and self.has_flash_attn: + patch_fa_llama_cross_entropy() + elif self.cfg.unsloth_cross_entropy_loss: + from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch + + integrate_cross_entropy_loss_patch(model_type="llama") - if self.cfg.flash_attn_cross_entropy: - patch_llama_cross_entropy() - if self.cfg.flash_attn_rms_norm: + if self.cfg.flash_attn_rms_norm and self.has_flash_attn: patch_llama_rms_norm() elif self.cfg.unsloth_rms_norm: from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm patch_unsloth_layernorm() - if self.cfg.unsloth_cross_entropy_loss: - from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch - - integrate_cross_entropy_loss_patch(model_type="llama") if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora @@ -481,6 +490,7 @@ def patch_llama_derived_model(self) -> None: """ Modify all llama derived models in one block """ + self.patch_loss_llama() if self.cfg.flash_attention: from axolotl.monkeypatch.llama_attn_hijack_flash import ( @@ -528,16 +538,6 @@ def patch_llama_derived_model(self) -> None: "Shifted-sparse attention not currently implemented without flash attention." ) - if self.cfg.unsloth_cross_entropy_loss: - from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch - - integrate_cross_entropy_loss_patch(model_type="llama") - - if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: - from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora - - patch_self_attn_lora() - def set_auto_model_loader(self) -> None: """set self.AutoModelLoader - default value: AutoModelForCausalLM (set at __init__) diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py index 8b76362fb4..7ca1c08365 100644 --- a/tests/e2e/patched/test_fa_xentropy.py +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -4,11 +4,11 @@ import logging import os -import unittest from importlib import reload from pathlib import Path import pytest +from tbparse import SummaryReader from transformers.utils import is_torch_bf16_gpu_available from axolotl.cli import load_datasets @@ -17,7 +17,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import most_recent_subdir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -31,18 +31,20 @@ def reload_transformers(): reload(transformers.models.llama.modeling_llama) -class TestFAXentropyLlama(unittest.TestCase): +class TestFAXentropyLlama: """ Test case for Llama models using LoRA w multipack """ - @with_temp_dir - def test_lora_packing_fa_cross_entropy(self, temp_dir): + @pytest.mark.parametrize( + "gradient_accumulation_steps", + [1, 4], + ) + def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_steps): # pylint: disable=duplicate-code cfg = DictDefault( { - "base_model": "JackFram/llama-68m", - "tokenizer_type": "LlamaTokenizer", + "base_model": "HuggingFaceTB/SmolLM2-135M", "sequence_len": 1024, "sample_packing": True, "flash_attention": True, @@ -55,25 +57,29 @@ def test_lora_packing_fa_cross_entropy(self, temp_dir): "lora_target_linear": True, "val_set_size": 0.2, "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", + "pad_token": "<|endoftext|>", }, + "chat_template": "chatml", "datasets": [ { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", + "path": "mlabonne/FineTome-100k", + "field_messages": "conversations", + "message_field_content": "value", + "message_field_role": "from", + "type": "chat_template", + "split": "train[:2%]", }, ], "num_epochs": 1, - "max_steps": 10, - "save_steps": 10, - "micro_batch_size": 8, - "gradient_accumulation_steps": 1, + "max_steps": 5, + "save_steps": 5, + "micro_batch_size": 2, + "gradient_accumulation_steps": gradient_accumulation_steps, "output_dir": temp_dir, "learning_rate": 0.00001, - "optimizer": "adamw_torch", + "optimizer": "adamw_8bit", "lr_scheduler": "cosine", + "use_tensorboard": True, } ) if is_torch_bf16_gpu_available(): @@ -87,3 +93,10 @@ def test_lora_packing_fa_cross_entropy(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "adapter_model.bin").exists() + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0]) + reader = SummaryReader(event_file) + df = reader.scalars # pylint: disable=invalid-name + df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name + assert df.value.values[-1] < 1.5, "Loss is too high" diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py new file mode 100644 index 0000000000..805b150037 --- /dev/null +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -0,0 +1,186 @@ +""" +e2e tests for unsloth qlora +""" +import logging +import os +from pathlib import Path + +import pytest +from e2e.utils import most_recent_subdir +from tbparse import SummaryReader + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +# pylint: disable=duplicate-code +class TestUnslothQLoRA: + """ + Test class for Unsloth QLoRA Llama models + """ + + @pytest.mark.parametrize( + "sample_packing", + [True, False], + ) + def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sequence_len": 1024, + "sample_packing": sample_packing, + "flash_attention": True, + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 16, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.2, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "save_steps": 10, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "use_tensorboard": True, + "bf16": "auto", + } + ) + + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0]) + reader = SummaryReader(event_file) + df = reader.scalars # pylint: disable=invalid-name + df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name + assert df.value.values[-1] < 2.0, "Loss is too high" + + def test_unsloth_llama_qlora_unpacked(self, temp_dir): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sequence_len": 1024, + "sample_packing": False, + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 16, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.2, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "save_steps": 10, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "use_tensorboard": True, + "bf16": "auto", + } + ) + + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0]) + reader = SummaryReader(event_file) + df = reader.scalars # pylint: disable=invalid-name + df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name + assert df.value.values[-1] < 2.0, "Loss is too high" + + @pytest.mark.parametrize( + "sdp_attention", + [True, False], + ) + def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sequence_len": 1024, + "sample_packing": False, + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 16, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.2, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "save_steps": 10, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "sdp_attention": sdp_attention, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "use_tensorboard": True, + "fp16": True, + } + ) + + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0]) + reader = SummaryReader(event_file) + df = reader.scalars # pylint: disable=invalid-name + df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name + assert df.value.values[-1] < 2.0, "Loss is too high"