From 7d1d22f72f53af613cc168f3d473e174f0b79b47 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 19 Apr 2024 17:25:36 -0400 Subject: [PATCH 1/9] ORPO Trainer replacement (#1551) * WIP use trl ORPOTrainer * fixes to make orpo work with trl * fix the chat template laoding * make sure to handle the special tokens and add_generation for assistant turn too --- requirements.txt | 2 +- src/axolotl/cli/preprocess.py | 2 +- src/axolotl/cli/train.py | 2 +- src/axolotl/core/trainer_builder.py | 47 ++++++++--- .../prompt_strategies/orpo/__init__.py | 2 +- .../prompt_strategies/orpo/chat_template.py | 84 +++++++++++++++++++ src/axolotl/utils/data/__init__.py | 2 +- src/axolotl/utils/data/{dpo.py => rl.py} | 24 +++++- src/axolotl/utils/trainer.py | 6 +- tests/core/test_trainer_builder.py | 6 +- 10 files changed, 151 insertions(+), 26 deletions(-) rename src/axolotl/utils/data/{dpo.py => rl.py} (80%) diff --git a/requirements.txt b/requirements.txt index 4d74dee908..9289a40f39 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,6 +39,6 @@ s3fs gcsfs # adlfs -trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f +trl==0.8.5 zstandard==0.22.0 fastcore diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index a1a01d59de..fa71d67934 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): LOG.warning(msg) parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH - if parsed_cfg.rl and parsed_cfg.rl != "orpo": + if parsed_cfg.rl: # and parsed_cfg.rl != "orpo": load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) else: load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 7e004567a6..0cebe5a52b 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: else: register_chatml_template() - if cfg.rl and cfg.rl != "orpo": + if cfg.rl: # and cfg.rl != "orpo": dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 900dcb7887..fdb0810030 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -30,7 +30,7 @@ ) from transformers.trainer_utils import seed_worker from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOTrainer +from trl import DPOTrainer, ORPOConfig, ORPOTrainer from trl.trainer.utils import pad_to_length from axolotl.loraplus import create_loraplus_optimizer @@ -810,6 +810,14 @@ def tokenize_row( return res +class AxolotlORPOTrainer(ORPOTrainer): + """ + Extend the base ORPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "orpo"] + + class TrainerBuilderBase(abc.ABC): """ Base class for trainer builder @@ -1404,7 +1412,7 @@ def build_collator( ) -class HFDPOTrainerBuilder(TrainerBuilderBase): +class HFRLTrainerBuilder(TrainerBuilderBase): """ Trainer factory class for DPO Trainer """ @@ -1497,7 +1505,15 @@ def build_training_arguments(self, total_num_steps): # default to saving each epoch if not defined training_args_kwargs["save_strategy"] = "epoch" - training_args = TrainingArguments( + if self.cfg.orpo_alpha: + # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? + training_args_kwargs["beta"] = self.cfg.orpo_alpha + + training_args_cls = TrainingArguments + if self.cfg.rl == "orpo": + training_args_cls = ORPOConfig + + training_args = training_args_cls( per_device_train_batch_size=self.cfg.micro_batch_size, max_steps=self.cfg.max_steps or total_num_steps, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, @@ -1530,17 +1546,26 @@ def build(self, total_num_steps): dpo_trainer_kwargs[ "precompute_ref_log_probs" ] = self.cfg.precompute_ref_log_probs - dpo_trainer = AxolotlDPOTrainer( - self.model, - self.model_ref, + if self.cfg.rl in ["dpo", "ipo", "kto_pair"]: + trainer_cls = AxolotlDPOTrainer + dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 + trainer_cls_args = [self.model, self.model_ref] + + # these aren't used for the ORPO trainer + dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len + dpo_trainer_kwargs["max_target_length"] = None + dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len + dpo_trainer_kwargs["generate_during_eval"] = True + elif self.cfg.rl == "orpo": + trainer_cls = AxolotlORPOTrainer + trainer_cls_args = [self.model] + else: + raise ValueError(f"Unsupported RL: {self.cfg.rl}") + dpo_trainer = trainer_cls( + *trainer_cls_args, args=training_args, - beta=self.cfg.dpo_beta or 0.1, train_dataset=self.train_dataset, tokenizer=self.tokenizer, - max_length=self.cfg.sequence_len, - max_target_length=None, - max_prompt_length=self.cfg.sequence_len, - generate_during_eval=True, callbacks=self.get_callbacks(), **dpo_trainer_kwargs, ) diff --git a/src/axolotl/prompt_strategies/orpo/__init__.py b/src/axolotl/prompt_strategies/orpo/__init__.py index 3a961fcc92..4a02f3c625 100644 --- a/src/axolotl/prompt_strategies/orpo/__init__.py +++ b/src/axolotl/prompt_strategies/orpo/__init__.py @@ -6,4 +6,4 @@ from ..base import load as load_base -load = partial(load_base, module="axolotl.prompt_strategies.orpo") +load = partial(load_base, module_base="axolotl.prompt_strategies.orpo") diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py index 9953fe87e8..a89dee1575 100644 --- a/src/axolotl/prompt_strategies/orpo/chat_template.py +++ b/src/axolotl/prompt_strategies/orpo/chat_template.py @@ -78,6 +78,57 @@ def get_rejected_conversation_thread(self, prompt) -> MessageList: ) return MessageList(messages=messages) + def get_prompt(self, prompt) -> MessageList: + """Map the data to extract everything up to the last turn""" + total_msg_len = len(prompt["chosen"]) + total_msg_turns, remainder = divmod(total_msg_len, 2) + assert remainder == 0, "invalid number of turns" + + messages: List[Message] = [] + if system := prompt.get("system", None): + messages.append(Message(role="system", content=system, label=False)) + for i in range(total_msg_turns): + if "prompt" in prompt: + messages.append( + Message(role="user", content=prompt["prompt"], label=False) + ) + else: + messages.append( + Message( + role="user", + content=prompt["chosen"][i * 2]["content"], + label=False, + ) + ) + if i < total_msg_turns - 1: + messages.append( + Message( + role="assistant", + content=prompt["chosen"][i * 2 + 1]["content"], + label=False, + ) + ) + + return MessageList(messages=messages) + + def get_chosen(self, prompt) -> MessageList: + res = self.get_prompt(prompt) + res.messages.append( + Message( + role="assistant", content=prompt["chosen"][-1]["content"], label=True + ) + ) + return res + + def get_rejected(self, prompt) -> MessageList: + res = self.get_prompt(prompt) + res.messages.append( + Message( + role="assistant", content=prompt["rejected"][-1]["content"], label=True + ) + ) + return res + class ORPOTokenizingStrategy(PromptTokenizingStrategy): """ @@ -186,3 +237,36 @@ def build_prompt( chat_template=self.chat_template, tokenize=False, ), True + + +def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument + dataset_parser = ORPODatasetParsingStrategy() + + chat_template_str = chat_templates(cfg.chat_template) + + def transform_fn(sample, tokenizer=None): + res = {} + + res["prompt"] = tokenizer.apply_chat_template( + [msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages], + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=False, + ) + prompt_str_len = len(res["prompt"]) + res["chosen"] = tokenizer.apply_chat_template( + [msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages], + add_generation_prompt=False, + chat_template=chat_template_str, + tokenize=False, + )[prompt_str_len:] + res["rejected"] = tokenizer.apply_chat_template( + [msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages], + add_generation_prompt=False, + chat_template=chat_template_str, + tokenize=False, + )[prompt_str_len:] + + return res + + return transform_fn diff --git a/src/axolotl/utils/data/__init__.py b/src/axolotl/utils/data/__init__.py index 1015c8370a..140d02106d 100644 --- a/src/axolotl/utils/data/__init__.py +++ b/src/axolotl/utils/data/__init__.py @@ -1,11 +1,11 @@ """ Data processing modules """ -from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401 from axolotl.utils.data.pretraining import ( # noqa: F401 encode_pretraining, wrap_pretraining_dataset, ) +from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401 from axolotl.utils.data.sft import ( # noqa: F401 get_dataset_wrapper, load_prepare_datasets, diff --git a/src/axolotl/utils/data/dpo.py b/src/axolotl/utils/data/rl.py similarity index 80% rename from src/axolotl/utils/data/dpo.py rename to src/axolotl/utils/data/rl.py index 765a3fc374..ff5ca87ddf 100644 --- a/src/axolotl/utils/data/dpo.py +++ b/src/axolotl/utils/data/rl.py @@ -1,17 +1,20 @@ """data handling specific to DPO""" - +import inspect import logging +from functools import partial from pathlib import Path from typing import Any, List import yaml -from datasets import concatenate_datasets, load_dataset, load_from_disk +from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.prompt_strategies.dpo import load as load_dpo +from axolotl.prompt_strategies.orpo import load as load_orpo from axolotl.utils.data.utils import md5 from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first +from axolotl.utils.models import load_tokenizer LOG = logging.getLogger("axolotl") @@ -72,16 +75,29 @@ def load_split(dataset_cfgs, _cfg): ) split_datasets.insert(i, ds) + tokenizer = None for i, data_set in enumerate(split_datasets): _type = dataset_cfgs[i]["type"] if _type: if isinstance(_type, DictDefault): _type = "user_defined.default" - ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) - split_datasets[i] = data_set.map( + if _cfg.rl == "orpo": + ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i) + else: + ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) + sig = inspect.signature(ds_transform_fn) + if "tokenizer" in sig.parameters: + if not tokenizer: + tokenizer = load_tokenizer(_cfg) + ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer) + + data_set = data_set.map( ds_transform_fn, desc="Mapping RL Dataset", ) + if isinstance(data_set, DatasetDict): + data_set = data_set["train"] + split_datasets[i] = data_set else: # If no `type` is provided, assume the dataset is already in the expected format with # "prompt", "chosen" and "rejected" already preprocessed diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2a8ed216d1..808fbb59f5 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, RandomSampler from transformers.utils import is_torch_bf16_gpu_available -from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder +from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -340,8 +340,8 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - if cfg.rl in ["dpo", "ipo", "kto_pair"]: - trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer) + if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]: + trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2] else: diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py index 541fdb343d..82455922ef 100644 --- a/tests/core/test_trainer_builder.py +++ b/tests/core/test_trainer_builder.py @@ -4,7 +4,7 @@ import pytest -from axolotl.core.trainer_builder import HFDPOTrainerBuilder +from axolotl.core.trainer_builder import HFRLTrainerBuilder from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer @@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer): return load_model(cfg, tokenizer) -class TestHFDPOTrainerBuilder: +class TestHFRLTrainerBuilder: """ TestCase class for DPO trainer builder """ def test_build_training_arguments(self, cfg, model, tokenizer): - builder = HFDPOTrainerBuilder(cfg, model, tokenizer) + builder = HFRLTrainerBuilder(cfg, model, tokenizer) training_arguments = builder.build_training_arguments(100) assert training_arguments.adam_beta1 == 0.998 assert training_arguments.adam_beta2 == 0.9 From 7477a53287cf0e84eec563dc3751dff6ad41713e Mon Sep 17 00:00:00 2001 From: Frank Ruis Date: Mon, 22 Apr 2024 01:55:20 +0200 Subject: [PATCH 2/9] wrap prepared_ds_path in str() to avoid TypeError in fsspec package (#1548) * wrap prepared_ds_path in str() to avoid TypeError in fsspec package `fsspec` calls `if "::" in path` on `prepared_ds_path`, which will throw an error if it is a `PosixPath` object. * update test too --------- Co-authored-by: Wing Lian --- src/axolotl/utils/data/sft.py | 2 +- tests/test_datasets.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 39c50b1a07..dbc4172b4b 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -421,7 +421,7 @@ def for_d_in_datasets(dataset_configs): if cfg.local_rank == 0: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") - dataset.save_to_disk(prepared_ds_path) + dataset.save_to_disk(str(prepared_ds_path)) if cfg.push_dataset_to_hub: LOG.info( f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 8b7b3dae6a..a274b7b894 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -110,7 +110,7 @@ def test_load_from_save_to_disk(self): """Usual use case. Verify datasets saved via `save_to_disk` can be loaded.""" with tempfile.TemporaryDirectory() as tmp_dir: tmp_ds_name = Path(tmp_dir) / "tmp_dataset" - self.dataset.save_to_disk(tmp_ds_name) + self.dataset.save_to_disk(str(tmp_ds_name)) prepared_path = Path(tmp_dir) / "prepared" cfg = DictDefault( From 60f5ce0569b7f1d522ef81ea986ebfdc98780e6a Mon Sep 17 00:00:00 2001 From: Haoxiang Wang Date: Sun, 21 Apr 2024 18:55:40 -0500 Subject: [PATCH 3/9] Add support for Gemma chat template (#1530) * Add support for Gemma chat template * Update fschat version to include its newest support for Gemma chat style * pin fastchat to current HEAD --------- Co-authored-by: Wing Lian --- requirements.txt | 2 +- src/axolotl/monkeypatch/fastchat_conversation_turns.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9289a40f39..26525be15d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,7 +28,7 @@ scipy scikit-learn==1.2.2 pynvml art -fschat==0.2.36 +fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8 gradio==3.50.2 tensorboard diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py index d09ab5075d..7ab07d4854 100644 --- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py @@ -123,6 +123,14 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role, "" return + if self.sep_style == SeparatorStyle.GEMMA: + if self.system_message: + raise ValueError("Gemma chat template does not support system messages") + for i, (role, message) in enumerate(self.messages): + prefix = "" if i == 0 else "" + message_str = message if message else "" + yield prefix + "" + role + "\n", message_str + "\n" + return if self.sep_style == SeparatorStyle.CHATGLM: # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 From 68601ec6ad1cc0e8cb855376586e6eef6a8aa270 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Apr 2024 16:00:05 -0400 Subject: [PATCH 4/9] make sure everything stays in the same dtype when using dpo + FSDP (#1559) --- src/axolotl/core/trainer_builder.py | 4 ++++ src/axolotl/utils/models.py | 10 ++++++++++ 2 files changed, 14 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index fdb0810030..6bddb95740 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -54,6 +54,7 @@ MambaDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, ) +from axolotl.utils.models import ensure_dtype from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import ( get_cosine_schedule_with_min_lr, @@ -1569,6 +1570,9 @@ def build(self, total_num_steps): callbacks=self.get_callbacks(), **dpo_trainer_kwargs, ) + if self.cfg.fsdp: + ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype) + dpo_trainer = self.hook_post_create_trainer(dpo_trainer) for callback in self.get_post_trainer_create_callbacks(dpo_trainer): dpo_trainer.add_callback(callback) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 52d8db047f..8537b7e754 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -993,3 +993,13 @@ def load_lora(model, cfg, inference=False, config_only=False): setup_quantized_peft_meta_for_training(model) return model, lora_config + + +def ensure_dtype(model, dtype=torch.bfloat16): + for name, module in model.named_modules(): + try: + if module.weight.dtype != dtype: + print(f"Converting module {name}: {module.weight.dtype} -> {dtype}") + module.to(dtype) + except AttributeError: + pass From 98c25e15cb87f61fced4fd68d2d2b19f88a9e7aa Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sat, 27 Apr 2024 09:07:06 -0700 Subject: [PATCH 5/9] Add ORPO example and e2e test (#1572) * add example for mistral orpo * sample_packing: false for orpo * go to load_dataset (since load_rl_datasets require a transfom_fn, which only dpo uses currently) --- .gitignore | 1 + docs/rlhf.qmd | 2 +- examples/mistral/mistral-qlora-orpo.yml | 82 +++++++++++++++++++++++++ tests/e2e/test_dpo.py | 47 ++++++++++++++ 4 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 examples/mistral/mistral-qlora-orpo.yml diff --git a/.gitignore b/.gitignore index 589440abf6..e6dfee67db 100644 --- a/.gitignore +++ b/.gitignore @@ -133,6 +133,7 @@ venv/ ENV/ env.bak/ venv.bak/ +venv3.10/ # Spyder project settings .spyderproject diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 7db68915ad..b8b2bded09 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -49,7 +49,7 @@ remove_unused_columns: false chat_template: chatml datasets: - path: argilla/ultrafeedback-binarized-preferences-cleaned - type: orpo.chat_template + type: chat_template.argilla ``` #### Using local dataset files diff --git a/examples/mistral/mistral-qlora-orpo.yml b/examples/mistral/mistral-qlora-orpo.yml new file mode 100644 index 0000000000..7727fd7485 --- /dev/null +++ b/examples/mistral/mistral-qlora-orpo.yml @@ -0,0 +1,82 @@ +base_model: mistralai/Mistral-7B-v0.1 +model_type: MistralForCausalLM +tokenizer_type: LlamaTokenizer + +load_in_8bit: false +load_in_4bit: true +strict: false + +rl: orpo +orpo_alpha: 0.1 +remove_unused_columns: false + +chat_template: chatml +datasets: + - path: argilla/ultrafeedback-binarized-preferences-cleaned + type: chat_template.argilla +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./mistral-qlora-orpo-out + +adapter: qlora +lora_model_dir: + +sequence_len: 4096 +sample_packing: false +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index e28df7411f..9596b1873f 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -158,3 +158,50 @@ def test_ipo_lora(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + + @with_temp_dir + def test_orpo_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": {}, + "rl": "orpo", + "orpo_alpha": 0.1, + "remove_unused_columns": False, + "chat_template": "chatml", + "datasets": [ + { + "path": "argilla/ultrafeedback-binarized-preferences-cleaned", + "type": "chat_template.argilla", + "split": "train", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() From 5294653a2d353066600cbc66bb06f7c63c87147b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 27 Apr 2024 12:28:20 -0400 Subject: [PATCH 6/9] PoSE context length ext (#1567) * PoSE wip * fixes for pose splitting * set pose context len so we can pick that up seperately from the usable training context len * support min sample len and define num chunks * fix chunk splitting * support for curriculum/ordered learning with pose * fix sequence len sort * add curriculum_sampling to pydantic --- src/axolotl/core/trainer_builder.py | 7 ++ .../config/models/input/v0_4_1/__init__.py | 8 ++ src/axolotl/utils/trainer.py | 108 +++++++++++++++++- 3 files changed, 118 insertions(+), 5 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 6bddb95740..09651bdc9b 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -212,6 +212,10 @@ class AxolotlTrainingArguments(TrainingArguments): default=None, metadata={"help": "path under the model to access the layers"}, ) + curriculum_sampling: Optional[bool] = field( + default=None, + metadata={"help": "whether to use sequential sampling for curriculum learning"}, + ) class AxolotlTrainer(Trainer): @@ -347,6 +351,8 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: lengths=get_dataset_lengths(self.train_dataset), packing_efficiency_estimate=self.args.sample_packing_efficiency, ) + if self.args.curriculum_sampling: + return SequentialSampler(self.train_dataset) return super()._get_train_sampler() def _get_eval_sampler( @@ -1193,6 +1199,7 @@ def build(self, total_num_steps): False if self.cfg.ddp else None ) training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length + training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling report_to = None if self.cfg.use_wandb: report_to = "wandb" diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index d99155ac25..e27a8ddd52 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -503,9 +503,17 @@ class Config: unfrozen_parameters: Optional[List[str]] = None sequence_len: int = Field(default=512) + min_sample_len: Optional[int] = None sample_packing: Optional[bool] = None eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None + curriculum_sampling: Optional[bool] = None + + # for PoSE context length extension + use_pose: Optional[bool] = None + pose_split_on_token_ids: Optional[List[int]] = None + pose_max_context_len: Optional[int] = None + pose_num_chunks: Optional[int] = None pretrain_multipack_buffer_size: Optional[int] = 10_000 pretrain_multipack_attn: Optional[bool] = Field( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 808fbb59f5..2e3728cc8a 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,9 +1,10 @@ """Module containing the Trainer class and related functions""" import math import os +import random from contextlib import contextmanager from functools import partial -from typing import List +from typing import List, Optional import numpy as np import torch @@ -98,17 +99,89 @@ def add_position_ids(sample): return sample +def add_pose_position_ids( + sample, + max_context_len=32768, + split_on_token_ids: Optional[List[int]] = None, + chunks: int = 2, +): + """ + use the PoSE technique to extend the context length by randomly skipping + positions in the context. We only want to skip right before tokens in + the split_on_token_ids list. We should attempt to randomly distribute + the skips, but we don't need the final position_ids to be the full + context_len. There may be multiple turns in the context, so we want to + make sure we take into account the maximum possible number of skips + remaining in each sample. + """ + + input_ids = sample["input_ids"] + sample_len = len(input_ids) + max_skips = max_context_len - sample_len + + if split_on_token_ids is None: + split_on_token_ids = [] + + if split_on_token_ids: + split_indices = [ + i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids + ] + else: + chunk_len = sample_len // chunks + split_indices = [i * chunk_len for i in range(1, chunks)] + split_indices.append(len(input_ids)) # make sure we go to the end of the sample + if split_indices[0] < 2: + # drop the first split index if it's too close to the beginning + split_indices = split_indices[1:] + + position_ids = [] + prev_index = 0 + total_skips = 0 + + for split_index in split_indices: + num_skips = ( + random.randint(0, max_skips) # nosec B311 + if prev_index != 0 and max_skips + else 0 + ) + max_skips -= num_skips + total_skips += num_skips + + segment_position_ids = list( + range(prev_index + total_skips, split_index + total_skips) + ) + + position_ids.extend(segment_position_ids) + prev_index = split_index + + sample["sequence_len"] = position_ids[-1] + position_ids = torch.tensor(position_ids) + + sample["position_ids"] = position_ids + sample["length"] = len(position_ids) + assert len(position_ids) == len(input_ids) + + return sample + + def add_length(sample): sample["length"] = len(sample["input_ids"]) return sample -def drop_long_seq(sample, sequence_len=2048): - return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0 +def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): + return ( + len(sample["input_ids"]) <= sequence_len + and len(sample["input_ids"]) >= min_sequence_len + ) def process_datasets_for_packing(cfg, train_dataset, eval_dataset): - drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) + drop_long = partial( + drop_long_seq, + sequence_len=cfg.sequence_len, + min_sequence_len=cfg.min_sample_len or 2, + ) with zero_first(is_main_process()): if cfg.is_preprocess: min_input_len = np.min(get_dataset_lengths(train_dataset)) @@ -153,7 +226,32 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): desc="Group By Length", ) - if cfg.sample_packing: + if cfg.use_pose: + pose_kwargs = {} + if cfg.pose_num_chunks is not None: + pose_kwargs["chunks"] = cfg.pose_num_chunks + pose_fn = partial( + add_pose_position_ids, + max_context_len=cfg.pose_max_context_len, + split_on_token_ids=cfg.pose_split_on_token_ids, + **pose_kwargs, + ) + train_dataset = train_dataset.map( + pose_fn, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (PoSE)", + ) + train_dataset = train_dataset.sort("sequence_len") + if cfg.eval_sample_packing is not False: + if eval_dataset: + eval_dataset = eval_dataset.map( + pose_fn, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (PoSE)", + ) + elif cfg.sample_packing: train_dataset = train_dataset.map( add_position_ids, num_proc=cfg.dataset_processes, From 1aeece6e247f709fa11059372a71869cc6ca6f80 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 1 May 2024 00:33:53 +0900 Subject: [PATCH 7/9] chore(doc): clarify micro_batch_size (#1579) [skip ci] --- docs/config.qmd | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/config.qmd b/docs/config.qmd index caa9b7649f..dadc5c487c 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -268,6 +268,7 @@ torch_compile_backend: # Optional[str] # If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps. gradient_accumulation_steps: 1 # The number of samples to include in each batch. This is the number of samples sent to each GPU. +# Batch size per gpu = micro_batch_size * gradient_accumulation_steps micro_batch_size: 2 eval_batch_size: num_epochs: 4 From cc5d31e0d934009708d8d0669a00107181c92b78 Mon Sep 17 00:00:00 2001 From: Abhinand Date: Tue, 30 Apr 2024 21:06:04 +0530 Subject: [PATCH 8/9] Add debug option for RL dataset preprocessing (#1404) * adding debug option for RL dataset preprocessing * Refine formatting of debugging code in RL dataset preprocessing * Update __init__.py * chore: fix lint --------- Co-authored-by: NanoCode012 --- src/axolotl/cli/__init__.py | 17 +++++++++ src/axolotl/utils/tokenization.py | 61 +++++++++++++++++++++++++++++-- 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 0670f64e3a..9f40bb476d 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -433,6 +433,23 @@ def load_rl_datasets( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) + if cli_args.debug or cfg.debug: + LOG.info("check_dataset_labels...") + + tokenizer = load_tokenizer(cfg) + check_dataset_labels( + train_dataset.select( + [ + random.randrange(0, len(train_dataset) - 1) # nosec + for _ in range(cli_args.debug_num_examples) + ] + ), + tokenizer, + num_examples=cli_args.debug_num_examples, + text_only=cli_args.debug_text_only, + rl_mode=True, + ) + return TrainDatasetMeta( train_dataset=train_dataset, eval_dataset=eval_dataset, diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index afbdef8778..845296b7a6 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -1,6 +1,5 @@ """Module for tokenization utilities""" - import logging import re from typing import Dict, List @@ -10,10 +9,19 @@ LOG = logging.getLogger("axolotl") -def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False): +def check_dataset_labels( + dataset, + tokenizer, + num_examples=5, + text_only=False, + rl_mode=False, +): # the dataset is already shuffled, so let's just check the first 5 elements for idx in range(num_examples): - check_example_labels(dataset[idx], tokenizer, text_only=text_only) + if not rl_mode: + check_example_labels(dataset[idx], tokenizer, text_only=text_only) + else: + check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only) def check_example_labels(example, tokenizer, text_only=False): @@ -40,6 +48,53 @@ def check_example_labels(example, tokenizer, text_only=False): return " ".join(colored_tokens) +def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only): + """Helper function to color tokens based on their type.""" + colored_text = colored(decoded_token, color) + return ( + colored_text + if text_only + else f"{colored_text}{colored(f'({encoded_token})', 'white')}" + ) + + +def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only): + """Helper function to process and color tokens.""" + colored_tokens = [ + color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only) + for token in tokenizer.encode(tokens) + ] + return colored_tokens + + +def check_rl_example_labels(example, tokenizer, text_only=False): + field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected" + + input_tokens = example[field_prompt] + labels_chosen, labels_rejected = example[field_chosen], example[field_rejected] + + # Process and color each type of token + colored_tokens = process_tokens_for_rl_debug( + input_tokens, "yellow", tokenizer, text_only + ) + colored_chosens = process_tokens_for_rl_debug( + labels_chosen, "green", tokenizer, text_only + ) + colored_rejecteds = process_tokens_for_rl_debug( + labels_rejected, "red", tokenizer, text_only + ) + + # Create a delimiter based on text_only flag + delimiter = "" if text_only else " " + + # Logging information + LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n") + LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n") + LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n") + + return delimiter.join(colored_tokens) + + GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"] GLAIVE_TO_SHAREGPT_ROLE = { "SYSTEM": "system", From 601c08b4c2d9a8198527e5e33536d3ad499305f0 Mon Sep 17 00:00:00 2001 From: JohanWork <39947546+JohanWork@users.noreply.github.com> Date: Tue, 30 Apr 2024 18:05:12 +0200 Subject: [PATCH 9/9] ADD: warning hub model (#1301) * update warning for save_strategy * update * clean up * update * Update test_validation.py * fix validation step * update * test_validation * update * fix * fix --------- Co-authored-by: NanoCode012 --- src/axolotl/utils/config/__init__.py | 15 +++--- .../config/models/input/v0_4_1/__init__.py | 6 +-- tests/test_validation.py | 48 ++++++++++++++++--- 3 files changed, 51 insertions(+), 18 deletions(-) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 8716cb93be..a054f24a7f 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -383,9 +383,9 @@ def legacy_validate_config(cfg): "push_to_hub_model_id is deprecated. Please use hub_model_id instead." ) - if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch): + if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]: LOG.warning( - "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch." + "hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty." ) if cfg.gptq and cfg.revision_of_model: @@ -448,10 +448,14 @@ def legacy_validate_config(cfg): raise ValueError( "save_steps and saves_per_epoch are mutually exclusive and cannot be used together." ) - if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps": + if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps": raise ValueError( "save_strategy must be empty or set to `steps` when used with saves_per_epoch." ) + if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps": + raise ValueError( + "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." + ) if cfg.evals_per_epoch and cfg.eval_steps: raise ValueError( "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together." @@ -464,11 +468,6 @@ def legacy_validate_config(cfg): raise ValueError( "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch." ) - if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps": - raise ValueError( - "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." - ) - if ( cfg.evaluation_strategy and cfg.eval_steps diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index e27a8ddd52..72e82e8231 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -780,11 +780,11 @@ def check_saves(cls, data): @model_validator(mode="before") @classmethod def check_push_save(cls, data): - if data.get("hub_model_id") and not ( - data.get("save_steps") or data.get("saves_per_epoch") + if data.get("hub_model_id") and ( + data.get("save_strategy") not in ["steps", "epoch", None] ): LOG.warning( - "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch." + "hub_model_id is set without any models being saved. To save a model, set save_strategy." ) return data diff --git a/tests/test_validation.py b/tests/test_validation.py index 4865712c47..27824f2887 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1067,17 +1067,51 @@ def test_unfrozen_parameters_w_peft_layers_to_transform(self, minimal_cfg): ): validate_config(cfg) - def test_hub_model_id_save_value_warns(self, minimal_cfg): - cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg + def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg): + cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg with self._caplog.at_level(logging.WARNING): validate_config(cfg) - assert ( - "set without any models being saved" in self._caplog.records[0].message - ) + assert len(self._caplog.records) == 1 + + def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg): + cfg = ( + DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert len(self._caplog.records) == 1 + + def test_hub_model_id_save_value_steps(self, minimal_cfg): + cfg = ( + DictDefault({"hub_model_id": "test", "save_strategy": "steps"}) + | minimal_cfg + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert len(self._caplog.records) == 0 + + def test_hub_model_id_save_value_epochs(self, minimal_cfg): + cfg = ( + DictDefault({"hub_model_id": "test", "save_strategy": "epoch"}) + | minimal_cfg + ) - def test_hub_model_id_save_value(self, minimal_cfg): - cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert len(self._caplog.records) == 0 + + def test_hub_model_id_save_value_none(self, minimal_cfg): + cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert len(self._caplog.records) == 0 + + def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg): + cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg with self._caplog.at_level(logging.WARNING): validate_config(cfg)