diff --git a/docs/unsloth.qmd b/docs/unsloth.qmd index 90cb49bafa..73b2f03036 100644 --- a/docs/unsloth.qmd +++ b/docs/unsloth.qmd @@ -11,12 +11,10 @@ standard industry baselines. ### Installation -The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up -to date libraries. +The following will install the correct unsloth and extras from source. ```bash -pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git" -pip install --no-deps --force-reinstall xformers==0.0.26.post1 +python scripts/unsloth_install.py | sh ``` ### Using unsloth w Axolotl diff --git a/scripts/unsloth_install.py b/scripts/unsloth_install.py new file mode 100644 index 0000000000..66b983e72d --- /dev/null +++ b/scripts/unsloth_install.py @@ -0,0 +1,33 @@ +# noqa +# pylint: skip-file +try: + import torch +except ImportError: + raise ImportError("Install torch via `pip install torch`") +from packaging.version import Version as V + +v = V(torch.__version__) +cuda = str(torch.version.cuda) +is_ampere = torch.cuda.get_device_capability()[0] >= 8 +if cuda != "12.1" and cuda != "11.8" and cuda != "12.4": + raise RuntimeError(f"CUDA = {cuda} not supported!") +if v <= V("2.1.0"): + raise RuntimeError(f"Torch = {v} too old!") +elif v <= V("2.1.1"): + x = "cu{}{}-torch211" +elif v <= V("2.1.2"): + x = "cu{}{}-torch212" +elif v < V("2.3.0"): + x = "cu{}{}-torch220" +elif v < V("2.4.0"): + x = "cu{}{}-torch230" +elif v < V("2.5.0"): + x = "cu{}{}-torch240" +elif v < V("2.6.0"): + x = "cu{}{}-torch250" +else: + raise RuntimeError(f"Torch = {v} too new!") +x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "") +print( + f'pip install unsloth-zoo && pip install --no-deps "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"' +) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 589b6b575a..7c8db7ce8e 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -30,7 +30,10 @@ from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta -from axolotl.utils.chat_templates import get_chat_template +from axolotl.utils.chat_templates import ( + get_chat_template, + get_chat_template_from_config, +) from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.config import ( normalize_cfg_datasets, @@ -199,6 +202,10 @@ def do_inference( ) elif cfg.chat_template: chat_template_str = get_chat_template(cfg.chat_template) + elif cfg.datasets[0].type == "chat_template": + chat_template_str = get_chat_template_from_config( + cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer + ) model = model.to(cfg.device, dtype=cfg.torch_dtype) diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index c8272ac735..38bbdc88fb 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -188,7 +188,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM): for module in layer_modules ) mlp_not_dora = all( - getattr(module, "lora_magnitude_vector", None) is None + len(getattr(module, "lora_magnitude_vector", []) or []) == 0 for module in layer_modules ) @@ -213,7 +213,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): for module in layer_modules ) qkv_not_dora = all( - getattr(module, "lora_magnitude_vector", None) is None + len(getattr(module, "lora_magnitude_vector", []) or []) == 0 for module in layer_modules ) @@ -232,7 +232,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): for module in layer_modules ) o_not_dora = all( - getattr(module, "lora_magnitude_vector", None) is None + len(getattr(module, "lora_magnitude_vector", []) or []) == 0 for module in layer_modules ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 9d21a294c1..f4420ae2cd 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -7,7 +7,6 @@ import logging import os from enum import Enum -from importlib.metadata import version from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union from pydantic import ( @@ -1425,21 +1424,6 @@ def check_qlora_unsloth(cls, data): ) return data - @model_validator(mode="before") - @classmethod - def check_unsloth_xformers_version(cls, data): - if ( - data.get("unsloth_lora_mlp") - or data.get("unsloth_lora_qkv") - or data.get("unsloth_lora_o") - ): - xformers_version = version("xformers") - if xformers_version == "0.0.27": - raise ValueError( - "xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1" - ) - return data - @model_validator(mode="before") @classmethod def check_torch_compile_deepspeed(cls, data):