diff --git a/.gitignore b/.gitignore index 589440abf6..e6dfee67db 100644 --- a/.gitignore +++ b/.gitignore @@ -133,6 +133,7 @@ venv/ ENV/ env.bak/ venv.bak/ +venv3.10/ # Spyder project settings .spyderproject diff --git a/docs/config.qmd b/docs/config.qmd index caa9b7649f..dadc5c487c 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -268,6 +268,7 @@ torch_compile_backend: # Optional[str] # If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps. gradient_accumulation_steps: 1 # The number of samples to include in each batch. This is the number of samples sent to each GPU. +# Batch size per gpu = micro_batch_size * gradient_accumulation_steps micro_batch_size: 2 eval_batch_size: num_epochs: 4 diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 7db68915ad..b8b2bded09 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -49,7 +49,7 @@ remove_unused_columns: false chat_template: chatml datasets: - path: argilla/ultrafeedback-binarized-preferences-cleaned - type: orpo.chat_template + type: chat_template.argilla ``` #### Using local dataset files diff --git a/examples/mistral/mistral-qlora-orpo.yml b/examples/mistral/mistral-qlora-orpo.yml new file mode 100644 index 0000000000..7727fd7485 --- /dev/null +++ b/examples/mistral/mistral-qlora-orpo.yml @@ -0,0 +1,82 @@ +base_model: mistralai/Mistral-7B-v0.1 +model_type: MistralForCausalLM +tokenizer_type: LlamaTokenizer + +load_in_8bit: false +load_in_4bit: true +strict: false + +rl: orpo +orpo_alpha: 0.1 +remove_unused_columns: false + +chat_template: chatml +datasets: + - path: argilla/ultrafeedback-binarized-preferences-cleaned + type: chat_template.argilla +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./mistral-qlora-orpo-out + +adapter: qlora +lora_model_dir: + +sequence_len: 4096 +sample_packing: false +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/requirements.txt b/requirements.txt index 4d74dee908..26525be15d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,7 +28,7 @@ scipy scikit-learn==1.2.2 pynvml art -fschat==0.2.36 +fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8 gradio==3.50.2 tensorboard @@ -39,6 +39,6 @@ s3fs gcsfs # adlfs -trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f +trl==0.8.5 zstandard==0.22.0 fastcore diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 44cf62ed9c..81d20802cd 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -436,6 +436,23 @@ def load_rl_datasets( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) + if cli_args.debug or cfg.debug: + LOG.info("check_dataset_labels...") + + tokenizer = load_tokenizer(cfg) + check_dataset_labels( + train_dataset.select( + [ + random.randrange(0, len(train_dataset) - 1) # nosec + for _ in range(cli_args.debug_num_examples) + ] + ), + tokenizer, + num_examples=cli_args.debug_num_examples, + text_only=cli_args.debug_text_only, + rl_mode=True, + ) + return TrainDatasetMeta( train_dataset=train_dataset, eval_dataset=eval_dataset, diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index a1a01d59de..fa71d67934 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): LOG.warning(msg) parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH - if parsed_cfg.rl and parsed_cfg.rl != "orpo": + if parsed_cfg.rl: # and parsed_cfg.rl != "orpo": load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) else: load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 7e004567a6..0cebe5a52b 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: else: register_chatml_template() - if cfg.rl and cfg.rl != "orpo": + if cfg.rl: # and cfg.rl != "orpo": dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 417678345e..aaa3420f74 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -30,7 +30,7 @@ ) from transformers.trainer_utils import seed_worker from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOTrainer +from trl import DPOTrainer, ORPOConfig, ORPOTrainer from trl.trainer.utils import pad_to_length from axolotl.loraplus import create_loraplus_optimizer @@ -54,6 +54,7 @@ MambaDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, ) +from axolotl.utils.models import ensure_dtype from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import ( get_cosine_schedule_with_min_lr, @@ -211,6 +212,10 @@ class AxolotlTrainingArguments(TrainingArguments): default=None, metadata={"help": "path under the model to access the layers"}, ) + curriculum_sampling: Optional[bool] = field( + default=None, + metadata={"help": "whether to use sequential sampling for curriculum learning"}, + ) class AxolotlTrainer(Trainer): @@ -346,6 +351,8 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: lengths=get_dataset_lengths(self.train_dataset), packing_efficiency_estimate=self.args.sample_packing_efficiency, ) + if self.args.curriculum_sampling: + return SequentialSampler(self.train_dataset) return super()._get_train_sampler() def _get_eval_sampler( @@ -810,6 +817,14 @@ def tokenize_row( return res +class AxolotlORPOTrainer(ORPOTrainer): + """ + Extend the base ORPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "orpo"] + + class TrainerBuilderBase(abc.ABC): """ Base class for trainer builder @@ -1189,6 +1204,7 @@ def build(self, total_num_steps): False if self.cfg.ddp else None ) training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length + training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling report_to = None if self.cfg.use_wandb: report_to = "wandb" @@ -1409,7 +1425,7 @@ def build_collator( ) -class HFDPOTrainerBuilder(TrainerBuilderBase): +class HFRLTrainerBuilder(TrainerBuilderBase): """ Trainer factory class for DPO Trainer """ @@ -1502,7 +1518,15 @@ def build_training_arguments(self, total_num_steps): # default to saving each epoch if not defined training_args_kwargs["save_strategy"] = "epoch" - training_args = TrainingArguments( + if self.cfg.orpo_alpha: + # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? + training_args_kwargs["beta"] = self.cfg.orpo_alpha + + training_args_cls = TrainingArguments + if self.cfg.rl == "orpo": + training_args_cls = ORPOConfig + + training_args = training_args_cls( per_device_train_batch_size=self.cfg.micro_batch_size, max_steps=self.cfg.max_steps or total_num_steps, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, @@ -1535,20 +1559,32 @@ def build(self, total_num_steps): dpo_trainer_kwargs[ "precompute_ref_log_probs" ] = self.cfg.precompute_ref_log_probs - dpo_trainer = AxolotlDPOTrainer( - self.model, - self.model_ref, + if self.cfg.rl in ["dpo", "ipo", "kto_pair"]: + trainer_cls = AxolotlDPOTrainer + dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 + trainer_cls_args = [self.model, self.model_ref] + + # these aren't used for the ORPO trainer + dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len + dpo_trainer_kwargs["max_target_length"] = None + dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len + dpo_trainer_kwargs["generate_during_eval"] = True + elif self.cfg.rl == "orpo": + trainer_cls = AxolotlORPOTrainer + trainer_cls_args = [self.model] + else: + raise ValueError(f"Unsupported RL: {self.cfg.rl}") + dpo_trainer = trainer_cls( + *trainer_cls_args, args=training_args, - beta=self.cfg.dpo_beta or 0.1, train_dataset=self.train_dataset, tokenizer=self.tokenizer, - max_length=self.cfg.sequence_len, - max_target_length=None, - max_prompt_length=self.cfg.sequence_len, - generate_during_eval=True, callbacks=self.get_callbacks(), **dpo_trainer_kwargs, ) + if self.cfg.fsdp: + ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype) + dpo_trainer = self.hook_post_create_trainer(dpo_trainer) for callback in self.get_post_trainer_create_callbacks(dpo_trainer): dpo_trainer.add_callback(callback) diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py index d09ab5075d..7ab07d4854 100644 --- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py @@ -123,6 +123,14 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role, "" return + if self.sep_style == SeparatorStyle.GEMMA: + if self.system_message: + raise ValueError("Gemma chat template does not support system messages") + for i, (role, message) in enumerate(self.messages): + prefix = "" if i == 0 else "" + message_str = message if message else "" + yield prefix + "" + role + "\n", message_str + "\n" + return if self.sep_style == SeparatorStyle.CHATGLM: # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 diff --git a/src/axolotl/prompt_strategies/orpo/__init__.py b/src/axolotl/prompt_strategies/orpo/__init__.py index 3a961fcc92..4a02f3c625 100644 --- a/src/axolotl/prompt_strategies/orpo/__init__.py +++ b/src/axolotl/prompt_strategies/orpo/__init__.py @@ -6,4 +6,4 @@ from ..base import load as load_base -load = partial(load_base, module="axolotl.prompt_strategies.orpo") +load = partial(load_base, module_base="axolotl.prompt_strategies.orpo") diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py index 9953fe87e8..a89dee1575 100644 --- a/src/axolotl/prompt_strategies/orpo/chat_template.py +++ b/src/axolotl/prompt_strategies/orpo/chat_template.py @@ -78,6 +78,57 @@ def get_rejected_conversation_thread(self, prompt) -> MessageList: ) return MessageList(messages=messages) + def get_prompt(self, prompt) -> MessageList: + """Map the data to extract everything up to the last turn""" + total_msg_len = len(prompt["chosen"]) + total_msg_turns, remainder = divmod(total_msg_len, 2) + assert remainder == 0, "invalid number of turns" + + messages: List[Message] = [] + if system := prompt.get("system", None): + messages.append(Message(role="system", content=system, label=False)) + for i in range(total_msg_turns): + if "prompt" in prompt: + messages.append( + Message(role="user", content=prompt["prompt"], label=False) + ) + else: + messages.append( + Message( + role="user", + content=prompt["chosen"][i * 2]["content"], + label=False, + ) + ) + if i < total_msg_turns - 1: + messages.append( + Message( + role="assistant", + content=prompt["chosen"][i * 2 + 1]["content"], + label=False, + ) + ) + + return MessageList(messages=messages) + + def get_chosen(self, prompt) -> MessageList: + res = self.get_prompt(prompt) + res.messages.append( + Message( + role="assistant", content=prompt["chosen"][-1]["content"], label=True + ) + ) + return res + + def get_rejected(self, prompt) -> MessageList: + res = self.get_prompt(prompt) + res.messages.append( + Message( + role="assistant", content=prompt["rejected"][-1]["content"], label=True + ) + ) + return res + class ORPOTokenizingStrategy(PromptTokenizingStrategy): """ @@ -186,3 +237,36 @@ def build_prompt( chat_template=self.chat_template, tokenize=False, ), True + + +def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument + dataset_parser = ORPODatasetParsingStrategy() + + chat_template_str = chat_templates(cfg.chat_template) + + def transform_fn(sample, tokenizer=None): + res = {} + + res["prompt"] = tokenizer.apply_chat_template( + [msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages], + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=False, + ) + prompt_str_len = len(res["prompt"]) + res["chosen"] = tokenizer.apply_chat_template( + [msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages], + add_generation_prompt=False, + chat_template=chat_template_str, + tokenize=False, + )[prompt_str_len:] + res["rejected"] = tokenizer.apply_chat_template( + [msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages], + add_generation_prompt=False, + chat_template=chat_template_str, + tokenize=False, + )[prompt_str_len:] + + return res + + return transform_fn diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 8716cb93be..a054f24a7f 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -383,9 +383,9 @@ def legacy_validate_config(cfg): "push_to_hub_model_id is deprecated. Please use hub_model_id instead." ) - if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch): + if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]: LOG.warning( - "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch." + "hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty." ) if cfg.gptq and cfg.revision_of_model: @@ -448,10 +448,14 @@ def legacy_validate_config(cfg): raise ValueError( "save_steps and saves_per_epoch are mutually exclusive and cannot be used together." ) - if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps": + if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps": raise ValueError( "save_strategy must be empty or set to `steps` when used with saves_per_epoch." ) + if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps": + raise ValueError( + "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." + ) if cfg.evals_per_epoch and cfg.eval_steps: raise ValueError( "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together." @@ -464,11 +468,6 @@ def legacy_validate_config(cfg): raise ValueError( "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch." ) - if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps": - raise ValueError( - "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." - ) - if ( cfg.evaluation_strategy and cfg.eval_steps diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index bb5fc6ce6c..e33774972c 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -510,9 +510,17 @@ class Config: unfrozen_parameters: Optional[List[str]] = None sequence_len: int = Field(default=512) + min_sample_len: Optional[int] = None sample_packing: Optional[bool] = None eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None + curriculum_sampling: Optional[bool] = None + + # for PoSE context length extension + use_pose: Optional[bool] = None + pose_split_on_token_ids: Optional[List[int]] = None + pose_max_context_len: Optional[int] = None + pose_num_chunks: Optional[int] = None pretrain_multipack_buffer_size: Optional[int] = 10_000 pretrain_multipack_attn: Optional[bool] = Field( @@ -779,11 +787,11 @@ def check_saves(cls, data): @model_validator(mode="before") @classmethod def check_push_save(cls, data): - if data.get("hub_model_id") and not ( - data.get("save_steps") or data.get("saves_per_epoch") + if data.get("hub_model_id") and ( + data.get("save_strategy") not in ["steps", "epoch", None] ): LOG.warning( - "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch." + "hub_model_id is set without any models being saved. To save a model, set save_strategy." ) return data diff --git a/src/axolotl/utils/data/__init__.py b/src/axolotl/utils/data/__init__.py index 1015c8370a..140d02106d 100644 --- a/src/axolotl/utils/data/__init__.py +++ b/src/axolotl/utils/data/__init__.py @@ -1,11 +1,11 @@ """ Data processing modules """ -from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401 from axolotl.utils.data.pretraining import ( # noqa: F401 encode_pretraining, wrap_pretraining_dataset, ) +from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401 from axolotl.utils.data.sft import ( # noqa: F401 get_dataset_wrapper, load_prepare_datasets, diff --git a/src/axolotl/utils/data/dpo.py b/src/axolotl/utils/data/rl.py similarity index 80% rename from src/axolotl/utils/data/dpo.py rename to src/axolotl/utils/data/rl.py index 765a3fc374..ff5ca87ddf 100644 --- a/src/axolotl/utils/data/dpo.py +++ b/src/axolotl/utils/data/rl.py @@ -1,17 +1,20 @@ """data handling specific to DPO""" - +import inspect import logging +from functools import partial from pathlib import Path from typing import Any, List import yaml -from datasets import concatenate_datasets, load_dataset, load_from_disk +from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.prompt_strategies.dpo import load as load_dpo +from axolotl.prompt_strategies.orpo import load as load_orpo from axolotl.utils.data.utils import md5 from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first +from axolotl.utils.models import load_tokenizer LOG = logging.getLogger("axolotl") @@ -72,16 +75,29 @@ def load_split(dataset_cfgs, _cfg): ) split_datasets.insert(i, ds) + tokenizer = None for i, data_set in enumerate(split_datasets): _type = dataset_cfgs[i]["type"] if _type: if isinstance(_type, DictDefault): _type = "user_defined.default" - ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) - split_datasets[i] = data_set.map( + if _cfg.rl == "orpo": + ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i) + else: + ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) + sig = inspect.signature(ds_transform_fn) + if "tokenizer" in sig.parameters: + if not tokenizer: + tokenizer = load_tokenizer(_cfg) + ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer) + + data_set = data_set.map( ds_transform_fn, desc="Mapping RL Dataset", ) + if isinstance(data_set, DatasetDict): + data_set = data_set["train"] + split_datasets[i] = data_set else: # If no `type` is provided, assume the dataset is already in the expected format with # "prompt", "chosen" and "rejected" already preprocessed diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index defb3bcd24..7a8d00adbf 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -445,7 +445,7 @@ def for_d_in_datasets(dataset_configs): if cfg.local_rank == 0: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") - dataset.save_to_disk(prepared_ds_path) + dataset.save_to_disk(str(prepared_ds_path)) if cfg.push_dataset_to_hub: LOG.info( f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 52d8db047f..8537b7e754 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -993,3 +993,13 @@ def load_lora(model, cfg, inference=False, config_only=False): setup_quantized_peft_meta_for_training(model) return model, lora_config + + +def ensure_dtype(model, dtype=torch.bfloat16): + for name, module in model.named_modules(): + try: + if module.weight.dtype != dtype: + print(f"Converting module {name}: {module.weight.dtype} -> {dtype}") + module.to(dtype) + except AttributeError: + pass diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index afbdef8778..845296b7a6 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -1,6 +1,5 @@ """Module for tokenization utilities""" - import logging import re from typing import Dict, List @@ -10,10 +9,19 @@ LOG = logging.getLogger("axolotl") -def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False): +def check_dataset_labels( + dataset, + tokenizer, + num_examples=5, + text_only=False, + rl_mode=False, +): # the dataset is already shuffled, so let's just check the first 5 elements for idx in range(num_examples): - check_example_labels(dataset[idx], tokenizer, text_only=text_only) + if not rl_mode: + check_example_labels(dataset[idx], tokenizer, text_only=text_only) + else: + check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only) def check_example_labels(example, tokenizer, text_only=False): @@ -40,6 +48,53 @@ def check_example_labels(example, tokenizer, text_only=False): return " ".join(colored_tokens) +def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only): + """Helper function to color tokens based on their type.""" + colored_text = colored(decoded_token, color) + return ( + colored_text + if text_only + else f"{colored_text}{colored(f'({encoded_token})', 'white')}" + ) + + +def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only): + """Helper function to process and color tokens.""" + colored_tokens = [ + color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only) + for token in tokenizer.encode(tokens) + ] + return colored_tokens + + +def check_rl_example_labels(example, tokenizer, text_only=False): + field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected" + + input_tokens = example[field_prompt] + labels_chosen, labels_rejected = example[field_chosen], example[field_rejected] + + # Process and color each type of token + colored_tokens = process_tokens_for_rl_debug( + input_tokens, "yellow", tokenizer, text_only + ) + colored_chosens = process_tokens_for_rl_debug( + labels_chosen, "green", tokenizer, text_only + ) + colored_rejecteds = process_tokens_for_rl_debug( + labels_rejected, "red", tokenizer, text_only + ) + + # Create a delimiter based on text_only flag + delimiter = "" if text_only else " " + + # Logging information + LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n") + LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n") + LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n") + + return delimiter.join(colored_tokens) + + GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"] GLAIVE_TO_SHAREGPT_ROLE = { "SYSTEM": "system", diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2a8ed216d1..2e3728cc8a 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,9 +1,10 @@ """Module containing the Trainer class and related functions""" import math import os +import random from contextlib import contextmanager from functools import partial -from typing import List +from typing import List, Optional import numpy as np import torch @@ -13,7 +14,7 @@ from torch.utils.data import DataLoader, RandomSampler from transformers.utils import is_torch_bf16_gpu_available -from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder +from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -98,17 +99,89 @@ def add_position_ids(sample): return sample +def add_pose_position_ids( + sample, + max_context_len=32768, + split_on_token_ids: Optional[List[int]] = None, + chunks: int = 2, +): + """ + use the PoSE technique to extend the context length by randomly skipping + positions in the context. We only want to skip right before tokens in + the split_on_token_ids list. We should attempt to randomly distribute + the skips, but we don't need the final position_ids to be the full + context_len. There may be multiple turns in the context, so we want to + make sure we take into account the maximum possible number of skips + remaining in each sample. + """ + + input_ids = sample["input_ids"] + sample_len = len(input_ids) + max_skips = max_context_len - sample_len + + if split_on_token_ids is None: + split_on_token_ids = [] + + if split_on_token_ids: + split_indices = [ + i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids + ] + else: + chunk_len = sample_len // chunks + split_indices = [i * chunk_len for i in range(1, chunks)] + split_indices.append(len(input_ids)) # make sure we go to the end of the sample + if split_indices[0] < 2: + # drop the first split index if it's too close to the beginning + split_indices = split_indices[1:] + + position_ids = [] + prev_index = 0 + total_skips = 0 + + for split_index in split_indices: + num_skips = ( + random.randint(0, max_skips) # nosec B311 + if prev_index != 0 and max_skips + else 0 + ) + max_skips -= num_skips + total_skips += num_skips + + segment_position_ids = list( + range(prev_index + total_skips, split_index + total_skips) + ) + + position_ids.extend(segment_position_ids) + prev_index = split_index + + sample["sequence_len"] = position_ids[-1] + position_ids = torch.tensor(position_ids) + + sample["position_ids"] = position_ids + sample["length"] = len(position_ids) + assert len(position_ids) == len(input_ids) + + return sample + + def add_length(sample): sample["length"] = len(sample["input_ids"]) return sample -def drop_long_seq(sample, sequence_len=2048): - return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0 +def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): + return ( + len(sample["input_ids"]) <= sequence_len + and len(sample["input_ids"]) >= min_sequence_len + ) def process_datasets_for_packing(cfg, train_dataset, eval_dataset): - drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) + drop_long = partial( + drop_long_seq, + sequence_len=cfg.sequence_len, + min_sequence_len=cfg.min_sample_len or 2, + ) with zero_first(is_main_process()): if cfg.is_preprocess: min_input_len = np.min(get_dataset_lengths(train_dataset)) @@ -153,7 +226,32 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): desc="Group By Length", ) - if cfg.sample_packing: + if cfg.use_pose: + pose_kwargs = {} + if cfg.pose_num_chunks is not None: + pose_kwargs["chunks"] = cfg.pose_num_chunks + pose_fn = partial( + add_pose_position_ids, + max_context_len=cfg.pose_max_context_len, + split_on_token_ids=cfg.pose_split_on_token_ids, + **pose_kwargs, + ) + train_dataset = train_dataset.map( + pose_fn, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (PoSE)", + ) + train_dataset = train_dataset.sort("sequence_len") + if cfg.eval_sample_packing is not False: + if eval_dataset: + eval_dataset = eval_dataset.map( + pose_fn, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (PoSE)", + ) + elif cfg.sample_packing: train_dataset = train_dataset.map( add_position_ids, num_proc=cfg.dataset_processes, @@ -340,8 +438,8 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - if cfg.rl in ["dpo", "ipo", "kto_pair"]: - trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer) + if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]: + trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2] else: diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py index 541fdb343d..82455922ef 100644 --- a/tests/core/test_trainer_builder.py +++ b/tests/core/test_trainer_builder.py @@ -4,7 +4,7 @@ import pytest -from axolotl.core.trainer_builder import HFDPOTrainerBuilder +from axolotl.core.trainer_builder import HFRLTrainerBuilder from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer @@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer): return load_model(cfg, tokenizer) -class TestHFDPOTrainerBuilder: +class TestHFRLTrainerBuilder: """ TestCase class for DPO trainer builder """ def test_build_training_arguments(self, cfg, model, tokenizer): - builder = HFDPOTrainerBuilder(cfg, model, tokenizer) + builder = HFRLTrainerBuilder(cfg, model, tokenizer) training_arguments = builder.build_training_arguments(100) assert training_arguments.adam_beta1 == 0.998 assert training_arguments.adam_beta2 == 0.9 diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index e28df7411f..9596b1873f 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -158,3 +158,50 @@ def test_ipo_lora(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + + @with_temp_dir + def test_orpo_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": {}, + "rl": "orpo", + "orpo_alpha": 0.1, + "remove_unused_columns": False, + "chat_template": "chatml", + "datasets": [ + { + "path": "argilla/ultrafeedback-binarized-preferences-cleaned", + "type": "chat_template.argilla", + "split": "train", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 8b7b3dae6a..a274b7b894 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -110,7 +110,7 @@ def test_load_from_save_to_disk(self): """Usual use case. Verify datasets saved via `save_to_disk` can be loaded.""" with tempfile.TemporaryDirectory() as tmp_dir: tmp_ds_name = Path(tmp_dir) / "tmp_dataset" - self.dataset.save_to_disk(tmp_ds_name) + self.dataset.save_to_disk(str(tmp_ds_name)) prepared_path = Path(tmp_dir) / "prepared" cfg = DictDefault( diff --git a/tests/test_validation.py b/tests/test_validation.py index 4865712c47..27824f2887 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1067,17 +1067,51 @@ def test_unfrozen_parameters_w_peft_layers_to_transform(self, minimal_cfg): ): validate_config(cfg) - def test_hub_model_id_save_value_warns(self, minimal_cfg): - cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg + def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg): + cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg with self._caplog.at_level(logging.WARNING): validate_config(cfg) - assert ( - "set without any models being saved" in self._caplog.records[0].message - ) + assert len(self._caplog.records) == 1 + + def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg): + cfg = ( + DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert len(self._caplog.records) == 1 + + def test_hub_model_id_save_value_steps(self, minimal_cfg): + cfg = ( + DictDefault({"hub_model_id": "test", "save_strategy": "steps"}) + | minimal_cfg + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert len(self._caplog.records) == 0 + + def test_hub_model_id_save_value_epochs(self, minimal_cfg): + cfg = ( + DictDefault({"hub_model_id": "test", "save_strategy": "epoch"}) + | minimal_cfg + ) - def test_hub_model_id_save_value(self, minimal_cfg): - cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert len(self._caplog.records) == 0 + + def test_hub_model_id_save_value_none(self, minimal_cfg): + cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert len(self._caplog.records) == 0 + + def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg): + cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg with self._caplog.at_level(logging.WARNING): validate_config(cfg)