Skip to content

Commit

Permalink
Reward model (#1879)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Oct 13, 2024
1 parent cd2d89f commit 68b1369
Show file tree
Hide file tree
Showing 12 changed files with 382 additions and 21 deletions.
63 changes: 63 additions & 0 deletions examples/gemma2/reward-model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
base_model: google/gemma-2-2b
model_type: AutoModelForSequenceClassification
tokenizer_type: AutoTokenizer

load_in_8bit: false
load_in_4bit: false
strict: false

reward_model: true
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
remove_unused_columns: false

sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:


gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
65 changes: 48 additions & 17 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
KTOTrainer,
ORPOConfig,
ORPOTrainer,
RewardConfig,
RewardTrainer,
)
from trl.trainer.utils import pad_to_length
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length

from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
Expand Down Expand Up @@ -301,6 +303,13 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
)


@dataclass
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
"""
Reward config for Reward training
"""


class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
Expand Down Expand Up @@ -398,12 +407,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
def __init__(
self,
*_args,
num_epochs=1,
bench_data_collator=None,
eval_data_collator=None,
**kwargs,
):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
super().__init__(*_args, **kwargs)
Expand Down Expand Up @@ -1039,6 +1046,14 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"]


class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""

tag_names = ["axolotl", "reward"]


class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
Expand Down Expand Up @@ -1214,6 +1229,8 @@ def _get_trainer_cls(self):
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
return AxolotlTrainer

def build(self, total_num_steps):
Expand Down Expand Up @@ -1553,6 +1570,9 @@ def build(self, total_num_steps):

trainer_kwargs = {}

if self.cfg.reward_model:
trainer_kwargs["max_length"] = self.cfg.sequence_len

if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
Expand Down Expand Up @@ -1596,10 +1616,13 @@ def build(self, total_num_steps):
"accelerator_config"
] = self.cfg.accelerator_config

training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
training_args_cls = (
AxolotlTrainingArguments
if not self.cfg.reward_model
else AxolotlRewardConfig
)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
training_args = self.hook_post_create_training_args(training_args)

Expand All @@ -1621,27 +1644,32 @@ def build(self, total_num_steps):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64

if self.cfg.reward_model:
data_collator_kwargs["max_length"] = self.cfg.sequence_len

trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not self.cfg.reward_model:
trainer_kwargs["eval_data_collator"] = eval_data_collator
if not self.cfg.reward_model:
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
tokenizer=self.tokenizer,
data_collator=self.build_collator(training_args, **data_collator_kwargs),
eval_data_collator=self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=self.get_callbacks(),
num_epochs=self.cfg.num_epochs,
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
Expand Down Expand Up @@ -1675,9 +1703,12 @@ def build_collator(
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
RewardDataCollatorWithPadding,
]
]
if use_batch_sampler_collator:
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
Expand Down
10 changes: 10 additions & 0 deletions src/axolotl/prompt_strategies/bradley_terry/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
### example yaml

```yaml
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
```
35 changes: 35 additions & 0 deletions src/axolotl/prompt_strategies/bradley_terry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Module to load prompt strategies."""

import importlib
import inspect
import logging

from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig

LOG = logging.getLogger("axolotl.prompt_strategies")


def load(strategy, tokenizer, cfg, ds_cfg):
# pylint: disable=duplicate-code
try:
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(
f".{strategy}", "axolotl.prompt_strategies.bradley_terry"
)
func = getattr(mod, load_fn)
load_kwargs = {}
if strategy == "user_defined":
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
else:
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
return None
88 changes: 88 additions & 0 deletions src/axolotl/prompt_strategies/bradley_terry/chat_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
Bradley-Terry model with chat template prompt strategy.
"""

from typing import Any, Dict, Optional

from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
ChatTemplateStrategy,
)
from axolotl.utils.chat_templates import chat_templates


class BTChatTemplateStrategy(ChatTemplateStrategy):
"""
Bradley-Terry reward model pairwise chat template prompt strategy.
"""

def tokenize_prompt(self, prompt):
"""
:param prompt: the actual row of data from the underlying dataset
:return:
"""

self.messages = "chosen_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)

self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]})
rejected_tokenized = super().tokenize_prompt(prompt)

return {
"input_ids_chosen": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"],
"labels_chosen": 1.0,
"input_ids_rejected": rejected_tokenized["input_ids"],
"attention_mask_rejected": rejected_tokenized["attention_mask"],
"labels_rejected": 0.0,
}


def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}

prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "value"),
"message_field_training": ds_cfg.get("message_field_training", "training"),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", "train_detail"
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1
if not cfg.reward_model
else cfg.sequence_len,
}

strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}

strategy = BTChatTemplateStrategy(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)

if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]

return strategy
27 changes: 27 additions & 0 deletions src/axolotl/prompt_strategies/bradley_terry/llama3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
chatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template
"""


def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
"""

def transform_fn(sample):
if "system" in sample and sample["system"]:
prompt = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = prompt + f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = prompt + f"{sample['rejected']}<|eot_id|>"
return sample

return transform_fn
1 change: 1 addition & 0 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def get_images(self, prompt):


def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}

prompter_params = {
Expand Down
3 changes: 2 additions & 1 deletion src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def train(
model, peft_config = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
model.generation_config.do_sample = True
if model.generation_config is not None:
model.generation_config.do_sample = True

model_ref = None
if cfg.rl and cfg.rl != "orpo":
Expand Down
Loading

0 comments on commit 68b1369

Please sign in to comment.