From 68b1369de9cc8b77931bc4489899216f40fdb93f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 13 Oct 2024 15:11:13 -0400 Subject: [PATCH] Reward model (#1879) --- examples/gemma2/reward-model.yaml | 63 +++++++++++++ src/axolotl/core/trainer_builder.py | 65 ++++++++++---- .../prompt_strategies/bradley_terry/README.md | 10 +++ .../bradley_terry/__init__.py | 35 ++++++++ .../bradley_terry/chat_template.py | 88 +++++++++++++++++++ .../prompt_strategies/bradley_terry/llama3.py | 27 ++++++ .../prompt_strategies/chat_template.py | 1 + src/axolotl/train.py | 3 +- .../config/models/input/v0_4_1/__init__.py | 12 +++ src/axolotl/utils/data/sft.py | 18 +++- src/axolotl/utils/trainer.py | 7 +- tests/e2e/test_reward_model_llama.py | 74 ++++++++++++++++ 12 files changed, 382 insertions(+), 21 deletions(-) create mode 100644 examples/gemma2/reward-model.yaml create mode 100644 src/axolotl/prompt_strategies/bradley_terry/README.md create mode 100644 src/axolotl/prompt_strategies/bradley_terry/__init__.py create mode 100644 src/axolotl/prompt_strategies/bradley_terry/chat_template.py create mode 100644 src/axolotl/prompt_strategies/bradley_terry/llama3.py create mode 100644 tests/e2e/test_reward_model_llama.py diff --git a/examples/gemma2/reward-model.yaml b/examples/gemma2/reward-model.yaml new file mode 100644 index 0000000000..c1f993c3ae --- /dev/null +++ b/examples/gemma2/reward-model.yaml @@ -0,0 +1,63 @@ +base_model: google/gemma-2-2b +model_type: AutoModelForSequenceClassification +tokenizer_type: AutoTokenizer + +load_in_8bit: false +load_in_4bit: false +strict: false + +reward_model: true +chat_template: gemma +datasets: + - path: argilla/distilabel-intel-orca-dpo-pairs + type: bradley_terry.chat_template +val_set_size: 0.0 +output_dir: ./outputs/out +remove_unused_columns: false + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 9c12b6141a..599144bd34 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -43,8 +43,10 @@ KTOTrainer, ORPOConfig, ORPOTrainer, + RewardConfig, + RewardTrainer, ) -from trl.trainer.utils import pad_to_length +from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler @@ -301,6 +303,13 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig): ) +@dataclass +class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig): + """ + Reward config for Reward training + """ + + class SchedulerMixin(Trainer): """ Mixin class for scheduler setup in CausalTrainer. @@ -398,12 +407,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer): def __init__( self, *_args, - num_epochs=1, bench_data_collator=None, eval_data_collator=None, **kwargs, ): - self.num_epochs = num_epochs self.bench_data_collator = bench_data_collator self.eval_data_collator = eval_data_collator super().__init__(*_args, **kwargs) @@ -1039,6 +1046,14 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): tag_names = ["axolotl", "cpo"] +class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): + """ + Extend the base RewardTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "reward"] + + class TrainerBuilderBase(abc.ABC): """ Base class for trainer builder @@ -1214,6 +1229,8 @@ def _get_trainer_cls(self): return ReLoRATrainer if self.cfg.model_config_type == "mamba": return AxolotlMambaTrainer + if self.cfg.reward_model: + return AxolotlRewardTrainer return AxolotlTrainer def build(self, total_num_steps): @@ -1553,6 +1570,9 @@ def build(self, total_num_steps): trainer_kwargs = {} + if self.cfg.reward_model: + trainer_kwargs["max_length"] = self.cfg.sequence_len + if self.cfg.optimizer in [ "optimi_adamw", "ao_adamw_4bit", @@ -1596,10 +1616,13 @@ def build(self, total_num_steps): "accelerator_config" ] = self.cfg.accelerator_config - training_args = ( - AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg - **training_arguments_kwargs, - ) + training_args_cls = ( + AxolotlTrainingArguments + if not self.cfg.reward_model + else AxolotlRewardConfig + ) + training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg + **training_arguments_kwargs, ) training_args = self.hook_post_create_training_args(training_args) @@ -1621,10 +1644,24 @@ def build(self, total_num_steps): # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html data_collator_kwargs["pad_to_multiple_of"] = 64 + if self.cfg.reward_model: + data_collator_kwargs["max_length"] = self.cfg.sequence_len + trainer_cls = self._get_trainer_cls() trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( trainer_kwargs, trainer_cls ) + if eval_data_collator := self.build_collator( + training_args, is_eval=True, **data_collator_kwargs + ): + if not self.cfg.reward_model: + trainer_kwargs["eval_data_collator"] = eval_data_collator + if not self.cfg.reward_model: + trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq( + self.tokenizer, + return_tensors="pt", + **data_collator_kwargs, + ) trainer = trainer_cls( model=self.model, train_dataset=self.train_dataset, @@ -1632,16 +1669,7 @@ def build(self, total_num_steps): args=training_args, tokenizer=self.tokenizer, data_collator=self.build_collator(training_args, **data_collator_kwargs), - eval_data_collator=self.build_collator( - training_args, is_eval=True, **data_collator_kwargs - ), - bench_data_collator=transformers.DataCollatorForSeq2Seq( - self.tokenizer, - return_tensors="pt", - **data_collator_kwargs, - ), callbacks=self.get_callbacks(), - num_epochs=self.cfg.num_epochs, **trainer_kwargs, ) trainer = self.hook_post_create_trainer(trainer) @@ -1675,9 +1703,12 @@ def build_collator( V2BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, + RewardDataCollatorWithPadding, ] ] - if use_batch_sampler_collator: + if self.cfg.reward_model: + collator = RewardDataCollatorWithPadding + elif use_batch_sampler_collator: if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: collator = V2BatchSamplerDataCollatorForSeq2Seq elif ( diff --git a/src/axolotl/prompt_strategies/bradley_terry/README.md b/src/axolotl/prompt_strategies/bradley_terry/README.md new file mode 100644 index 0000000000..39cd16137c --- /dev/null +++ b/src/axolotl/prompt_strategies/bradley_terry/README.md @@ -0,0 +1,10 @@ +### example yaml + +```yaml +chat_template: gemma +datasets: + - path: argilla/distilabel-intel-orca-dpo-pairs + type: bradley_terry.chat_template +val_set_size: 0.0 +output_dir: ./outputs/out +``` diff --git a/src/axolotl/prompt_strategies/bradley_terry/__init__.py b/src/axolotl/prompt_strategies/bradley_terry/__init__.py new file mode 100644 index 0000000000..849d84e458 --- /dev/null +++ b/src/axolotl/prompt_strategies/bradley_terry/__init__.py @@ -0,0 +1,35 @@ +"""Module to load prompt strategies.""" + +import importlib +import inspect +import logging + +from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig + +LOG = logging.getLogger("axolotl.prompt_strategies") + + +def load(strategy, tokenizer, cfg, ds_cfg): + # pylint: disable=duplicate-code + try: + load_fn = "load" + if strategy.split(".")[-1].startswith("load_"): + load_fn = strategy.split(".")[-1] + strategy = ".".join(strategy.split(".")[:-1]) + mod = importlib.import_module( + f".{strategy}", "axolotl.prompt_strategies.bradley_terry" + ) + func = getattr(mod, load_fn) + load_kwargs = {} + if strategy == "user_defined": + load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg) + else: + sig = inspect.signature(func) + if "ds_cfg" in sig.parameters: + load_kwargs["ds_cfg"] = ds_cfg + return func(tokenizer, cfg, **load_kwargs) + except ModuleNotFoundError: + return None + except Exception as exc: # pylint: disable=broad-exception-caught + LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}") + return None diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py new file mode 100644 index 0000000000..ccda0a4bde --- /dev/null +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -0,0 +1,88 @@ +""" +Bradley-Terry model with chat template prompt strategy. +""" + +from typing import Any, Dict, Optional + +from axolotl.prompt_strategies.chat_template import ( + ChatTemplatePrompter, + ChatTemplateStrategy, +) +from axolotl.utils.chat_templates import chat_templates + + +class BTChatTemplateStrategy(ChatTemplateStrategy): + """ + Bradley-Terry reward model pairwise chat template prompt strategy. + """ + + def tokenize_prompt(self, prompt): + """ + + :param prompt: the actual row of data from the underlying dataset + :return: + """ + + self.messages = "chosen_messages" + # pylint: disable=duplicate-code + prompt[self.messages] = [] + if prompt["system"]: + prompt[self.messages].append({"from": "system", "value": prompt["system"]}) + prompt[self.messages].append({"from": "user", "value": prompt["input"]}) + prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]}) + chosen_tokenized = super().tokenize_prompt(prompt) + + self.messages = "rejected_messages" + # pylint: disable=duplicate-code + prompt[self.messages] = [] + if prompt["system"]: + prompt[self.messages].append({"from": "system", "value": prompt["system"]}) + prompt[self.messages].append({"from": "user", "value": prompt["input"]}) + prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]}) + rejected_tokenized = super().tokenize_prompt(prompt) + + return { + "input_ids_chosen": chosen_tokenized["input_ids"], + "attention_mask_chosen": chosen_tokenized["attention_mask"], + "labels_chosen": 1.0, + "input_ids_rejected": rejected_tokenized["input_ids"], + "attention_mask_rejected": rejected_tokenized["attention_mask"], + "labels_rejected": 0.0, + } + + +def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + ds_cfg = ds_cfg or {} + + prompter_params = { + "tokenizer": tokenizer, + "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), + "message_field_role": ds_cfg.get("message_field_role", "from"), + "message_field_content": ds_cfg.get("message_field_content", "value"), + "message_field_training": ds_cfg.get("message_field_training", "training"), + "message_field_training_detail": ds_cfg.get( + "message_field_training_detail", "train_detail" + ), + "roles": ds_cfg.get("roles"), + "drop_system_message": ds_cfg.get("drop_system_message", False), + # we need to add one for detecting sequences with exceeding the `sequence_len` limit. + "max_length": cfg.sequence_len + 1 + if not cfg.reward_model + else cfg.sequence_len, + } + + strategy_params = { + "train_on_inputs": cfg.train_on_inputs, + "sequence_len": cfg.sequence_len, + "roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]), + "train_on_eos": ds_cfg.get("train_on_eos", "turn"), + } + + strategy = BTChatTemplateStrategy( + ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params + ) + + if "field_messages" in ds_cfg and hasattr(strategy, "messages"): + strategy.messages = ds_cfg["field_messages"] + + return strategy diff --git a/src/axolotl/prompt_strategies/bradley_terry/llama3.py b/src/axolotl/prompt_strategies/bradley_terry/llama3.py new file mode 100644 index 0000000000..1d586fd5f4 --- /dev/null +++ b/src/axolotl/prompt_strategies/bradley_terry/llama3.py @@ -0,0 +1,27 @@ +""" +chatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template +""" + + +def icr( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + """ + chatml transforms for datasets with system, input, chosen, rejected + ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + prompt = ( + f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + else: + prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["chosen"] = prompt + f"{sample['chosen']}<|eot_id|>" + sample["rejected"] = prompt + f"{sample['rejected']}<|eot_id|>" + return sample + + return transform_fn diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 48d52dae11..c7852a707f 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -403,6 +403,7 @@ def get_images(self, prompt): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None): + # pylint: disable=duplicate-code ds_cfg = ds_cfg or {} prompter_params = { diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 855dbc2d3b..6ad3736557 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -102,7 +102,8 @@ def train( model, peft_config = load_model( cfg, tokenizer, processor=processor, inference=cli_args.inference ) - model.generation_config.do_sample = True + if model.generation_config is not None: + model.generation_config.do_sample = True model_ref = None if cfg.rl and cfg.rl != "orpo": diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 3304c62f28..4831da3c8a 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -551,6 +551,7 @@ class Config: resize_token_embeddings_to_32x: Optional[bool] = None rl: Optional[RLType] = None + reward_model: Optional[bool] = None datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore @@ -856,6 +857,17 @@ def hint_sample_packing_padding(cls, data): ) return data + @model_validator(mode="before") + @classmethod + def hint_reward_model_pad(cls, data): + if data.get("reward_model") and not data.get("pad_to_sequence_len"): + LOG.warning( + "`pad_to_sequence_len: true` is recommended when using reward_model" + ) + if data.get("pad_to_sequence_len") is None: + data["pad_to_sequence_len"] = True + return data + @model_validator(mode="before") @classmethod def check_gas_bsz(cls, data): diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 163059c2b8..ce01b44098 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -19,6 +19,7 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies import load +from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load from axolotl.prompt_tokenizers import ( AlpacaMultipleChoicePromptTokenizingStrategy, AlpacaPromptTokenizingStrategy, @@ -459,7 +460,7 @@ def for_d_in_datasets(dataset_configs): else: LOG.debug("NOT shuffling merged datasets") - if not cfg.skip_prepare_dataset: + if cfg.sample_packing and not cfg.skip_prepare_dataset: dataset, _ = process_datasets_for_packing(cfg, dataset, None) if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: @@ -609,7 +610,20 @@ def get_dataset_wrapper( ) elif cfg.skip_prepare_dataset: dataset_wrapper = dataset - elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): + elif ds_strategy := config_dataset.type.startswith( + "bradley_terry" + ) and bradley_terry_load( + config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset + ): + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = TokenizedPromptDataset( + ds_strategy, + dataset, + **ds_kwargs, + ) + elif ds_strategy := load( + config_dataset.type, tokenizer, cfg, config_dataset, processor=processor + ): if isinstance(ds_strategy, DatasetWrappingStrategy): dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs) else: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 30b40925f9..7ebf384aff 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -306,7 +306,11 @@ def process_pretraining_datasets_for_packing( def calculate_total_num_steps(cfg, train_dataset, update=True): - if not cfg.total_num_tokens and not cfg.skip_prepare_dataset: + if ( + not cfg.total_num_tokens + and not cfg.skip_prepare_dataset + and not cfg.reward_model + ): total_num_tokens = np.sum( train_dataset.data.column("input_ids") .to_pandas() @@ -323,6 +327,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): not skip_estimates and not cfg.total_supervised_tokens and not cfg.skip_prepare_dataset + and not cfg.reward_model ): total_supervised_tokens = ( train_dataset.data.column("labels") diff --git a/tests/e2e/test_reward_model_llama.py b/tests/e2e/test_reward_model_llama.py new file mode 100644 index 0000000000..27ac3e25f1 --- /dev/null +++ b/tests/e2e/test_reward_model_llama.py @@ -0,0 +1,74 @@ +""" +E2E tests for reward model lora llama +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestRewardModelLoraLlama(unittest.TestCase): + """ + Test case for Llama reward models using LoRA + """ + + @with_temp_dir + def test_rm_fft(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "model_type": "AutoModelForSequenceClassification", + "tokenizer_type": "LlamaTokenizer", + "chat_template": "alpaca", + "reward_model": True, + "sequence_len": 1024, + "pad_to_sequence_len": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.0, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "argilla/distilabel-intel-orca-dpo-pairs", + "type": "bradley_terry.chat_template", + }, + ], + "remove_unused_columns": False, + "max_steps": 10, + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "gradient_checkpointing": True, + "warmup_ratio": 0.1, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists()