Skip to content

Commit

Permalink
Unsloth rope (#1767)
Browse files Browse the repository at this point in the history
* Add unsloth rope embeddings support

* support for models weights in 4bit and do some memory gc

* use accelerate logger

* add unsloth llama rms norm optims

* update docs for unsloth

* more docs info
  • Loading branch information
winglian authored Jul 18, 2024
1 parent c86c32a commit 7830fe0
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 11 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Features:
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Unsloth](./docs/unsloth.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
Expand Down
1 change: 1 addition & 0 deletions _quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ website:
- docs/nccl.qmd
- docs/mac.qmd
- docs/multi-node.qmd
- docs/unsloth.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Reference"
Expand Down
49 changes: 49 additions & 0 deletions docs/unsloth.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
---
title: "Unsloth"
description: "Hyper-optimized QLoRA finetuning for single GPUs"
---

### Overview

Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
standard industry baselines.


### Installation

The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
to date libraries.

```bash
pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps --force-reinstall xformers==0.0.26.post1
```

### Using unsloth w Axolotl

Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.

Our unsloth integration is currently limited to the following model architectures:
- llama

These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
```yaml
unsloth_lora_mlp: true
unsloth_lora_qkv: true
unsloth_lora_o: true
```
These options are composable and can be used with multi-gpu finetuning
```
unsloth_cross_entropy_loss: true
unsloth_rms_norm: true
unsloth_rope: true
```
### Limitations
- Single GPU only; e.g. no multi-gpu support
- No deepspeed or FSDP support (requires multi-gpu)
- LoRA + QLoRA support only. No full fine tunes or fp8 support.
- Limited model architecture support. Llama, Phi, Gemma, Mistral only
- No MoE support.
63 changes: 54 additions & 9 deletions src/axolotl/monkeypatch/unsloth_.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
"""module for patching with unsloth optimizations"""

import inspect
import logging
import re
import types
from typing import Tuple

import torch
from accelerate.logging import get_logger
from peft import PeftModelForCausalLM
from torch import nn
from transformers.models.llama.modeling_llama import (
LlamaFlashAttention2,
LlamaForCausalLM,
)

LOG = logging.getLogger("axolotl.monkeypatch.unsloth")
LOG = get_logger("axolotl.monkeypatch.unsloth")

ORIGINAL_CEL_CODE = """ if labels is not None:
# Shift so that tokens < n predict n
Expand Down Expand Up @@ -137,7 +139,7 @@ def integrate_cross_entropy_loss_patch():
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
print("patching unsloth fast_cross_entropy_loss")
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821


Expand Down Expand Up @@ -179,12 +181,30 @@ def patch_self_attn_lora():
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
print("patching unsloth attn lora")
LOG.info("patching unsloth attn lora", main_process_only=True)
LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
)


def integrate_rope_embeddings():
import transformers.models.llama.modeling_llama
from unsloth.kernels.rope_embedding import fast_rope_embedding

def apply_rotary_pos_emb( # pylint: disable=unused-argument
q, # pylint: disable=invalid-name
k, # pylint: disable=invalid-name
cos,
sin,
position_ids=None,
unsqueeze_dim=1,
):
return fast_rope_embedding(q, k, cos, sin)

LOG.info("patching unsloth RoPE embeddings", main_process_only=True)
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb


def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
from unsloth.kernels import apply_lora_mlp_swiglu
Expand Down Expand Up @@ -217,7 +237,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
else:
logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx)


def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
Expand All @@ -243,9 +263,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_qkv = apply_lora_qkv
else:
layer.self_attn.apply_qkv = original_apply_qkv
logging.warning(
"unable to apply unsloth lora qkv patch to layer %d", idx
)
LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx)
if cfg.unsloth_lora_o:
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
Expand All @@ -264,6 +282,33 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_o = apply_lora_o
else:
layer.self_attn.apply_o = original_apply_o
logging.warning(
LOG.warning(
"unable to apply unsloth lora o_proj patch to layer %d", idx
)


def patch_unsloth_layernorm():
try:
import transformers.models.llama.modeling_llama
from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm

class LlamaRMSNorm(nn.Module):
"""LlamaRMSNorm"""

def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
return Fast_RMS_Layernorm.apply(
hidden_states, self.weight, self.variance_epsilon, False
)

LOG.info("patching with unsloth.kernels.rms_layernorm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning("missing unsloth library")
18 changes: 18 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import os
from enum import Enum
from importlib.metadata import version
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from pydantic import BaseModel, Field, conlist, field_validator, model_validator
Expand Down Expand Up @@ -596,6 +597,8 @@ class Config:
unsloth_lora_mlp: Optional[bool] = None
unsloth_lora_qkv: Optional[bool] = None
unsloth_lora_o: Optional[bool] = None
unsloth_rms_norm: Optional[bool] = None
unsloth_rope: Optional[bool] = None

deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None
Expand Down Expand Up @@ -1164,6 +1167,21 @@ def check_qlora_unsloth(cls, data):
)
return data

@model_validator(mode="before")
@classmethod
def check_unsloth_xformers_version(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
xformers_version = version("xformers")
if xformers_version == "0.0.27":
raise ValueError(
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
)
return data

@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):
Expand Down
17 changes: 15 additions & 2 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module for models and model loading"""

# pylint: disable=too-many-lines

import gc
import logging
import math
import os
Expand Down Expand Up @@ -94,7 +94,7 @@ def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDef
"Please make sure to point to a GPTQ model."
)

if not cfg.gptq and quant_config_exists:
if not cfg.gptq and quant_config_exists and not cfg.load_in_4bit:
raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
Expand Down Expand Up @@ -358,6 +358,10 @@ def load_model(
patch_llama_cross_entropy()
if cfg.flash_attn_rms_norm:
patch_llama_rms_norm()
elif cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm

patch_unsloth_layernorm()
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import (
integrate_cross_entropy_loss_patch,
Expand Down Expand Up @@ -884,6 +888,15 @@ def load_model(

integrate_lora_patch(model, cfg)

if cfg.unsloth_rope:
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings

integrate_rope_embeddings()

for _ in range(3):
gc.collect()
torch.cuda.empty_cache()

# TODO resume_from_checkpoint handling
return model, lora_config

Expand Down

0 comments on commit 7830fe0

Please sign in to comment.