diff --git a/README.md b/README.md index 99432fa32b..054d7db54f 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,9 @@ Features: - Log results and optionally checkpoints to wandb or mlflow - And more! + + phorm.ai + @@ -28,6 +31,7 @@ Features: - [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod - [Bare Metal Cloud GPU](#bare-metal-cloud-gpu) - [Windows](#windows) + - [Mac](#mac) - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot) - [Dataset](#dataset) - [How to Add Custom Prompts](#how-to-add-custom-prompts) @@ -99,24 +103,14 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo **Requirements**: Python >=3.10 and Pytorch >=2.1.1. -### For developers ```bash git clone https://github.com/OpenAccess-AI-Collective/axolotl cd axolotl pip3 install packaging -``` - -General case: -``` pip3 install -e '.[flash-attn,deepspeed]' ``` -Mac: see https://github.com/OpenAccess-AI-Collective/axolotl/blob/13199f678b9aab39e92961323bdbce3234ee4b2b/docs/mac.md -``` -pip3 install -e '.' -``` - ### Usage ```bash # preprocess datasets - optional but recommended @@ -249,9 +243,31 @@ For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud ``` +##### GCP + +
+ +Click to Expand + +Use a Deeplearning linux OS with cuda and pytorch installed. Then follow instructions on quickstart. + +Make sure to run the below to uninstall xla. +```bash +pip uninstall -y torch_xla[tpu] +``` + +
+ #### Windows Please use WSL or Docker! +#### Mac + +Use the below instead of the install method in QuickStart. +``` +pip3 install -e '.' +``` +More info: [mac.md](/docs/mac.md) #### Launching on public clouds via SkyPilot To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html): @@ -1084,6 +1100,10 @@ fsdp_config: fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer ``` +##### FSDP + QLoRA + +Axolotl supports training with FSDP and QLoRA, see [these docs](docs/fsdp_qlora.md) for more information. + ##### Weights & Biases Logging Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`. diff --git a/deepspeed_configs/zero1.json b/deepspeed_configs/zero1.json index 787fc0d6b7..8a57b8605a 100644 --- a/deepspeed_configs/zero1.json +++ b/deepspeed_configs/zero1.json @@ -16,6 +16,7 @@ "min_loss_scale": 1 }, "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false diff --git a/deepspeed_configs/zero2.json b/deepspeed_configs/zero2.json index 5b22d996c5..153ac02802 100644 --- a/deepspeed_configs/zero2.json +++ b/deepspeed_configs/zero2.json @@ -20,6 +20,7 @@ "min_loss_scale": 1 }, "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false diff --git a/deepspeed_configs/zero3.json b/deepspeed_configs/zero3.json index a185afab44..90ec3677ea 100644 --- a/deepspeed_configs/zero3.json +++ b/deepspeed_configs/zero3.json @@ -24,6 +24,7 @@ "min_loss_scale": 1 }, "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false diff --git a/deepspeed_configs/zero3_bf16.json b/deepspeed_configs/zero3_bf16.json index 263caa393b..16e64d76b4 100644 --- a/deepspeed_configs/zero3_bf16.json +++ b/deepspeed_configs/zero3_bf16.json @@ -24,6 +24,7 @@ "min_loss_scale": 1 }, "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false diff --git a/docs/fsdp_qlora.md b/docs/fsdp_qlora.md new file mode 100644 index 0000000000..14b2c1a571 --- /dev/null +++ b/docs/fsdp_qlora.md @@ -0,0 +1,37 @@ +# FDSP + QLoRA + +## Background + +Using FSDP with QLoRA is essential for **fine-tuning larger (70b+ parameter) LLMs on consumer GPUs.** For example, you can use FSDP + QLoRA to train a 70b model on two 24GB GPUs[^1]. + +Below, we describe how to use this feature in Axolotl. + +## Usage + +To enable `QLoRA` with `FSDP`, you need to perform the following steps: + +> ![Tip] +> See the [example config](#example-config) file in addition to reading these instructions. + +1. Set `adapter: qlora` in your axolotl config file. +2. Enable FSDP in your axolotl config, as [described here](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#fsdp). +3. Use one of the supported model types: `llama`, `mistral` or `mixtral`. + +## Example Config + +[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl. + +## References + +- [PR #1378](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl. +- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP. +- Related HuggingFace PRs Enabling FDSP + QLoRA: + - Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 ) + - Transformers [PR#29587](https://github.com/huggingface/transformers/pull/29587) + - TRL [PR#1416](https://github.com/huggingface/trl/pull/1416) + - PEFT [PR#1550](https://github.com/huggingface/peft/pull/1550) + + + + +[^1]: This was enabled by [this work](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the Answer.AI team. diff --git a/docs/rlhf.md b/docs/rlhf.md index 9f5ba05fdb..4f71184fc0 100644 --- a/docs/rlhf.md +++ b/docs/rlhf.md @@ -34,6 +34,21 @@ datasets: rl: ipo ``` +#### ORPO + +Paper: https://arxiv.org/abs/2403.07691 + +```yaml +rl: orpo +orpo_alpha: 0.1 +remove_unused_columns: false + +chat_template: chatml +datasets: + - path: argilla/ultrafeedback-binarized-preferences-cleaned + type: orpo.chat_template +``` + #### Using local dataset files ```yaml datasets: diff --git a/examples/gemma/qlora.yml b/examples/gemma/qlora.yml index 02a41e0cf4..262197cb7e 100644 --- a/examples/gemma/qlora.yml +++ b/examples/gemma/qlora.yml @@ -21,7 +21,7 @@ lora_dropout: 0.05 lora_target_linear: true sequence_len: 4096 -sample_packing: true +sample_packing: false pad_to_sequence_len: true wandb_project: diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml index 7c18e7098c..5ee3da9d65 100644 --- a/examples/mistral/mixtral.yml +++ b/examples/mistral/mixtral.yml @@ -16,12 +16,12 @@ output_dir: ./qlora-out ## You can optionally freeze the entire model and unfreeze a subset of parameters unfrozen_parameters: -# - lm_head.* -# - model.embed_tokens.* -# - model.layers.2[0-9]+.block_sparse_moe.gate.* -# - model.layers.2[0-9]+.block_sparse_moe.experts.* -# - model.layers.3[0-9]+.block_sparse_moe.gate.* -# - model.layers.3[0-9]+.block_sparse_moe.experts.* +# - ^lm_head.weight$ +# - ^model.embed_tokens.weight$[:32000] +# - model.layers.2[0-9]+.block_sparse_moe.gate +# - model.layers.2[0-9]+.block_sparse_moe.experts +# - model.layers.3[0-9]+.block_sparse_moe.gate +# - model.layers.3[0-9]+.block_sparse_moe.experts model_config: output_router_logits: true diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 89ab023e5f..a1a01d59de 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): LOG.warning(msg) parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH - if parsed_cfg.rl: + if parsed_cfg.rl and parsed_cfg.rl != "orpo": load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) else: load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 05fd63ae80..7e004567a6 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: else: register_chatml_template() - if cfg.rl: + if cfg.rl and cfg.rl != "orpo": dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d11f0c6532..42180f32b3 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -11,10 +11,11 @@ import os import sys from abc import abstractmethod +from collections import defaultdict from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import List, Optional, Type, Union +from typing import Dict, List, Literal, Optional, Type, Union import torch import transformers @@ -200,6 +201,9 @@ class AxolotlTrainingArguments(TrainingArguments): default=False, metadata={"help": "whether this is a qlora training"}, ) + orpo_alpha: Optional[float] = field( + default=None, + ) class AxolotlTrainer(Trainer): @@ -223,6 +227,9 @@ def __init__( self.eval_data_collator = eval_data_collator super().__init__(*_args, **kwargs) self.train_data_collator = self.data_collator + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + if self.args.orpo_alpha: + self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") def create_optimizer(self): if self.args.loraplus_lr_ratio is None: @@ -465,8 +472,112 @@ def compute_loss(self, model, inputs, return_outputs=False): # outputs = model(**inputs) # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) # return (loss, outputs) if return_outputs else loss + if self.args.orpo_alpha: + return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs) return super().compute_loss(model, inputs, return_outputs=return_outputs) + def orpo_compute_custom_loss(self, logits, labels): + logits = logits.contiguous() + loss = 0.0 + + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Flatten the tokens + loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean( + dim=-1 + ) + + return loss + + def orpo_compute_logps( + self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits + ): + # Get the shape of chosen_attention_mask[:, :-1] + chosen_shape = chosen_attention_mask[:, :-1].shape + + # Calculate the padding size + pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1) + + # Pad prompt_attention_mask with zeros to match the desired shape + prompt_attention_mask_padded = torch.nn.functional.pad( + prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0 + ) + + # Perform the subtraction operation + mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded + + per_token_logps = torch.gather( + logits[:, :-1, :].log_softmax(-1), + dim=2, + index=(mask * chosen_inputs[:, 1:]).unsqueeze(2), + ).squeeze(2) + return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to( + dtype=torch.float64 + ) / mask.sum(dim=1).to(dtype=torch.float64) + + def orpo_compute_loss(self, model, inputs, return_outputs=False): + outputs_neg = model( + **{ + "input_ids": inputs["rejected_input_ids"], + "attention_mask": inputs["rejected_attention_mask"], + "labels": inputs["rejected_labels"], + }, + output_hidden_states=True, + ) + outputs_pos = model( + **{ + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"], + "labels": inputs["labels"], + }, + output_hidden_states=True, + ) + + # Calculate NLL loss + pos_loss = self.orpo_compute_custom_loss( + logits=outputs_pos.logits, labels=inputs["input_ids"] + ) + + # Calculate Log Probability + pos_prob = self.orpo_compute_logps( + prompt_attention_mask=inputs["prompt_attention_mask"], + chosen_inputs=inputs["input_ids"], + chosen_attention_mask=inputs["attention_mask"], + logits=outputs_pos.logits, + ) + neg_prob = self.orpo_compute_logps( + prompt_attention_mask=inputs["prompt_attention_mask"], + chosen_inputs=inputs["rejected_input_ids"], + chosen_attention_mask=inputs["rejected_attention_mask"], + logits=outputs_neg.logits, + ) + + # Calculate log odds + log_odds = (pos_prob - neg_prob) - ( + torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)) + ) + sig_ratio = torch.nn.functional.sigmoid(log_odds) + ratio = torch.log(sig_ratio) + + # Calculate the Final Loss + loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to( + dtype=torch.bfloat16 + ) + + metrics = {} + metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item() + metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item() + metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item() + metrics["log_odds"] = torch.mean(log_odds).cpu().item() + self.store_metrics(metrics, train_eval="train") + + return (loss, outputs_pos) if return_outputs else loss + @wraps(Trainer.push_to_hub) def push_to_hub(self, *args, **kwargs) -> str: """ @@ -527,6 +638,28 @@ def create_accelerator_and_postprocess(self): return res + def log(self, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs) + + def store_metrics( + self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" + ) -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + class AxolotlMambaTrainer(AxolotlTrainer): """ @@ -903,6 +1036,11 @@ def build(self, total_num_steps): elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False: training_arguments_kwargs["dataloader_drop_last"] = True + if self.cfg.remove_unused_columns is not None: + training_arguments_kwargs[ + "remove_unused_columns" + ] = self.cfg.remove_unused_columns + if not self.cfg.test_datasets and self.cfg.val_set_size == 0: # no eval set, so don't eval training_arguments_kwargs["evaluation_strategy"] = "no" @@ -1070,6 +1208,9 @@ def build(self, total_num_steps): training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) + if self.cfg.rl == "orpo": + training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha + if self.cfg.neftune_noise_alpha is not None: training_arguments_kwargs[ "neftune_noise_alpha" diff --git a/src/axolotl/logging_config.py b/src/axolotl/logging_config.py index 8f473aa240..2ddf89a8c4 100644 --- a/src/axolotl/logging_config.py +++ b/src/axolotl/logging_config.py @@ -30,6 +30,7 @@ def format(self, record): DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { "version": 1, + "disable_existing_loggers": False, "formatters": { "simple": { "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s", diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 964b41f707..fbcaf7a668 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -1,6 +1,9 @@ """multipack patching for v2 of sample packing""" +import importlib import transformers +from accelerate import init_empty_weights +from transformers import AutoConfig, AutoModelForCausalLM from transformers.integrations import is_deepspeed_zero3_enabled from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 @@ -12,11 +15,12 @@ "falcon", "phi", "gemma", + "gemmoe", "starcoder2", ] -def patch_for_multipack(model_type): +def patch_for_multipack(model_type, model_name=None): if model_type == "mixtral": transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data @@ -43,3 +47,15 @@ def patch_for_multipack(model_type): transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) + elif model_type == "gemmoe": + model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + # we need to load the model here in order for modeling_gemmoe to be available + with init_empty_weights(): + AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + module_name = model_config.__class__.__module__.replace( + ".configuration_gemmoe", ".modeling_gemmoe" + ) + modeling_gemmoe = importlib.import_module(module_name) + modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) diff --git a/src/axolotl/prompt_strategies/base.py b/src/axolotl/prompt_strategies/base.py new file mode 100644 index 0000000000..fce2aba14a --- /dev/null +++ b/src/axolotl/prompt_strategies/base.py @@ -0,0 +1,20 @@ +""" +module for base dataset transform strategies +""" + +import importlib +import logging + +LOG = logging.getLogger("axolotl") + + +def load(strategy, cfg, module_base=None, **kwargs): + try: + load_fn = strategy.split(".")[-1] + strategy = ".".join(strategy.split(".")[:-1]) + mod = importlib.import_module(f".{strategy}", module_base) + func = getattr(mod, load_fn) + return func(cfg, **kwargs) + except Exception: # pylint: disable=broad-exception-caught + LOG.warning(f"unable to load strategy {strategy}") + return None diff --git a/src/axolotl/prompt_strategies/dpo/__init__.py b/src/axolotl/prompt_strategies/dpo/__init__.py index 8bd430f912..1a149f4528 100644 --- a/src/axolotl/prompt_strategies/dpo/__init__.py +++ b/src/axolotl/prompt_strategies/dpo/__init__.py @@ -1,20 +1,8 @@ """ module for DPO style dataset transform strategies """ +from functools import partial -import importlib -import logging +from ..base import load as load_base -LOG = logging.getLogger("axolotl") - - -def load(strategy, cfg, **kwargs): - try: - load_fn = strategy.split(".")[-1] - strategy = ".".join(strategy.split(".")[:-1]) - mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo") - func = getattr(mod, load_fn) - return func(cfg, **kwargs) - except Exception: # pylint: disable=broad-exception-caught - LOG.warning(f"unable to load strategy {strategy}") - return None +load = partial(load_base, module="axolotl.prompt_strategies.dpo") diff --git a/src/axolotl/prompt_strategies/dpo/chatml.py b/src/axolotl/prompt_strategies/dpo/chatml.py index e8c7f4088c..585696e29a 100644 --- a/src/axolotl/prompt_strategies/dpo/chatml.py +++ b/src/axolotl/prompt_strategies/dpo/chatml.py @@ -24,6 +24,25 @@ def transform_fn(sample): return transform_fn +def argilla_chat( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + """ + for argilla/dpo-mix-7k conversations + """ + + def transform_fn(sample): + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" + sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" + return sample + + return transform_fn + + def icr( cfg, **kwargs, diff --git a/src/axolotl/prompt_strategies/orpo/__init__.py b/src/axolotl/prompt_strategies/orpo/__init__.py new file mode 100644 index 0000000000..3a961fcc92 --- /dev/null +++ b/src/axolotl/prompt_strategies/orpo/__init__.py @@ -0,0 +1,9 @@ +""" +module for ORPO style dataset transform strategies +""" + +from functools import partial + +from ..base import load as load_base + +load = partial(load_base, module="axolotl.prompt_strategies.orpo") diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py new file mode 100644 index 0000000000..fb39bcf8f4 --- /dev/null +++ b/src/axolotl/prompt_strategies/orpo/chat_template.py @@ -0,0 +1,187 @@ +"""chatml prompt tokenization strategy for ORPO""" +from typing import Any, Dict, Generator, List, Optional, Tuple + +from pydantic import BaseModel + +from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy +from axolotl.prompters import Prompter +from axolotl.utils.chat_templates import chat_templates + + +class Message(BaseModel): + """message/turn""" + + role: str + content: str + label: Optional[bool] = None + + +class MessageList(BaseModel): + """conversation""" + + messages: List[Message] + + +def load( + tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwargs +): # pylint: disable=possibly-unused-variable,unused-argument + """ + chatml transforms for datasets with system, input, chosen, rejected + """ + + chat_template = chat_templates("chatml") + if ds_cfg and "chat_template" in ds_cfg: + chat_template = ds_cfg["chat_template"] + try: + chat_template = chat_templates(chat_template) + except ValueError: + pass + + return ORPOTokenizingStrategy( + ORPOPrompter(chat_template, tokenizer), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + dataset_parser=ORPODatasetParsingStrategy(), + ) + + +class ORPODatasetParsingStrategy: + """Strategy to parse chosen rejected dataset into messagelist""" + + def get_chosen_conversation_thread(self, prompt) -> MessageList: + """Dataset structure mappings""" + + messages: List[Message] = [] + if system := prompt.get("system", None): + messages.append(Message(role="system", content=system, label=False)) + messages.append(Message(role="user", content=prompt["prompt"], label=False)) + messages.append( + Message( + role="assistant", content=prompt["chosen"][1]["content"], label=True + ) + ) + return MessageList(messages=messages) + + def get_rejected_conversation_thread(self, prompt) -> MessageList: + """Dataset structure mappings""" + + messages: List[Message] = [] + if system := prompt.get("system", None): + messages.append(Message(role="system", content=system, label=False)) + messages.append(Message(role="user", content=prompt["prompt"], label=False)) + messages.append( + Message( + role="assistant", content=prompt["rejected"][1]["content"], label=True + ) + ) + return MessageList(messages=messages) + + +class ORPOTokenizingStrategy(PromptTokenizingStrategy): + """ + rejected_input_ids + input_ids + rejected_attention_mask + attention_mask + rejected_labels + labels + """ + + def __init__( + self, + *args, + dataset_parser=None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.dataset_parser = dataset_parser + + def tokenize_prompt(self, prompt): + # pass the rejected prompt/row to the Prompter to get the formatted prompt + prompt_len = 0 + rejected_message_list = self.dataset_parser.get_rejected_conversation_thread( + prompt + ) + input_ids = [] + labels = [] + for _, (part, label) in enumerate( + self.prompter.build_prompt(rejected_message_list) + ): + if not part: + continue + _input_ids = self.tokenizer.encode(part, add_special_tokens=False) + prev_idx = len(input_ids) + input_ids += _input_ids[prev_idx:] + if label: + labels += input_ids[prev_idx:] + else: + labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx) + prompt_len = len(input_ids) + # remap the input_ids, attention_mask and labels + rejected_input_ids = input_ids + rejected_labels = labels + # pass the chosen prompt/row to the Prompter to get the formatted prompt + chosen_message_list = self.dataset_parser.get_chosen_conversation_thread(prompt) + input_ids = [] + labels = [] + for _, (part, label) in enumerate( + self.prompter.build_prompt(chosen_message_list) + ): + if not part: + continue + _input_ids = self.tokenizer.encode(part, add_special_tokens=False) + prev_idx = len(input_ids) + input_ids += _input_ids[prev_idx:] + if label: + labels += input_ids[prev_idx:] + else: + labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx) + + return { + "rejected_input_ids": rejected_input_ids, + "rejected_labels": rejected_labels, + "rejected_attention_mask": [1] * len(rejected_labels), + "input_ids": input_ids, + "labels": labels, + "attention_mask": [1] * len(labels), + "prompt_attention_mask": [1] * prompt_len + + [0] * (len(labels) - prompt_len), + } + + +class ORPOPrompter(Prompter): + """Single Turn prompter for ORPO""" + + def __init__(self, chat_template, tokenizer): + self.chat_template = chat_template + self.tokenizer = tokenizer + + def build_prompt( + self, + message_list: MessageList, + ) -> Generator[Tuple[str, bool], None, None]: + conversation = [] + for message in message_list.messages: + conversation.append(message.model_dump()) + if message.role == "system": + yield self.tokenizer.apply_chat_template( + conversation, + add_generation_prompt=False, + chat_template=self.chat_template, + tokenize=False, + ), False + if message.role == "user": + yield self.tokenizer.apply_chat_template( + conversation, + add_generation_prompt=True, + chat_template=self.chat_template, + tokenize=False, + ), False + if message.role == "assistant": + yield self.tokenizer.apply_chat_template( + conversation, + add_generation_prompt=False, + chat_template=self.chat_template, + tokenize=False, + ), True diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 6ac7cbafe9..7a7f61a8ee 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -1,10 +1,15 @@ """Module containing the SimpleShareGPTPromptTokenizingStrategy class""" + from typing import Any, Dict, Optional from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy from axolotl.prompters import ShareGPTPrompterV2 +from axolotl.utils.tokenization import ( + chatml_to_conversation, + merge_consecutive_messages, +) def register_chatml_template(system_message=None): @@ -19,6 +24,16 @@ def register_chatml_template(system_message=None): sep="<|im_end|>", ) ) + register_conv_template( + Conversation( + name="chatml_glaive", + system_template="<|im_start|>system\n{system_message}", + system_message=system_message, + roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"], + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + ) + ) def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): @@ -77,6 +92,20 @@ def load_guanaco(tokenizer, cfg): ) +def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + conversation = ( + ds_cfg["conversation"] + if ds_cfg and "conversation" in ds_cfg + else "chatml_glaive" + ) + return GlaiveShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2(conversation=conversation), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + + class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): """ basic sharegpt strategy to grab conversations from the sample row @@ -158,3 +187,15 @@ def get_conversation_thread(self, prompt): {"from": role_map[t["role"]], "value": t["content"]} for t in conversations ] return turns + + +class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy): + """ + sharegpt strategy that remaps glaive data to sharegpt format + """ + + def get_conversation_thread(self, prompt): + conversation = chatml_to_conversation(prompt) + conversation = merge_consecutive_messages(conversation) + + return conversation diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index a5c243f7e6..7e62a0cd4c 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -360,11 +360,19 @@ def tokenize_prompt(self, prompt): LOG.warning(f"expected tuple, got {part}") continue - user, assistant = conversation.roles + tool_role_label = None + if len(conversation.roles) == 3: + ( + user_role_label, + assistant_role_label, + tool_role_label, + ) = conversation.roles + else: + user_role_label, assistant_role_label = conversation.roles role, content = part # Uses "in" because role contains extra characters - if user in role: + if user_role_label in role: role = ( role.replace(role_remap[0]["from"], role_remap[0]["to"]) if role_remap @@ -384,7 +392,7 @@ def tokenize_prompt(self, prompt): else: # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif assistant in role: + elif assistant_role_label in role: role = ( role.replace(role_remap[1]["from"], role_remap[1]["to"]) if role_remap @@ -426,6 +434,8 @@ def tokenize_prompt(self, prompt): else: # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) + elif tool_role_label and tool_role_label in role: + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) else: LOG.warning(f"unhandled role: {role}") continue diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 748db1a162..fa181f916d 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -267,6 +267,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods role_key_human = "human" role_key_model = "gpt" + # Optional, only used for tool usage datasets. + role_key_tool = None def __init__( self, @@ -274,6 +276,7 @@ def __init__( conversation: Optional[Union[str, Conversation]] = None, role_key_human: Optional[str] = None, role_key_model: Optional[str] = None, + role_key_tool: Optional[str] = None, ): if conversation: if isinstance(conversation, Conversation): @@ -286,6 +289,8 @@ def __init__( self.role_key_human = role_key_human if role_key_model: self.role_key_model = role_key_model + if role_key_tool: + self.role_key_tool = role_key_tool def _build_result(self, source): if len(source) < 2: @@ -303,6 +308,8 @@ def _build_result(self, source): source.pop(0) roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]} + if self.role_key_tool: + roles[self.role_key_tool] = conv.roles[2] try: # Apply prompt templates diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 0b5ef76716..b6cd24672e 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -19,7 +19,7 @@ from axolotl.common.cli import TrainerCliArgs from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault -from axolotl.utils.freeze import freeze_parameters_except +from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.trainer import setup_trainer @@ -85,7 +85,7 @@ def train( model.generation_config.do_sample = True model_ref = None - if cfg.rl: + if cfg.rl and cfg.rl != "orpo": if cfg.adapter and not cfg.rl_adapter_ref_model: # use built-in trl autounwrap LOG.debug("Passing model_ref: None to RL trainer") @@ -99,7 +99,7 @@ def train( safe_serialization = cfg.save_safetensors is True if cfg.unfrozen_parameters: - freeze_parameters_except(model, cfg.unfrozen_parameters) + freeze_layers_except(model, cfg.unfrozen_parameters) trainer = setup_trainer( cfg, @@ -110,9 +110,6 @@ def train( total_num_steps, ) - if hasattr(model, "config"): - model.config.use_cache = False - # go ahead and presave, so we have the adapter config available to inspect if peft_config: LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 1ec83536d6..fd34b4ea99 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -21,7 +21,7 @@ def chat_templates(user_choice: str): templates = { "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. - "chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", } diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 9151f288a8..3e743bda9f 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -191,6 +191,11 @@ def normalize_cfg_datasets(cfg): f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template" ) cfg.datasets[idx].conversation = "chatml" + if ds_cfg.type == "orpo.chat_template" and not ds_cfg.chat_template: + LOG.info( + f"updating dataset {ds_cfg.path} with `chat_template: chatml` to match your chat_template" + ) + cfg.datasets[idx].chat_template = "chatml" def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index a536fa9ecc..b1c395bcc8 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1,6 +1,7 @@ """ Module for pydantic models for configuration """ +# pylint: disable=too-many-lines import logging import os @@ -123,13 +124,16 @@ class RLType(str, Enum): dpo = "dpo" # pylint: disable=invalid-name ipo = "ipo" # pylint: disable=invalid-name kto_pair = "kto_pair" # pylint: disable=invalid-name + orpo = "orpo" # pylint: disable=invalid-name class ChatTemplate(str, Enum): """Chat templates configuration subset""" + alpaca = "alpaca" # pylint: disable=invalid-name chatml = "chatml" # pylint: disable=invalid-name inst = "inst" # pylint: disable=invalid-name + gemma = "gemma" # pylint: disable=invalid-name class LoftQConfig(BaseModel): @@ -179,6 +183,7 @@ class LoraConfig(BaseModel): peft_layers_to_transform: Optional[List[int]] = None peft: Optional[PeftConfig] = None peft_use_dora: Optional[bool] = None + peft_use_relora: Optional[bool] = None lora_on_cpu: Optional[bool] = None gptq: Optional[bool] = None @@ -428,6 +433,8 @@ class Config: dataloader_prefetch_factor: Optional[int] = None dataloader_drop_last: Optional[bool] = None + remove_unused_columns: Optional[bool] = None + push_dataset_to_hub: Optional[str] = None hf_use_auth_token: Optional[bool] = None @@ -512,10 +519,14 @@ class Config: neftune_noise_alpha: Optional[float] = None - max_memory: Optional[Union[int, str]] = None + orpo_alpha: Optional[float] = None + + max_memory: Optional[ + Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]] + ] = None gpu_memory_limit: Optional[Union[int, str]] = None - chat_template: Optional[Union[Literal["chatml", "inst"], ChatTemplate]] = None + chat_template: Optional[ChatTemplate] = None default_system_message: Optional[str] = None # INTERNALS - document for now, generally not set externally @@ -990,3 +1001,10 @@ def check_sample_packing_w_sdpa_bf16(cls, data): ) return data + + @model_validator(mode="before") + @classmethod + def check_fsdp_deepspeed(cls, data): + if data.get("deepspeed") and data.get("fsdp"): + raise ValueError("deepspeed and fsdp cannot be used together.") + return data diff --git a/src/axolotl/utils/freeze.py b/src/axolotl/utils/freeze.py index 05beda1caa..e3d0fd1446 100644 --- a/src/axolotl/utils/freeze.py +++ b/src/axolotl/utils/freeze.py @@ -3,13 +3,14 @@ """ import logging import re +from typing import Callable, List, Tuple, Union from axolotl.utils.distributed import is_main_process LOG = logging.getLogger("axolotl.utils.freeze") -def freeze_parameters_except(model, regex_patterns): +def freeze_layers_except(model, regex_patterns): """ Freezes all layers of the given model except for the layers that match given regex patterns. Periods in the patterns are treated as literal periods, not as wildcard characters. @@ -17,22 +18,211 @@ def freeze_parameters_except(model, regex_patterns): Parameters: - model (nn.Module): The PyTorch model to be modified. - regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen. + Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names. + Also, to match the entire layer name, the pattern should start with "^" and end with "$", otherwise it will match any part of the layer name. + The range pattern part is optional and it is not compiled as a regex pattern which means you must put "$" before the range pattern if you want to match the entire layer name. + E.g., ["^model.embed_tokens.weight$[:32000]", "layers.2[0-9]+.block_sparse_moe.gate.[a-z]+$"] Returns: None; the model is modified in place. """ - # Escape periods and compile the regex patterns - compiled_patterns = [ - re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns - ] + if isinstance(regex_patterns, str): + regex_patterns = [regex_patterns] - # First, freeze all parameters in the model - for param in model.parameters(): - param.requires_grad = False + patterns = [LayerNamePattern(pattern) for pattern in regex_patterns] # Unfreeze layers that match the regex patterns for name, param in model.named_parameters(): - if any(pattern.match(name) for pattern in compiled_patterns): - if is_main_process(): - LOG.debug(f"unfreezing {name}") + param.requires_grad = False + unfrozen_ranges = [] + for pattern in patterns: + if not pattern.match(name): + continue + param.requires_grad = True + + if pattern.range is not None: + unfrozen_ranges.append(pattern.range) + + merged_unfrozen_ranges = _merge_ranges(unfrozen_ranges, len(param)) + + if param.requires_grad and is_main_process(): + unfrozen_ranges = ( + f" with ranges {merged_unfrozen_ranges}" + if merged_unfrozen_ranges + else "" + ) + LOG.debug(f"Unfrozen {name}{unfrozen_ranges}") + + if not merged_unfrozen_ranges: + continue + + # The range list we need is actually the inverted of the merged ranges + ranges_to_freeze = _invert_ranges(merged_unfrozen_ranges, len(param)) + + param.register_hook(_create_freeze_parameters_hook(ranges_to_freeze)) + + if is_main_process() and all( + not param.requires_grad for param in model.parameters() + ): + LOG.warning("All parameters are frozen. Model will not be trained.") + + +def _invert_ranges( + given_ranges: List[Tuple[int, int]], layer_size: int +) -> List[Tuple[int, int]]: + """ + Inverts a list of ranges to obtain the ranges not covered by the given ranges. + + Parameters: + - given_ranges (List[Tuple[int, int]]): List of ranges to invert. Each range is represented as a tuple of start (inclusive) and end (exclusive) indices. + - layer_size (int): The length of the layer. E.g., len(model.layer.weight) + Returns: + - List[Tuple[int, int]]: List of inverted ranges, where each range is represented as a tuple of start (inclusive) and end (exclusive) indices. + """ + if not given_ranges: + return [(0, layer_size)] + + inverted_ranges = [] + current_start = 0 + + for start, end in sorted(given_ranges): + if start > current_start: + inverted_ranges.append((current_start, start)) + current_start = max(current_start, end) + + # Handle the case where the last given range does not reach the end of the total_size + if current_start < layer_size: + inverted_ranges.append((current_start, layer_size)) + + return inverted_ranges + + +def _merge_ranges( + given_ranges: List[Tuple[int, Union[int, None]]], layer_size: int +) -> List[Tuple[int, int]]: + """ + Merges overlapping ranges and sorts the given ranges. + + This function takes a list of ranges and merges any overlapping ranges. The ranges are represented + as tuples, where the first element is the start index (inclusive) and the second element is the end + index (exclusive). The end index can be None, indicating that the range extends to the end of the + sequence. + + Parameters: + - given_ranges (List[Tuple[int, int | None]]): List of ranges to merge. + - layer_size (int): The length of the layer. E.g., len(model.layer.weight) + + Returns: + - List[Tuple[int, int]]: List of merged ranges, as start (inclusive) and end (exclusive) indices. + """ + # End of each range can be determined now since we have the total size + processed_ranges = [ + (start, end if end is not None else layer_size) for start, end in given_ranges + ] + + # No need to merge if there's only one or no ranges + if len(processed_ranges) <= 1: + return processed_ranges + + sorted_ranges = sorted(processed_ranges) + + merged_ranges = [sorted_ranges[0]] + for start, end in sorted_ranges[1:]: + prev_start, prev_end = merged_ranges[-1] + if start <= prev_end: + merged_ranges[-1] = (prev_start, max(prev_end, end)) + else: + merged_ranges.append((start, end)) + + return merged_ranges + + +def _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int]]) -> Callable: + """ + Create a hook to freeze parameters in specified ranges by setting their gradients to zero. + + This function takes a list of tuples representing the ranges of indices to freeze. Each tuple should contain + two integers representing the start and end indices of the range. + + Parameters: + - ranges_to_freeze (List[Tuple[int, int]]): Ranges of indices to freeze. + + Returns: + - Callable: A hook function to be used with `register_hook` on parameters. + + Example usage: + ``` + ranges_to_freeze = [(0, 10), (20, 30)] + hook = _create_freeze_parameters_hook(ranges_to_freeze) + model.register_hook(hook) + ``` + """ + + def freeze_parameters_hook(gradients): + for start, end in ranges_to_freeze: + gradients[start:end].zero_() + + return freeze_parameters_hook + + +class LayerNamePattern: + """ + Represents a regex pattern for layer names, potentially including a parameter index range. + """ + + def __init__(self, pattern: str): + """ + Initializes a new instance of the LayerNamePattern class. + + Parameters: + - pattern (str): The regex pattern for layer names, potentially including a parameter index range. + """ + self.raw_pattern = pattern + name_pattern, self.range = self._parse_pattern(pattern) + self.name_regex = re.compile(name_pattern.replace(".", "\\.")) + + def match(self, name: str) -> bool: + """ + Checks if the given layer name matches the regex pattern. + + Parameters: + - name (str): The layer name to check. + + Returns: + - bool: True if the layer name matches the pattern, False otherwise. + """ + return self.name_regex.match(name) is not None + + def _parse_pattern( + self, pattern: str + ) -> Tuple[str, Union[Tuple[int, Union[int, None]], None]]: + """ + Extracts the range pattern from the given pattern. + + Parameters: + - pattern (str): The pattern to extract the range from. + + Returns: + - Tuple[str, Tuple[int, int | None] | None]: A tuple containing the regex pattern to match the layer name without the range pattern and the range of layer indices to match, if specified. + """ + match = re.match(r"^(.+)\[([0-9]*)(?::([0-9]*))?\]$", pattern) + if not match: + return pattern, None + + base_pattern, start_part, end_part = match.groups() + + if end_part is None and start_part.isdecimal(): + index = int(start_part) + return base_pattern, (index, index + 1) + + # [:end] or [start:] or [start:end] + start = int(start_part) if start_part else 0 + end = int(end_part) if end_part else None + + if end is not None and start >= end: + raise ValueError( + f"Invalid range in layer name pattern: {pattern}." + "End of range must be greater than start." + ) + return base_pattern, (start, end) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 36c9c17e35..fce7b20a7a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -429,7 +429,7 @@ def load_model( and cfg.flash_attention and cfg.sample_packing ): - patch_for_multipack(cfg.model_config_type) + patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model) elif cfg.is_llama_derived_model: # Modify all llama derived models in one block @@ -1055,6 +1055,8 @@ def load_lora(model, cfg, inference=False, config_only=False): lora_config_kwargs["init_lora_weights"] = "loftq" if cfg.peft_use_dora: lora_config_kwargs["use_dora"] = cfg.peft_use_dora + if cfg.peft_use_rslora: + lora_config_kwargs["use_rslora"] = cfg.use_rslora lora_config = LoraConfig( r=cfg.lora_r, diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 7f63a92fea..afbdef8778 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -2,6 +2,8 @@ import logging +import re +from typing import Dict, List from termcolor import colored @@ -36,3 +38,65 @@ def check_example_labels(example, tokenizer, text_only=False): LOG.info("\n\n\n") return " ".join(colored_tokens) + + +GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"] +GLAIVE_TO_SHAREGPT_ROLE = { + "SYSTEM": "system", + "USER": "human", + "ASSISTANT": "gpt", + "FUNCTION RESPONSE": "tool", +} + +GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ") + + +def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]: + """ + Converts a ChatML formatted row to a list of messages in ShareGPT format. + Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb. + """ + + system_prompt = row.get("system") + if system_prompt: + system_prompt = system_prompt.removeprefix("SYSTEM: ") + + chat_str = row["chat"] + chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s] + + chat_msg_dicts = [ + {"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value} + for role, value in zip(chat_msgs[::2], chat_msgs[1::2]) + ] + + if system_prompt: + chat_msg_dicts = [ + {"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt} + ] + chat_msg_dicts + + return chat_msg_dicts + + +def merge_consecutive_messages(messages): + """ + Merge consecutive messages from the same sender into a single message. + This can be useful with datasets that contain multiple consecutive tool calls. + """ + + merged_messages = [] + current_from = None + current_message = "" + + for msg in messages: + if current_from == msg["from"]: + current_message += msg["value"] + else: + if current_from is not None: + merged_messages.append({"from": current_from, "value": current_message}) + current_from = msg["from"] + current_message = msg["value"] + + if current_from is not None: + merged_messages.append({"from": current_from, "value": current_message}) + + return merged_messages diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py index 19f8217e04..c9290b220a 100644 --- a/tests/prompt_strategies/test_sharegpt.py +++ b/tests/prompt_strategies/test_sharegpt.py @@ -1,6 +1,7 @@ """ Test module for sharegpt integration w chatml """ + import pytest from datasets import Dataset from tokenizers import AddedToken @@ -8,6 +9,7 @@ from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies.sharegpt import ( + GlaiveShareGPTPromptTokenizingStrategy, SimpleShareGPTPromptTokenizingStrategy, register_chatml_template, ) @@ -48,6 +50,18 @@ def fixture_sharegpt_dataset(): ) +@pytest.fixture(name="glaive_dataset") +def fixture_sharegpt_glaive_dataset(): + return Dataset.from_list( + [ + { + "system": "SYSTEM: This is a system prompt", + "chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>", + } + ] + ) + + @pytest.fixture(name="tokenizer") def fixture_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") @@ -156,3 +170,29 @@ def test_no_train_on_input(self, sharegpt_dataset, tokenizer): 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt ] # fmt: on + + def test_chatml_glaive(self, glaive_dataset, tokenizer): + strategy = GlaiveShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2( + conversation="chatml", + role_key_model=None, + role_key_human=None, + ), + tokenizer, + True, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, glaive_dataset, process_count=1 + ) + + labels = dataset_wrapper[0]["labels"] + # fmt: off + assert labels == [ + 1, # bos + 32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, # system + 32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, # human + 32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt + ] + # fmt: on diff --git a/tests/test_freeze.py b/tests/test_freeze.py new file mode 100644 index 0000000000..49d30ba5fa --- /dev/null +++ b/tests/test_freeze.py @@ -0,0 +1,285 @@ +""" +This module contains unit tests for the `freeze_layers_except` function. + +The `freeze_layers_except` function is used to freeze layers in a model, except for the specified layers. +The unit tests in this module verify the behavior of the `freeze_layers_except` function in different scenarios. +""" + +import unittest + +import torch +from torch import nn + +from axolotl.utils.freeze import freeze_layers_except + +ZERO = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +ONE_TO_TEN = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + + +class TestFreezeLayersExcept(unittest.TestCase): + """ + A test case class for the `freeze_layers_except` function. + """ + + def setUp(self): + self.model = _TestModel() + + def test_freeze_layers_with_dots_in_name(self): + freeze_layers_except(self.model, ["features.layer"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + def test_freeze_layers_without_dots_in_name(self): + freeze_layers_except(self.model, ["classifier"]) + self.assertFalse( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertTrue( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + def test_freeze_layers_regex_patterns(self): + # The second pattern cannot match because only characters 'a' to 'c' are allowed after the word 'class', whereas it should be matching the character 'i'. + freeze_layers_except(self.model, [r"^features.[a-z]+.weight$", r"class[a-c]+"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + def test_all_layers_frozen(self): + freeze_layers_except(self.model, []) + self.assertFalse( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be frozen.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + def test_all_layers_unfrozen(self): + freeze_layers_except(self.model, ["features.layer", "classifier"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertTrue( + self.model.classifier.weight.requires_grad, + "model.classifier should be trainable.", + ) + + def test_freeze_layers_with_range_pattern_start_end(self): + freeze_layers_except(self.model, ["features.layer[1:5]"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + self._assert_gradient_output( + [ + ZERO, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ZERO, + ZERO, + ZERO, + ZERO, + ZERO, + ] + ) + + def test_freeze_layers_with_range_pattern_single_index(self): + freeze_layers_except(self.model, ["features.layer[5]"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + self._assert_gradient_output( + [ZERO, ZERO, ZERO, ZERO, ZERO, ONE_TO_TEN, ZERO, ZERO, ZERO, ZERO] + ) + + def test_freeze_layers_with_range_pattern_start_omitted(self): + freeze_layers_except(self.model, ["features.layer[:5]"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + self._assert_gradient_output( + [ + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ZERO, + ZERO, + ZERO, + ZERO, + ZERO, + ] + ) + + def test_freeze_layers_with_range_pattern_end_omitted(self): + freeze_layers_except(self.model, ["features.layer[4:]"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + self._assert_gradient_output( + [ + ZERO, + ZERO, + ZERO, + ZERO, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ] + ) + + def test_freeze_layers_with_range_pattern_merge_included(self): + freeze_layers_except(self.model, ["features.layer[4:]", "features.layer[5:6]"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + self._assert_gradient_output( + [ + ZERO, + ZERO, + ZERO, + ZERO, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ] + ) + + def test_freeze_layers_with_range_pattern_merge_intersect(self): + freeze_layers_except(self.model, ["features.layer[4:7]", "features.layer[6:8]"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + self._assert_gradient_output( + [ + ZERO, + ZERO, + ZERO, + ZERO, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ZERO, + ZERO, + ] + ) + + def test_freeze_layers_with_range_pattern_merge_separate(self): + freeze_layers_except( + self.model, + ["features.layer[1:2]", "features.layer[3:4]", "features.layer[5:6]"], + ) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + self._assert_gradient_output( + [ + ZERO, + ONE_TO_TEN, + ZERO, + ONE_TO_TEN, + ZERO, + ONE_TO_TEN, + ZERO, + ZERO, + ZERO, + ZERO, + ] + ) + + def _assert_gradient_output(self, expected): + input_tensor = torch.tensor([ONE_TO_TEN], dtype=torch.float32) + + self.model.features.layer.weight.grad = None # Reset gradients + output = self.model.features.layer(input_tensor) + loss = output.sum() + loss.backward() + + expected_grads = torch.tensor(expected) + torch.testing.assert_close( + self.model.features.layer.weight.grad, expected_grads + ) + + +class _SubLayerModule(nn.Module): + def __init__(self): + super().__init__() + self.layer = nn.Linear(10, 10) + + +class _TestModel(nn.Module): + def __init__(self): + super().__init__() + self.features = _SubLayerModule() + self.classifier = nn.Linear(10, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index cf662d95f7..4e659006fe 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -1,4 +1,5 @@ """Module for testing prompt tokenizers.""" + import json import logging import unittest @@ -7,7 +8,8 @@ from typing import Optional import pytest -from transformers import AutoTokenizer, LlamaTokenizer +from datasets import load_dataset +from transformers import AddedToken, AutoTokenizer, LlamaTokenizer from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter from axolotl.prompt_strategies.alpaca_w_system import ( @@ -18,11 +20,14 @@ Llama2ChatPrompter, LLama2ChatTokenizingStrategy, ) +from axolotl.prompt_strategies.orpo.chat_template import load +from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, ) from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2 +from axolotl.utils.dict import DictDefault LOG = logging.getLogger("axolotl") @@ -266,6 +271,23 @@ def test_sharegpt_assistant_label_ignore(self): idx = res["input_ids"].index(20255) # assistant token assert res["labels"][idx] == -100 + def test_glaive_tool_label_ignore(self): + conversation = { + "system": "SYSTEM: This is a system prompt", + "chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>", + } + prompter = ShareGPTPrompterV2() + strat = GlaiveShareGPTPromptTokenizingStrategy( + prompter, + self.tokenizer, + False, + 2048, + ) + with self._caplog.at_level(logging.WARNING): + res = strat.tokenize_prompt(conversation) + idx = res["input_ids"].index(13566) # assistant token + assert res["labels"][idx] == -100 + def test_no_sys_prompt(self): """ tests the interface between the user and assistant parts @@ -427,5 +449,57 @@ def compare_with_transformers_integration(self): ) +class OrpoTokenizationTest(unittest.TestCase): + """test case for the ORPO tokenization""" + + def setUp(self) -> None: + # pylint: disable=duplicate-code + tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + tokenizer.add_special_tokens( + { + "eos_token": AddedToken( + "<|im_end|>", rstrip=False, lstrip=False, normalized=False + ) + } + ) + tokenizer.add_tokens( + [ + AddedToken( + "<|im_start|>", rstrip=False, lstrip=False, normalized=False + ), + ] + ) + self.tokenizer = tokenizer + self.dataset = load_dataset( + "argilla/ultrafeedback-binarized-preferences-cleaned", split="train" + ).select([0]) + + def test_orpo_integration(self): + strat = load( + self.tokenizer, + DictDefault({"train_on_inputs": False}), + DictDefault({"chat_template": "chatml"}), + ) + res = strat.tokenize_prompt(self.dataset[0]) + assert "rejected_input_ids" in res + assert "rejected_labels" in res + assert "input_ids" in res + assert "labels" in res + assert "prompt_attention_mask" in res + + assert len(res["rejected_input_ids"]) == len(res["rejected_labels"]) + assert len(res["input_ids"]) == len(res["labels"]) + assert len(res["input_ids"]) == len(res["prompt_attention_mask"]) + + assert res["rejected_labels"][0] == -100 + assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1] + + assert res["labels"][0] == -100 + assert res["input_ids"][-1] == res["labels"][-1] + + assert res["prompt_attention_mask"][0] == 1 + assert res["prompt_attention_mask"][-1] == 0 + + if __name__ == "__main__": unittest.main()