Skip to content

Commit

Permalink
fix inference when no chat_template is set, fix unsloth dora check (#…
Browse files Browse the repository at this point in the history
…2092)

* fix inference when no chat_template is set, fix unsloth dora check

* remove old unsloth version check

* update docs on installing unsloth
  • Loading branch information
winglian authored Nov 20, 2024
1 parent 68a26f1 commit 2e99bb3
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 24 deletions.
6 changes: 2 additions & 4 deletions docs/unsloth.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@ standard industry baselines.

### Installation

The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
to date libraries.
The following will install the correct unsloth and extras from source.

```bash
pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps --force-reinstall xformers==0.0.26.post1
python scripts/unsloth_install.py | sh
```

### Using unsloth w Axolotl
Expand Down
33 changes: 33 additions & 0 deletions scripts/unsloth_install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# noqa
# pylint: skip-file
try:
import torch
except ImportError:
raise ImportError("Install torch via `pip install torch`")
from packaging.version import Version as V

v = V(torch.__version__)
cuda = str(torch.version.cuda)
is_ampere = torch.cuda.get_device_capability()[0] >= 8
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V("2.1.0"):
raise RuntimeError(f"Torch = {v} too old!")
elif v <= V("2.1.1"):
x = "cu{}{}-torch211"
elif v <= V("2.1.2"):
x = "cu{}{}-torch212"
elif v < V("2.3.0"):
x = "cu{}{}-torch220"
elif v < V("2.4.0"):
x = "cu{}{}-torch230"
elif v < V("2.5.0"):
x = "cu{}{}-torch240"
elif v < V("2.6.0"):
x = "cu{}{}-torch250"
else:
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
print(
f'pip install unsloth-zoo && pip install --no-deps "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"'
)
9 changes: 8 additions & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
Expand Down Expand Up @@ -199,6 +202,10 @@ def do_inference(
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)

model = model.to(cfg.device, dtype=cfg.torch_dtype)

Expand Down
6 changes: 3 additions & 3 deletions src/axolotl/monkeypatch/unsloth_.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
for module in layer_modules
)
mlp_not_dora = all(
getattr(module, "lora_magnitude_vector", None) is None
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)

Expand All @@ -213,7 +213,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
for module in layer_modules
)
qkv_not_dora = all(
getattr(module, "lora_magnitude_vector", None) is None
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)

Expand All @@ -232,7 +232,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
for module in layer_modules
)
o_not_dora = all(
getattr(module, "lora_magnitude_vector", None) is None
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)

Expand Down
16 changes: 0 additions & 16 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
import os
from enum import Enum
from importlib.metadata import version
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union

from pydantic import (
Expand Down Expand Up @@ -1425,21 +1424,6 @@ def check_qlora_unsloth(cls, data):
)
return data

@model_validator(mode="before")
@classmethod
def check_unsloth_xformers_version(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
xformers_version = version("xformers")
if xformers_version == "0.0.27":
raise ValueError(
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
)
return data

@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):
Expand Down

0 comments on commit 2e99bb3

Please sign in to comment.