diff --git a/README.md b/README.md
index 99432fa32b..054d7db54f 100644
--- a/README.md
+++ b/README.md
@@ -13,6 +13,9 @@ Features:
- Log results and optionally checkpoints to wandb or mlflow
- And more!
+
+
+
@@ -28,6 +31,7 @@ Features:
- [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- [Windows](#windows)
+ - [Mac](#mac)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Dataset](#dataset)
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
@@ -99,24 +103,14 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
-### For developers
```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install packaging
-```
-
-General case:
-```
pip3 install -e '.[flash-attn,deepspeed]'
```
-Mac: see https://github.com/OpenAccess-AI-Collective/axolotl/blob/13199f678b9aab39e92961323bdbce3234ee4b2b/docs/mac.md
-```
-pip3 install -e '.'
-```
-
### Usage
```bash
# preprocess datasets - optional but recommended
@@ -249,9 +243,31 @@ For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud
```
+##### GCP
+
+
+
+Click to Expand
+
+Use a Deeplearning linux OS with cuda and pytorch installed. Then follow instructions on quickstart.
+
+Make sure to run the below to uninstall xla.
+```bash
+pip uninstall -y torch_xla[tpu]
+```
+
+
+
#### Windows
Please use WSL or Docker!
+#### Mac
+
+Use the below instead of the install method in QuickStart.
+```
+pip3 install -e '.'
+```
+More info: [mac.md](/docs/mac.md)
#### Launching on public clouds via SkyPilot
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
@@ -1084,6 +1100,10 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
+##### FSDP + QLoRA
+
+Axolotl supports training with FSDP and QLoRA, see [these docs](docs/fsdp_qlora.md) for more information.
+
##### Weights & Biases Logging
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
diff --git a/deepspeed_configs/zero1.json b/deepspeed_configs/zero1.json
index 787fc0d6b7..8a57b8605a 100644
--- a/deepspeed_configs/zero1.json
+++ b/deepspeed_configs/zero1.json
@@ -16,6 +16,7 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
diff --git a/deepspeed_configs/zero2.json b/deepspeed_configs/zero2.json
index 5b22d996c5..153ac02802 100644
--- a/deepspeed_configs/zero2.json
+++ b/deepspeed_configs/zero2.json
@@ -20,6 +20,7 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
diff --git a/deepspeed_configs/zero3.json b/deepspeed_configs/zero3.json
index a185afab44..90ec3677ea 100644
--- a/deepspeed_configs/zero3.json
+++ b/deepspeed_configs/zero3.json
@@ -24,6 +24,7 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
diff --git a/deepspeed_configs/zero3_bf16.json b/deepspeed_configs/zero3_bf16.json
index 263caa393b..16e64d76b4 100644
--- a/deepspeed_configs/zero3_bf16.json
+++ b/deepspeed_configs/zero3_bf16.json
@@ -24,6 +24,7 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
diff --git a/docs/fsdp_qlora.md b/docs/fsdp_qlora.md
new file mode 100644
index 0000000000..14b2c1a571
--- /dev/null
+++ b/docs/fsdp_qlora.md
@@ -0,0 +1,37 @@
+# FDSP + QLoRA
+
+## Background
+
+Using FSDP with QLoRA is essential for **fine-tuning larger (70b+ parameter) LLMs on consumer GPUs.** For example, you can use FSDP + QLoRA to train a 70b model on two 24GB GPUs[^1].
+
+Below, we describe how to use this feature in Axolotl.
+
+## Usage
+
+To enable `QLoRA` with `FSDP`, you need to perform the following steps:
+
+> ![Tip]
+> See the [example config](#example-config) file in addition to reading these instructions.
+
+1. Set `adapter: qlora` in your axolotl config file.
+2. Enable FSDP in your axolotl config, as [described here](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#fsdp).
+3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
+
+## Example Config
+
+[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.
+
+## References
+
+- [PR #1378](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.
+- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP.
+- Related HuggingFace PRs Enabling FDSP + QLoRA:
+ - Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 )
+ - Transformers [PR#29587](https://github.com/huggingface/transformers/pull/29587)
+ - TRL [PR#1416](https://github.com/huggingface/trl/pull/1416)
+ - PEFT [PR#1550](https://github.com/huggingface/peft/pull/1550)
+
+
+
+
+[^1]: This was enabled by [this work](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the Answer.AI team.
diff --git a/docs/rlhf.md b/docs/rlhf.md
index 9f5ba05fdb..4f71184fc0 100644
--- a/docs/rlhf.md
+++ b/docs/rlhf.md
@@ -34,6 +34,21 @@ datasets:
rl: ipo
```
+#### ORPO
+
+Paper: https://arxiv.org/abs/2403.07691
+
+```yaml
+rl: orpo
+orpo_alpha: 0.1
+remove_unused_columns: false
+
+chat_template: chatml
+datasets:
+ - path: argilla/ultrafeedback-binarized-preferences-cleaned
+ type: orpo.chat_template
+```
+
#### Using local dataset files
```yaml
datasets:
diff --git a/examples/gemma/qlora.yml b/examples/gemma/qlora.yml
index 02a41e0cf4..262197cb7e 100644
--- a/examples/gemma/qlora.yml
+++ b/examples/gemma/qlora.yml
@@ -21,7 +21,7 @@ lora_dropout: 0.05
lora_target_linear: true
sequence_len: 4096
-sample_packing: true
+sample_packing: false
pad_to_sequence_len: true
wandb_project:
diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml
index 7c18e7098c..5ee3da9d65 100644
--- a/examples/mistral/mixtral.yml
+++ b/examples/mistral/mixtral.yml
@@ -16,12 +16,12 @@ output_dir: ./qlora-out
## You can optionally freeze the entire model and unfreeze a subset of parameters
unfrozen_parameters:
-# - lm_head.*
-# - model.embed_tokens.*
-# - model.layers.2[0-9]+.block_sparse_moe.gate.*
-# - model.layers.2[0-9]+.block_sparse_moe.experts.*
-# - model.layers.3[0-9]+.block_sparse_moe.gate.*
-# - model.layers.3[0-9]+.block_sparse_moe.experts.*
+# - ^lm_head.weight$
+# - ^model.embed_tokens.weight$[:32000]
+# - model.layers.2[0-9]+.block_sparse_moe.gate
+# - model.layers.2[0-9]+.block_sparse_moe.experts
+# - model.layers.3[0-9]+.block_sparse_moe.gate
+# - model.layers.3[0-9]+.block_sparse_moe.experts
model_config:
output_router_logits: true
diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py
index 89ab023e5f..a1a01d59de 100644
--- a/src/axolotl/cli/preprocess.py
+++ b/src/axolotl/cli/preprocess.py
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
- if parsed_cfg.rl:
+ if parsed_cfg.rl and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py
index 05fd63ae80..7e004567a6 100644
--- a/src/axolotl/cli/train.py
+++ b/src/axolotl/cli/train.py
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
else:
register_chatml_template()
- if cfg.rl:
+ if cfg.rl and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py
index d11f0c6532..42180f32b3 100644
--- a/src/axolotl/core/trainer_builder.py
+++ b/src/axolotl/core/trainer_builder.py
@@ -11,10 +11,11 @@
import os
import sys
from abc import abstractmethod
+from collections import defaultdict
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
-from typing import List, Optional, Type, Union
+from typing import Dict, List, Literal, Optional, Type, Union
import torch
import transformers
@@ -200,6 +201,9 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False,
metadata={"help": "whether this is a qlora training"},
)
+ orpo_alpha: Optional[float] = field(
+ default=None,
+ )
class AxolotlTrainer(Trainer):
@@ -223,6 +227,9 @@ def __init__(
self.eval_data_collator = eval_data_collator
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
+ if self.args.orpo_alpha:
+ self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
@@ -465,8 +472,112 @@ def compute_loss(self, model, inputs, return_outputs=False):
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
+ if self.args.orpo_alpha:
+ return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
return super().compute_loss(model, inputs, return_outputs=return_outputs)
+ def orpo_compute_custom_loss(self, logits, labels):
+ logits = logits.contiguous()
+ loss = 0.0
+
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ # Flatten the tokens
+ loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(
+ dim=-1
+ )
+
+ return loss
+
+ def orpo_compute_logps(
+ self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits
+ ):
+ # Get the shape of chosen_attention_mask[:, :-1]
+ chosen_shape = chosen_attention_mask[:, :-1].shape
+
+ # Calculate the padding size
+ pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1)
+
+ # Pad prompt_attention_mask with zeros to match the desired shape
+ prompt_attention_mask_padded = torch.nn.functional.pad(
+ prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0
+ )
+
+ # Perform the subtraction operation
+ mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded
+
+ per_token_logps = torch.gather(
+ logits[:, :-1, :].log_softmax(-1),
+ dim=2,
+ index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
+ ).squeeze(2)
+ return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(
+ dtype=torch.float64
+ ) / mask.sum(dim=1).to(dtype=torch.float64)
+
+ def orpo_compute_loss(self, model, inputs, return_outputs=False):
+ outputs_neg = model(
+ **{
+ "input_ids": inputs["rejected_input_ids"],
+ "attention_mask": inputs["rejected_attention_mask"],
+ "labels": inputs["rejected_labels"],
+ },
+ output_hidden_states=True,
+ )
+ outputs_pos = model(
+ **{
+ "input_ids": inputs["input_ids"],
+ "attention_mask": inputs["attention_mask"],
+ "labels": inputs["labels"],
+ },
+ output_hidden_states=True,
+ )
+
+ # Calculate NLL loss
+ pos_loss = self.orpo_compute_custom_loss(
+ logits=outputs_pos.logits, labels=inputs["input_ids"]
+ )
+
+ # Calculate Log Probability
+ pos_prob = self.orpo_compute_logps(
+ prompt_attention_mask=inputs["prompt_attention_mask"],
+ chosen_inputs=inputs["input_ids"],
+ chosen_attention_mask=inputs["attention_mask"],
+ logits=outputs_pos.logits,
+ )
+ neg_prob = self.orpo_compute_logps(
+ prompt_attention_mask=inputs["prompt_attention_mask"],
+ chosen_inputs=inputs["rejected_input_ids"],
+ chosen_attention_mask=inputs["rejected_attention_mask"],
+ logits=outputs_neg.logits,
+ )
+
+ # Calculate log odds
+ log_odds = (pos_prob - neg_prob) - (
+ torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob))
+ )
+ sig_ratio = torch.nn.functional.sigmoid(log_odds)
+ ratio = torch.log(sig_ratio)
+
+ # Calculate the Final Loss
+ loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to(
+ dtype=torch.bfloat16
+ )
+
+ metrics = {}
+ metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item()
+ metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item()
+ metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item()
+ metrics["log_odds"] = torch.mean(log_odds).cpu().item()
+ self.store_metrics(metrics, train_eval="train")
+
+ return (loss, outputs_pos) if return_outputs else loss
+
@wraps(Trainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
@@ -527,6 +638,28 @@ def create_accelerator_and_postprocess(self):
return res
+ def log(self, logs: Dict[str, float]) -> None:
+ """
+ Log `logs` on the various objects watching training, including stored metrics.
+
+ Args:
+ logs (`Dict[str, float]`):
+ The values to log.
+ """
+ # logs either has 'loss' or 'eval_loss'
+ train_eval = "train" if "loss" in logs else "eval"
+ # Add averaged stored metrics to logs
+ for key, metrics in self._stored_metrics[train_eval].items():
+ logs[key] = torch.tensor(metrics).mean().item()
+ del self._stored_metrics[train_eval]
+ return super().log(logs)
+
+ def store_metrics(
+ self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
+ ) -> None:
+ for key, value in metrics.items():
+ self._stored_metrics[train_eval][key].append(value)
+
class AxolotlMambaTrainer(AxolotlTrainer):
"""
@@ -903,6 +1036,11 @@ def build(self, total_num_steps):
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True
+ if self.cfg.remove_unused_columns is not None:
+ training_arguments_kwargs[
+ "remove_unused_columns"
+ ] = self.cfg.remove_unused_columns
+
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
@@ -1070,6 +1208,9 @@ def build(self, total_num_steps):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
+ if self.cfg.rl == "orpo":
+ training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
+
if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[
"neftune_noise_alpha"
diff --git a/src/axolotl/logging_config.py b/src/axolotl/logging_config.py
index 8f473aa240..2ddf89a8c4 100644
--- a/src/axolotl/logging_config.py
+++ b/src/axolotl/logging_config.py
@@ -30,6 +30,7 @@ def format(self, record):
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"version": 1,
+ "disable_existing_loggers": False,
"formatters": {
"simple": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py
index 964b41f707..fbcaf7a668 100644
--- a/src/axolotl/monkeypatch/multipack.py
+++ b/src/axolotl/monkeypatch/multipack.py
@@ -1,6 +1,9 @@
"""multipack patching for v2 of sample packing"""
+import importlib
import transformers
+from accelerate import init_empty_weights
+from transformers import AutoConfig, AutoModelForCausalLM
from transformers.integrations import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
@@ -12,11 +15,12 @@
"falcon",
"phi",
"gemma",
+ "gemmoe",
"starcoder2",
]
-def patch_for_multipack(model_type):
+def patch_for_multipack(model_type, model_name=None):
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
@@ -43,3 +47,15 @@ def patch_for_multipack(model_type):
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
+ elif model_type == "gemmoe":
+ model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
+ # we need to load the model here in order for modeling_gemmoe to be available
+ with init_empty_weights():
+ AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
+ module_name = model_config.__class__.__module__.replace(
+ ".configuration_gemmoe", ".modeling_gemmoe"
+ )
+ modeling_gemmoe = importlib.import_module(module_name)
+ modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
+ get_unpad_data
+ )
diff --git a/src/axolotl/prompt_strategies/base.py b/src/axolotl/prompt_strategies/base.py
new file mode 100644
index 0000000000..fce2aba14a
--- /dev/null
+++ b/src/axolotl/prompt_strategies/base.py
@@ -0,0 +1,20 @@
+"""
+module for base dataset transform strategies
+"""
+
+import importlib
+import logging
+
+LOG = logging.getLogger("axolotl")
+
+
+def load(strategy, cfg, module_base=None, **kwargs):
+ try:
+ load_fn = strategy.split(".")[-1]
+ strategy = ".".join(strategy.split(".")[:-1])
+ mod = importlib.import_module(f".{strategy}", module_base)
+ func = getattr(mod, load_fn)
+ return func(cfg, **kwargs)
+ except Exception: # pylint: disable=broad-exception-caught
+ LOG.warning(f"unable to load strategy {strategy}")
+ return None
diff --git a/src/axolotl/prompt_strategies/dpo/__init__.py b/src/axolotl/prompt_strategies/dpo/__init__.py
index 8bd430f912..1a149f4528 100644
--- a/src/axolotl/prompt_strategies/dpo/__init__.py
+++ b/src/axolotl/prompt_strategies/dpo/__init__.py
@@ -1,20 +1,8 @@
"""
module for DPO style dataset transform strategies
"""
+from functools import partial
-import importlib
-import logging
+from ..base import load as load_base
-LOG = logging.getLogger("axolotl")
-
-
-def load(strategy, cfg, **kwargs):
- try:
- load_fn = strategy.split(".")[-1]
- strategy = ".".join(strategy.split(".")[:-1])
- mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
- func = getattr(mod, load_fn)
- return func(cfg, **kwargs)
- except Exception: # pylint: disable=broad-exception-caught
- LOG.warning(f"unable to load strategy {strategy}")
- return None
+load = partial(load_base, module="axolotl.prompt_strategies.dpo")
diff --git a/src/axolotl/prompt_strategies/dpo/chatml.py b/src/axolotl/prompt_strategies/dpo/chatml.py
index e8c7f4088c..585696e29a 100644
--- a/src/axolotl/prompt_strategies/dpo/chatml.py
+++ b/src/axolotl/prompt_strategies/dpo/chatml.py
@@ -24,6 +24,25 @@ def transform_fn(sample):
return transform_fn
+def argilla_chat(
+ cfg,
+ **kwargs,
+): # pylint: disable=possibly-unused-variable,unused-argument
+ """
+ for argilla/dpo-mix-7k conversations
+ """
+
+ def transform_fn(sample):
+ sample[
+ "prompt"
+ ] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
+ sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
+ sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
+ return sample
+
+ return transform_fn
+
+
def icr(
cfg,
**kwargs,
diff --git a/src/axolotl/prompt_strategies/orpo/__init__.py b/src/axolotl/prompt_strategies/orpo/__init__.py
new file mode 100644
index 0000000000..3a961fcc92
--- /dev/null
+++ b/src/axolotl/prompt_strategies/orpo/__init__.py
@@ -0,0 +1,9 @@
+"""
+module for ORPO style dataset transform strategies
+"""
+
+from functools import partial
+
+from ..base import load as load_base
+
+load = partial(load_base, module="axolotl.prompt_strategies.orpo")
diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py
new file mode 100644
index 0000000000..fb39bcf8f4
--- /dev/null
+++ b/src/axolotl/prompt_strategies/orpo/chat_template.py
@@ -0,0 +1,187 @@
+"""chatml prompt tokenization strategy for ORPO"""
+from typing import Any, Dict, Generator, List, Optional, Tuple
+
+from pydantic import BaseModel
+
+from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
+from axolotl.prompters import Prompter
+from axolotl.utils.chat_templates import chat_templates
+
+
+class Message(BaseModel):
+ """message/turn"""
+
+ role: str
+ content: str
+ label: Optional[bool] = None
+
+
+class MessageList(BaseModel):
+ """conversation"""
+
+ messages: List[Message]
+
+
+def load(
+ tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwargs
+): # pylint: disable=possibly-unused-variable,unused-argument
+ """
+ chatml transforms for datasets with system, input, chosen, rejected
+ """
+
+ chat_template = chat_templates("chatml")
+ if ds_cfg and "chat_template" in ds_cfg:
+ chat_template = ds_cfg["chat_template"]
+ try:
+ chat_template = chat_templates(chat_template)
+ except ValueError:
+ pass
+
+ return ORPOTokenizingStrategy(
+ ORPOPrompter(chat_template, tokenizer),
+ tokenizer,
+ cfg.train_on_inputs,
+ cfg.sequence_len,
+ dataset_parser=ORPODatasetParsingStrategy(),
+ )
+
+
+class ORPODatasetParsingStrategy:
+ """Strategy to parse chosen rejected dataset into messagelist"""
+
+ def get_chosen_conversation_thread(self, prompt) -> MessageList:
+ """Dataset structure mappings"""
+
+ messages: List[Message] = []
+ if system := prompt.get("system", None):
+ messages.append(Message(role="system", content=system, label=False))
+ messages.append(Message(role="user", content=prompt["prompt"], label=False))
+ messages.append(
+ Message(
+ role="assistant", content=prompt["chosen"][1]["content"], label=True
+ )
+ )
+ return MessageList(messages=messages)
+
+ def get_rejected_conversation_thread(self, prompt) -> MessageList:
+ """Dataset structure mappings"""
+
+ messages: List[Message] = []
+ if system := prompt.get("system", None):
+ messages.append(Message(role="system", content=system, label=False))
+ messages.append(Message(role="user", content=prompt["prompt"], label=False))
+ messages.append(
+ Message(
+ role="assistant", content=prompt["rejected"][1]["content"], label=True
+ )
+ )
+ return MessageList(messages=messages)
+
+
+class ORPOTokenizingStrategy(PromptTokenizingStrategy):
+ """
+ rejected_input_ids
+ input_ids
+ rejected_attention_mask
+ attention_mask
+ rejected_labels
+ labels
+ """
+
+ def __init__(
+ self,
+ *args,
+ dataset_parser=None,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.dataset_parser = dataset_parser
+
+ def tokenize_prompt(self, prompt):
+ # pass the rejected prompt/row to the Prompter to get the formatted prompt
+ prompt_len = 0
+ rejected_message_list = self.dataset_parser.get_rejected_conversation_thread(
+ prompt
+ )
+ input_ids = []
+ labels = []
+ for _, (part, label) in enumerate(
+ self.prompter.build_prompt(rejected_message_list)
+ ):
+ if not part:
+ continue
+ _input_ids = self.tokenizer.encode(part, add_special_tokens=False)
+ prev_idx = len(input_ids)
+ input_ids += _input_ids[prev_idx:]
+ if label:
+ labels += input_ids[prev_idx:]
+ else:
+ labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
+ prompt_len = len(input_ids)
+ # remap the input_ids, attention_mask and labels
+ rejected_input_ids = input_ids
+ rejected_labels = labels
+ # pass the chosen prompt/row to the Prompter to get the formatted prompt
+ chosen_message_list = self.dataset_parser.get_chosen_conversation_thread(prompt)
+ input_ids = []
+ labels = []
+ for _, (part, label) in enumerate(
+ self.prompter.build_prompt(chosen_message_list)
+ ):
+ if not part:
+ continue
+ _input_ids = self.tokenizer.encode(part, add_special_tokens=False)
+ prev_idx = len(input_ids)
+ input_ids += _input_ids[prev_idx:]
+ if label:
+ labels += input_ids[prev_idx:]
+ else:
+ labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
+
+ return {
+ "rejected_input_ids": rejected_input_ids,
+ "rejected_labels": rejected_labels,
+ "rejected_attention_mask": [1] * len(rejected_labels),
+ "input_ids": input_ids,
+ "labels": labels,
+ "attention_mask": [1] * len(labels),
+ "prompt_attention_mask": [1] * prompt_len
+ + [0] * (len(labels) - prompt_len),
+ }
+
+
+class ORPOPrompter(Prompter):
+ """Single Turn prompter for ORPO"""
+
+ def __init__(self, chat_template, tokenizer):
+ self.chat_template = chat_template
+ self.tokenizer = tokenizer
+
+ def build_prompt(
+ self,
+ message_list: MessageList,
+ ) -> Generator[Tuple[str, bool], None, None]:
+ conversation = []
+ for message in message_list.messages:
+ conversation.append(message.model_dump())
+ if message.role == "system":
+ yield self.tokenizer.apply_chat_template(
+ conversation,
+ add_generation_prompt=False,
+ chat_template=self.chat_template,
+ tokenize=False,
+ ), False
+ if message.role == "user":
+ yield self.tokenizer.apply_chat_template(
+ conversation,
+ add_generation_prompt=True,
+ chat_template=self.chat_template,
+ tokenize=False,
+ ), False
+ if message.role == "assistant":
+ yield self.tokenizer.apply_chat_template(
+ conversation,
+ add_generation_prompt=False,
+ chat_template=self.chat_template,
+ tokenize=False,
+ ), True
diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py
index 6ac7cbafe9..7a7f61a8ee 100644
--- a/src/axolotl/prompt_strategies/sharegpt.py
+++ b/src/axolotl/prompt_strategies/sharegpt.py
@@ -1,10 +1,15 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
+
from typing import Any, Dict, Optional
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
+from axolotl.utils.tokenization import (
+ chatml_to_conversation,
+ merge_consecutive_messages,
+)
def register_chatml_template(system_message=None):
@@ -19,6 +24,16 @@ def register_chatml_template(system_message=None):
sep="<|im_end|>",
)
)
+ register_conv_template(
+ Conversation(
+ name="chatml_glaive",
+ system_template="<|im_start|>system\n{system_message}",
+ system_message=system_message,
+ roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
+ sep_style=SeparatorStyle.CHATML,
+ sep="<|im_end|>",
+ )
+ )
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
@@ -77,6 +92,20 @@ def load_guanaco(tokenizer, cfg):
)
+def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
+ conversation = (
+ ds_cfg["conversation"]
+ if ds_cfg and "conversation" in ds_cfg
+ else "chatml_glaive"
+ )
+ return GlaiveShareGPTPromptTokenizingStrategy(
+ ShareGPTPrompterV2(conversation=conversation),
+ tokenizer,
+ cfg.train_on_inputs,
+ cfg.sequence_len,
+ )
+
+
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
@@ -158,3 +187,15 @@ def get_conversation_thread(self, prompt):
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
]
return turns
+
+
+class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
+ """
+ sharegpt strategy that remaps glaive data to sharegpt format
+ """
+
+ def get_conversation_thread(self, prompt):
+ conversation = chatml_to_conversation(prompt)
+ conversation = merge_consecutive_messages(conversation)
+
+ return conversation
diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py
index a5c243f7e6..7e62a0cd4c 100644
--- a/src/axolotl/prompt_tokenizers.py
+++ b/src/axolotl/prompt_tokenizers.py
@@ -360,11 +360,19 @@ def tokenize_prompt(self, prompt):
LOG.warning(f"expected tuple, got {part}")
continue
- user, assistant = conversation.roles
+ tool_role_label = None
+ if len(conversation.roles) == 3:
+ (
+ user_role_label,
+ assistant_role_label,
+ tool_role_label,
+ ) = conversation.roles
+ else:
+ user_role_label, assistant_role_label = conversation.roles
role, content = part
# Uses "in" because role contains extra characters
- if user in role:
+ if user_role_label in role:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
@@ -384,7 +392,7 @@ def tokenize_prompt(self, prompt):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
- elif assistant in role:
+ elif assistant_role_label in role:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
@@ -426,6 +434,8 @@ def tokenize_prompt(self, prompt):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
+ elif tool_role_label and tool_role_label in role:
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue
diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py
index 748db1a162..fa181f916d 100644
--- a/src/axolotl/prompters.py
+++ b/src/axolotl/prompters.py
@@ -267,6 +267,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
role_key_human = "human"
role_key_model = "gpt"
+ # Optional, only used for tool usage datasets.
+ role_key_tool = None
def __init__(
self,
@@ -274,6 +276,7 @@ def __init__(
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
+ role_key_tool: Optional[str] = None,
):
if conversation:
if isinstance(conversation, Conversation):
@@ -286,6 +289,8 @@ def __init__(
self.role_key_human = role_key_human
if role_key_model:
self.role_key_model = role_key_model
+ if role_key_tool:
+ self.role_key_tool = role_key_tool
def _build_result(self, source):
if len(source) < 2:
@@ -303,6 +308,8 @@ def _build_result(self, source):
source.pop(0)
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
+ if self.role_key_tool:
+ roles[self.role_key_tool] = conv.roles[2]
try:
# Apply prompt templates
diff --git a/src/axolotl/train.py b/src/axolotl/train.py
index 0b5ef76716..b6cd24672e 100644
--- a/src/axolotl/train.py
+++ b/src/axolotl/train.py
@@ -19,7 +19,7 @@
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
-from axolotl.utils.freeze import freeze_parameters_except
+from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
@@ -85,7 +85,7 @@ def train(
model.generation_config.do_sample = True
model_ref = None
- if cfg.rl:
+ if cfg.rl and cfg.rl != "orpo":
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
@@ -99,7 +99,7 @@ def train(
safe_serialization = cfg.save_safetensors is True
if cfg.unfrozen_parameters:
- freeze_parameters_except(model, cfg.unfrozen_parameters)
+ freeze_layers_except(model, cfg.unfrozen_parameters)
trainer = setup_trainer(
cfg,
@@ -110,9 +110,6 @@ def train(
total_num_steps,
)
- if hasattr(model, "config"):
- model.config.use_cache = False
-
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py
index 1ec83536d6..fd34b4ea99 100644
--- a/src/axolotl/utils/chat_templates.py
+++ b/src/axolotl/utils/chat_templates.py
@@ -21,7 +21,7 @@ def chat_templates(user_choice: str):
templates = {
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
- "chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
+ "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}",
}
diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py
index 9151f288a8..3e743bda9f 100644
--- a/src/axolotl/utils/config/__init__.py
+++ b/src/axolotl/utils/config/__init__.py
@@ -191,6 +191,11 @@ def normalize_cfg_datasets(cfg):
f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template"
)
cfg.datasets[idx].conversation = "chatml"
+ if ds_cfg.type == "orpo.chat_template" and not ds_cfg.chat_template:
+ LOG.info(
+ f"updating dataset {ds_cfg.path} with `chat_template: chatml` to match your chat_template"
+ )
+ cfg.datasets[idx].chat_template = "chatml"
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
index a536fa9ecc..b1c395bcc8 100644
--- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
+++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
@@ -1,6 +1,7 @@
"""
Module for pydantic models for configuration
"""
+# pylint: disable=too-many-lines
import logging
import os
@@ -123,13 +124,16 @@ class RLType(str, Enum):
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
kto_pair = "kto_pair" # pylint: disable=invalid-name
+ orpo = "orpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
+ alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
inst = "inst" # pylint: disable=invalid-name
+ gemma = "gemma" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
@@ -179,6 +183,7 @@ class LoraConfig(BaseModel):
peft_layers_to_transform: Optional[List[int]] = None
peft: Optional[PeftConfig] = None
peft_use_dora: Optional[bool] = None
+ peft_use_relora: Optional[bool] = None
lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None
@@ -428,6 +433,8 @@ class Config:
dataloader_prefetch_factor: Optional[int] = None
dataloader_drop_last: Optional[bool] = None
+ remove_unused_columns: Optional[bool] = None
+
push_dataset_to_hub: Optional[str] = None
hf_use_auth_token: Optional[bool] = None
@@ -512,10 +519,14 @@ class Config:
neftune_noise_alpha: Optional[float] = None
- max_memory: Optional[Union[int, str]] = None
+ orpo_alpha: Optional[float] = None
+
+ max_memory: Optional[
+ Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
+ ] = None
gpu_memory_limit: Optional[Union[int, str]] = None
- chat_template: Optional[Union[Literal["chatml", "inst"], ChatTemplate]] = None
+ chat_template: Optional[ChatTemplate] = None
default_system_message: Optional[str] = None
# INTERNALS - document for now, generally not set externally
@@ -990,3 +1001,10 @@ def check_sample_packing_w_sdpa_bf16(cls, data):
)
return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_fsdp_deepspeed(cls, data):
+ if data.get("deepspeed") and data.get("fsdp"):
+ raise ValueError("deepspeed and fsdp cannot be used together.")
+ return data
diff --git a/src/axolotl/utils/freeze.py b/src/axolotl/utils/freeze.py
index 05beda1caa..e3d0fd1446 100644
--- a/src/axolotl/utils/freeze.py
+++ b/src/axolotl/utils/freeze.py
@@ -3,13 +3,14 @@
"""
import logging
import re
+from typing import Callable, List, Tuple, Union
from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger("axolotl.utils.freeze")
-def freeze_parameters_except(model, regex_patterns):
+def freeze_layers_except(model, regex_patterns):
"""
Freezes all layers of the given model except for the layers that match given regex patterns.
Periods in the patterns are treated as literal periods, not as wildcard characters.
@@ -17,22 +18,211 @@ def freeze_parameters_except(model, regex_patterns):
Parameters:
- model (nn.Module): The PyTorch model to be modified.
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
+ Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.
+ Also, to match the entire layer name, the pattern should start with "^" and end with "$", otherwise it will match any part of the layer name.
+ The range pattern part is optional and it is not compiled as a regex pattern which means you must put "$" before the range pattern if you want to match the entire layer name.
+ E.g., ["^model.embed_tokens.weight$[:32000]", "layers.2[0-9]+.block_sparse_moe.gate.[a-z]+$"]
Returns:
None; the model is modified in place.
"""
- # Escape periods and compile the regex patterns
- compiled_patterns = [
- re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
- ]
+ if isinstance(regex_patterns, str):
+ regex_patterns = [regex_patterns]
- # First, freeze all parameters in the model
- for param in model.parameters():
- param.requires_grad = False
+ patterns = [LayerNamePattern(pattern) for pattern in regex_patterns]
# Unfreeze layers that match the regex patterns
for name, param in model.named_parameters():
- if any(pattern.match(name) for pattern in compiled_patterns):
- if is_main_process():
- LOG.debug(f"unfreezing {name}")
+ param.requires_grad = False
+ unfrozen_ranges = []
+ for pattern in patterns:
+ if not pattern.match(name):
+ continue
+
param.requires_grad = True
+
+ if pattern.range is not None:
+ unfrozen_ranges.append(pattern.range)
+
+ merged_unfrozen_ranges = _merge_ranges(unfrozen_ranges, len(param))
+
+ if param.requires_grad and is_main_process():
+ unfrozen_ranges = (
+ f" with ranges {merged_unfrozen_ranges}"
+ if merged_unfrozen_ranges
+ else ""
+ )
+ LOG.debug(f"Unfrozen {name}{unfrozen_ranges}")
+
+ if not merged_unfrozen_ranges:
+ continue
+
+ # The range list we need is actually the inverted of the merged ranges
+ ranges_to_freeze = _invert_ranges(merged_unfrozen_ranges, len(param))
+
+ param.register_hook(_create_freeze_parameters_hook(ranges_to_freeze))
+
+ if is_main_process() and all(
+ not param.requires_grad for param in model.parameters()
+ ):
+ LOG.warning("All parameters are frozen. Model will not be trained.")
+
+
+def _invert_ranges(
+ given_ranges: List[Tuple[int, int]], layer_size: int
+) -> List[Tuple[int, int]]:
+ """
+ Inverts a list of ranges to obtain the ranges not covered by the given ranges.
+
+ Parameters:
+ - given_ranges (List[Tuple[int, int]]): List of ranges to invert. Each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
+ - layer_size (int): The length of the layer. E.g., len(model.layer.weight)
+ Returns:
+ - List[Tuple[int, int]]: List of inverted ranges, where each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
+ """
+ if not given_ranges:
+ return [(0, layer_size)]
+
+ inverted_ranges = []
+ current_start = 0
+
+ for start, end in sorted(given_ranges):
+ if start > current_start:
+ inverted_ranges.append((current_start, start))
+ current_start = max(current_start, end)
+
+ # Handle the case where the last given range does not reach the end of the total_size
+ if current_start < layer_size:
+ inverted_ranges.append((current_start, layer_size))
+
+ return inverted_ranges
+
+
+def _merge_ranges(
+ given_ranges: List[Tuple[int, Union[int, None]]], layer_size: int
+) -> List[Tuple[int, int]]:
+ """
+ Merges overlapping ranges and sorts the given ranges.
+
+ This function takes a list of ranges and merges any overlapping ranges. The ranges are represented
+ as tuples, where the first element is the start index (inclusive) and the second element is the end
+ index (exclusive). The end index can be None, indicating that the range extends to the end of the
+ sequence.
+
+ Parameters:
+ - given_ranges (List[Tuple[int, int | None]]): List of ranges to merge.
+ - layer_size (int): The length of the layer. E.g., len(model.layer.weight)
+
+ Returns:
+ - List[Tuple[int, int]]: List of merged ranges, as start (inclusive) and end (exclusive) indices.
+ """
+ # End of each range can be determined now since we have the total size
+ processed_ranges = [
+ (start, end if end is not None else layer_size) for start, end in given_ranges
+ ]
+
+ # No need to merge if there's only one or no ranges
+ if len(processed_ranges) <= 1:
+ return processed_ranges
+
+ sorted_ranges = sorted(processed_ranges)
+
+ merged_ranges = [sorted_ranges[0]]
+ for start, end in sorted_ranges[1:]:
+ prev_start, prev_end = merged_ranges[-1]
+ if start <= prev_end:
+ merged_ranges[-1] = (prev_start, max(prev_end, end))
+ else:
+ merged_ranges.append((start, end))
+
+ return merged_ranges
+
+
+def _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int]]) -> Callable:
+ """
+ Create a hook to freeze parameters in specified ranges by setting their gradients to zero.
+
+ This function takes a list of tuples representing the ranges of indices to freeze. Each tuple should contain
+ two integers representing the start and end indices of the range.
+
+ Parameters:
+ - ranges_to_freeze (List[Tuple[int, int]]): Ranges of indices to freeze.
+
+ Returns:
+ - Callable: A hook function to be used with `register_hook` on parameters.
+
+ Example usage:
+ ```
+ ranges_to_freeze = [(0, 10), (20, 30)]
+ hook = _create_freeze_parameters_hook(ranges_to_freeze)
+ model.register_hook(hook)
+ ```
+ """
+
+ def freeze_parameters_hook(gradients):
+ for start, end in ranges_to_freeze:
+ gradients[start:end].zero_()
+
+ return freeze_parameters_hook
+
+
+class LayerNamePattern:
+ """
+ Represents a regex pattern for layer names, potentially including a parameter index range.
+ """
+
+ def __init__(self, pattern: str):
+ """
+ Initializes a new instance of the LayerNamePattern class.
+
+ Parameters:
+ - pattern (str): The regex pattern for layer names, potentially including a parameter index range.
+ """
+ self.raw_pattern = pattern
+ name_pattern, self.range = self._parse_pattern(pattern)
+ self.name_regex = re.compile(name_pattern.replace(".", "\\."))
+
+ def match(self, name: str) -> bool:
+ """
+ Checks if the given layer name matches the regex pattern.
+
+ Parameters:
+ - name (str): The layer name to check.
+
+ Returns:
+ - bool: True if the layer name matches the pattern, False otherwise.
+ """
+ return self.name_regex.match(name) is not None
+
+ def _parse_pattern(
+ self, pattern: str
+ ) -> Tuple[str, Union[Tuple[int, Union[int, None]], None]]:
+ """
+ Extracts the range pattern from the given pattern.
+
+ Parameters:
+ - pattern (str): The pattern to extract the range from.
+
+ Returns:
+ - Tuple[str, Tuple[int, int | None] | None]: A tuple containing the regex pattern to match the layer name without the range pattern and the range of layer indices to match, if specified.
+ """
+ match = re.match(r"^(.+)\[([0-9]*)(?::([0-9]*))?\]$", pattern)
+ if not match:
+ return pattern, None
+
+ base_pattern, start_part, end_part = match.groups()
+
+ if end_part is None and start_part.isdecimal():
+ index = int(start_part)
+ return base_pattern, (index, index + 1)
+
+ # [:end] or [start:] or [start:end]
+ start = int(start_part) if start_part else 0
+ end = int(end_part) if end_part else None
+
+ if end is not None and start >= end:
+ raise ValueError(
+ f"Invalid range in layer name pattern: {pattern}."
+ "End of range must be greater than start."
+ )
+ return base_pattern, (start, end)
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index 36c9c17e35..fce7b20a7a 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -429,7 +429,7 @@ def load_model(
and cfg.flash_attention
and cfg.sample_packing
):
- patch_for_multipack(cfg.model_config_type)
+ patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block
@@ -1055,6 +1055,8 @@ def load_lora(model, cfg, inference=False, config_only=False):
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
+ if cfg.peft_use_rslora:
+ lora_config_kwargs["use_rslora"] = cfg.use_rslora
lora_config = LoraConfig(
r=cfg.lora_r,
diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py
index 7f63a92fea..afbdef8778 100644
--- a/src/axolotl/utils/tokenization.py
+++ b/src/axolotl/utils/tokenization.py
@@ -2,6 +2,8 @@
import logging
+import re
+from typing import Dict, List
from termcolor import colored
@@ -36,3 +38,65 @@ def check_example_labels(example, tokenizer, text_only=False):
LOG.info("\n\n\n")
return " ".join(colored_tokens)
+
+
+GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
+GLAIVE_TO_SHAREGPT_ROLE = {
+ "SYSTEM": "system",
+ "USER": "human",
+ "ASSISTANT": "gpt",
+ "FUNCTION RESPONSE": "tool",
+}
+
+GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ")
+
+
+def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]:
+ """
+ Converts a ChatML formatted row to a list of messages in ShareGPT format.
+ Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb.
+ """
+
+ system_prompt = row.get("system")
+ if system_prompt:
+ system_prompt = system_prompt.removeprefix("SYSTEM: ")
+
+ chat_str = row["chat"]
+ chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s]
+
+ chat_msg_dicts = [
+ {"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value}
+ for role, value in zip(chat_msgs[::2], chat_msgs[1::2])
+ ]
+
+ if system_prompt:
+ chat_msg_dicts = [
+ {"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt}
+ ] + chat_msg_dicts
+
+ return chat_msg_dicts
+
+
+def merge_consecutive_messages(messages):
+ """
+ Merge consecutive messages from the same sender into a single message.
+ This can be useful with datasets that contain multiple consecutive tool calls.
+ """
+
+ merged_messages = []
+ current_from = None
+ current_message = ""
+
+ for msg in messages:
+ if current_from == msg["from"]:
+ current_message += msg["value"]
+ else:
+ if current_from is not None:
+ merged_messages.append({"from": current_from, "value": current_message})
+ current_from = msg["from"]
+ current_message = msg["value"]
+
+ if current_from is not None:
+ merged_messages.append({"from": current_from, "value": current_message})
+
+ return merged_messages
diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py
index 19f8217e04..c9290b220a 100644
--- a/tests/prompt_strategies/test_sharegpt.py
+++ b/tests/prompt_strategies/test_sharegpt.py
@@ -1,6 +1,7 @@
"""
Test module for sharegpt integration w chatml
"""
+
import pytest
from datasets import Dataset
from tokenizers import AddedToken
@@ -8,6 +9,7 @@
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.sharegpt import (
+ GlaiveShareGPTPromptTokenizingStrategy,
SimpleShareGPTPromptTokenizingStrategy,
register_chatml_template,
)
@@ -48,6 +50,18 @@ def fixture_sharegpt_dataset():
)
+@pytest.fixture(name="glaive_dataset")
+def fixture_sharegpt_glaive_dataset():
+ return Dataset.from_list(
+ [
+ {
+ "system": "SYSTEM: This is a system prompt",
+ "chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
+ }
+ ]
+ )
+
+
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
@@ -156,3 +170,29 @@ def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
]
# fmt: on
+
+ def test_chatml_glaive(self, glaive_dataset, tokenizer):
+ strategy = GlaiveShareGPTPromptTokenizingStrategy(
+ ShareGPTPrompterV2(
+ conversation="chatml",
+ role_key_model=None,
+ role_key_human=None,
+ ),
+ tokenizer,
+ True, # train_on_inputs
+ 2048, # sequence_len
+ )
+
+ dataset_wrapper = TokenizedPromptDataset(
+ strategy, glaive_dataset, process_count=1
+ )
+
+ labels = dataset_wrapper[0]["labels"]
+ # fmt: off
+ assert labels == [
+ 1, # bos
+ 32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, # system
+ 32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, # human
+ 32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
+ ]
+ # fmt: on
diff --git a/tests/test_freeze.py b/tests/test_freeze.py
new file mode 100644
index 0000000000..49d30ba5fa
--- /dev/null
+++ b/tests/test_freeze.py
@@ -0,0 +1,285 @@
+"""
+This module contains unit tests for the `freeze_layers_except` function.
+
+The `freeze_layers_except` function is used to freeze layers in a model, except for the specified layers.
+The unit tests in this module verify the behavior of the `freeze_layers_except` function in different scenarios.
+"""
+
+import unittest
+
+import torch
+from torch import nn
+
+from axolotl.utils.freeze import freeze_layers_except
+
+ZERO = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
+ONE_TO_TEN = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
+
+
+class TestFreezeLayersExcept(unittest.TestCase):
+ """
+ A test case class for the `freeze_layers_except` function.
+ """
+
+ def setUp(self):
+ self.model = _TestModel()
+
+ def test_freeze_layers_with_dots_in_name(self):
+ freeze_layers_except(self.model, ["features.layer"])
+ self.assertTrue(
+ self.model.features.layer.weight.requires_grad,
+ "model.features.layer should be trainable.",
+ )
+ self.assertFalse(
+ self.model.classifier.weight.requires_grad,
+ "model.classifier should be frozen.",
+ )
+
+ def test_freeze_layers_without_dots_in_name(self):
+ freeze_layers_except(self.model, ["classifier"])
+ self.assertFalse(
+ self.model.features.layer.weight.requires_grad,
+ "model.features.layer should be trainable.",
+ )
+ self.assertTrue(
+ self.model.classifier.weight.requires_grad,
+ "model.classifier should be frozen.",
+ )
+
+ def test_freeze_layers_regex_patterns(self):
+ # The second pattern cannot match because only characters 'a' to 'c' are allowed after the word 'class', whereas it should be matching the character 'i'.
+ freeze_layers_except(self.model, [r"^features.[a-z]+.weight$", r"class[a-c]+"])
+ self.assertTrue(
+ self.model.features.layer.weight.requires_grad,
+ "model.features.layer should be trainable.",
+ )
+ self.assertFalse(
+ self.model.classifier.weight.requires_grad,
+ "model.classifier should be frozen.",
+ )
+
+ def test_all_layers_frozen(self):
+ freeze_layers_except(self.model, [])
+ self.assertFalse(
+ self.model.features.layer.weight.requires_grad,
+ "model.features.layer should be frozen.",
+ )
+ self.assertFalse(
+ self.model.classifier.weight.requires_grad,
+ "model.classifier should be frozen.",
+ )
+
+ def test_all_layers_unfrozen(self):
+ freeze_layers_except(self.model, ["features.layer", "classifier"])
+ self.assertTrue(
+ self.model.features.layer.weight.requires_grad,
+ "model.features.layer should be trainable.",
+ )
+ self.assertTrue(
+ self.model.classifier.weight.requires_grad,
+ "model.classifier should be trainable.",
+ )
+
+ def test_freeze_layers_with_range_pattern_start_end(self):
+ freeze_layers_except(self.model, ["features.layer[1:5]"])
+ self.assertTrue(
+ self.model.features.layer.weight.requires_grad,
+ "model.features.layer should be trainable.",
+ )
+ self.assertFalse(
+ self.model.classifier.weight.requires_grad,
+ "model.classifier should be frozen.",
+ )
+
+ self._assert_gradient_output(
+ [
+ ZERO,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ZERO,
+ ZERO,
+ ZERO,
+ ZERO,
+ ZERO,
+ ]
+ )
+
+ def test_freeze_layers_with_range_pattern_single_index(self):
+ freeze_layers_except(self.model, ["features.layer[5]"])
+ self.assertTrue(
+ self.model.features.layer.weight.requires_grad,
+ "model.features.layer should be trainable.",
+ )
+ self.assertFalse(
+ self.model.classifier.weight.requires_grad,
+ "model.classifier should be frozen.",
+ )
+
+ self._assert_gradient_output(
+ [ZERO, ZERO, ZERO, ZERO, ZERO, ONE_TO_TEN, ZERO, ZERO, ZERO, ZERO]
+ )
+
+ def test_freeze_layers_with_range_pattern_start_omitted(self):
+ freeze_layers_except(self.model, ["features.layer[:5]"])
+ self.assertTrue(
+ self.model.features.layer.weight.requires_grad,
+ "model.features.layer should be trainable.",
+ )
+ self.assertFalse(
+ self.model.classifier.weight.requires_grad,
+ "model.classifier should be frozen.",
+ )
+
+ self._assert_gradient_output(
+ [
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ZERO,
+ ZERO,
+ ZERO,
+ ZERO,
+ ZERO,
+ ]
+ )
+
+ def test_freeze_layers_with_range_pattern_end_omitted(self):
+ freeze_layers_except(self.model, ["features.layer[4:]"])
+ self.assertTrue(
+ self.model.features.layer.weight.requires_grad,
+ "model.features.layer should be trainable.",
+ )
+ self.assertFalse(
+ self.model.classifier.weight.requires_grad,
+ "model.classifier should be frozen.",
+ )
+
+ self._assert_gradient_output(
+ [
+ ZERO,
+ ZERO,
+ ZERO,
+ ZERO,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ]
+ )
+
+ def test_freeze_layers_with_range_pattern_merge_included(self):
+ freeze_layers_except(self.model, ["features.layer[4:]", "features.layer[5:6]"])
+ self.assertTrue(
+ self.model.features.layer.weight.requires_grad,
+ "model.features.layer should be trainable.",
+ )
+ self.assertFalse(
+ self.model.classifier.weight.requires_grad,
+ "model.classifier should be frozen.",
+ )
+
+ self._assert_gradient_output(
+ [
+ ZERO,
+ ZERO,
+ ZERO,
+ ZERO,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ]
+ )
+
+ def test_freeze_layers_with_range_pattern_merge_intersect(self):
+ freeze_layers_except(self.model, ["features.layer[4:7]", "features.layer[6:8]"])
+ self.assertTrue(
+ self.model.features.layer.weight.requires_grad,
+ "model.features.layer should be trainable.",
+ )
+ self.assertFalse(
+ self.model.classifier.weight.requires_grad,
+ "model.classifier should be frozen.",
+ )
+
+ self._assert_gradient_output(
+ [
+ ZERO,
+ ZERO,
+ ZERO,
+ ZERO,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ONE_TO_TEN,
+ ZERO,
+ ZERO,
+ ]
+ )
+
+ def test_freeze_layers_with_range_pattern_merge_separate(self):
+ freeze_layers_except(
+ self.model,
+ ["features.layer[1:2]", "features.layer[3:4]", "features.layer[5:6]"],
+ )
+ self.assertTrue(
+ self.model.features.layer.weight.requires_grad,
+ "model.features.layer should be trainable.",
+ )
+ self.assertFalse(
+ self.model.classifier.weight.requires_grad,
+ "model.classifier should be frozen.",
+ )
+
+ self._assert_gradient_output(
+ [
+ ZERO,
+ ONE_TO_TEN,
+ ZERO,
+ ONE_TO_TEN,
+ ZERO,
+ ONE_TO_TEN,
+ ZERO,
+ ZERO,
+ ZERO,
+ ZERO,
+ ]
+ )
+
+ def _assert_gradient_output(self, expected):
+ input_tensor = torch.tensor([ONE_TO_TEN], dtype=torch.float32)
+
+ self.model.features.layer.weight.grad = None # Reset gradients
+ output = self.model.features.layer(input_tensor)
+ loss = output.sum()
+ loss.backward()
+
+ expected_grads = torch.tensor(expected)
+ torch.testing.assert_close(
+ self.model.features.layer.weight.grad, expected_grads
+ )
+
+
+class _SubLayerModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layer = nn.Linear(10, 10)
+
+
+class _TestModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.features = _SubLayerModule()
+ self.classifier = nn.Linear(10, 2)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py
index cf662d95f7..4e659006fe 100644
--- a/tests/test_prompt_tokenizers.py
+++ b/tests/test_prompt_tokenizers.py
@@ -1,4 +1,5 @@
"""Module for testing prompt tokenizers."""
+
import json
import logging
import unittest
@@ -7,7 +8,8 @@
from typing import Optional
import pytest
-from transformers import AutoTokenizer, LlamaTokenizer
+from datasets import load_dataset
+from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
from axolotl.prompt_strategies.alpaca_w_system import (
@@ -18,11 +20,14 @@
Llama2ChatPrompter,
LLama2ChatTokenizingStrategy,
)
+from axolotl.prompt_strategies.orpo.chat_template import load
+from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
+from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl")
@@ -266,6 +271,23 @@ def test_sharegpt_assistant_label_ignore(self):
idx = res["input_ids"].index(20255) # assistant token
assert res["labels"][idx] == -100
+ def test_glaive_tool_label_ignore(self):
+ conversation = {
+ "system": "SYSTEM: This is a system prompt",
+ "chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
+ }
+ prompter = ShareGPTPrompterV2()
+ strat = GlaiveShareGPTPromptTokenizingStrategy(
+ prompter,
+ self.tokenizer,
+ False,
+ 2048,
+ )
+ with self._caplog.at_level(logging.WARNING):
+ res = strat.tokenize_prompt(conversation)
+ idx = res["input_ids"].index(13566) # assistant token
+ assert res["labels"][idx] == -100
+
def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts
@@ -427,5 +449,57 @@ def compare_with_transformers_integration(self):
)
+class OrpoTokenizationTest(unittest.TestCase):
+ """test case for the ORPO tokenization"""
+
+ def setUp(self) -> None:
+ # pylint: disable=duplicate-code
+ tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
+ tokenizer.add_special_tokens(
+ {
+ "eos_token": AddedToken(
+ "<|im_end|>", rstrip=False, lstrip=False, normalized=False
+ )
+ }
+ )
+ tokenizer.add_tokens(
+ [
+ AddedToken(
+ "<|im_start|>", rstrip=False, lstrip=False, normalized=False
+ ),
+ ]
+ )
+ self.tokenizer = tokenizer
+ self.dataset = load_dataset(
+ "argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
+ ).select([0])
+
+ def test_orpo_integration(self):
+ strat = load(
+ self.tokenizer,
+ DictDefault({"train_on_inputs": False}),
+ DictDefault({"chat_template": "chatml"}),
+ )
+ res = strat.tokenize_prompt(self.dataset[0])
+ assert "rejected_input_ids" in res
+ assert "rejected_labels" in res
+ assert "input_ids" in res
+ assert "labels" in res
+ assert "prompt_attention_mask" in res
+
+ assert len(res["rejected_input_ids"]) == len(res["rejected_labels"])
+ assert len(res["input_ids"]) == len(res["labels"])
+ assert len(res["input_ids"]) == len(res["prompt_attention_mask"])
+
+ assert res["rejected_labels"][0] == -100
+ assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1]
+
+ assert res["labels"][0] == -100
+ assert res["input_ids"][-1] == res["labels"][-1]
+
+ assert res["prompt_attention_mask"][0] == 1
+ assert res["prompt_attention_mask"][-1] == 0
+
+
if __name__ == "__main__":
unittest.main()