Skip to content

Commit

Permalink
ORPO (#1419)
Browse files Browse the repository at this point in the history
* orpo trainer

* rl handling for orpo

* support for remove_unused_columns

* orpo fixes

* fix loader for orpo

* chore: lint

* fix default for remove_unused_columns

* roll ORPO into the main AxolotlTrainer so it can be compatible with some of the other techniques like relora

* better handling of system message for orpo

* revert system prompt changes for chat templtes

* no need for else condition

* split dataset parsing into it's own component
  • Loading branch information
winglian authored Mar 18, 2024
1 parent 9e032b4 commit 35202dd
Show file tree
Hide file tree
Showing 14 changed files with 451 additions and 24 deletions.
15 changes: 15 additions & 0 deletions docs/rlhf.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ datasets:
rl: ipo
```
#### ORPO
Paper: https://arxiv.org/abs/2403.07691
```yaml
rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false

chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: orpo.chat_template
```
#### Using local dataset files
```yaml
datasets:
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH

if parsed_cfg.rl:
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
else:
register_chatml_template()

if cfg.rl:
if cfg.rl and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
Expand Down
143 changes: 142 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
import os
import sys
from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import List, Optional, Type, Union
from typing import Dict, List, Literal, Optional, Type, Union

import torch
import transformers
Expand Down Expand Up @@ -200,6 +201,9 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False,
metadata={"help": "whether this is a qlora training"},
)
orpo_alpha: Optional[float] = field(
default=None,
)


class AxolotlTrainer(Trainer):
Expand All @@ -223,6 +227,9 @@ def __init__(
self.eval_data_collator = eval_data_collator
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")

def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
Expand Down Expand Up @@ -465,8 +472,112 @@ def compute_loss(self, model, inputs, return_outputs=False):
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
if self.args.orpo_alpha:
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
return super().compute_loss(model, inputs, return_outputs=return_outputs)

def orpo_compute_custom_loss(self, logits, labels):
logits = logits.contiguous()
loss = 0.0

if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# Flatten the tokens
loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(
dim=-1
)

return loss

def orpo_compute_logps(
self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits
):
# Get the shape of chosen_attention_mask[:, :-1]
chosen_shape = chosen_attention_mask[:, :-1].shape

# Calculate the padding size
pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1)

# Pad prompt_attention_mask with zeros to match the desired shape
prompt_attention_mask_padded = torch.nn.functional.pad(
prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0
)

# Perform the subtraction operation
mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded

per_token_logps = torch.gather(
logits[:, :-1, :].log_softmax(-1),
dim=2,
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
).squeeze(2)
return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(
dtype=torch.float64
) / mask.sum(dim=1).to(dtype=torch.float64)

def orpo_compute_loss(self, model, inputs, return_outputs=False):
outputs_neg = model(
**{
"input_ids": inputs["rejected_input_ids"],
"attention_mask": inputs["rejected_attention_mask"],
"labels": inputs["rejected_labels"],
},
output_hidden_states=True,
)
outputs_pos = model(
**{
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"labels": inputs["labels"],
},
output_hidden_states=True,
)

# Calculate NLL loss
pos_loss = self.orpo_compute_custom_loss(
logits=outputs_pos.logits, labels=inputs["input_ids"]
)

# Calculate Log Probability
pos_prob = self.orpo_compute_logps(
prompt_attention_mask=inputs["prompt_attention_mask"],
chosen_inputs=inputs["input_ids"],
chosen_attention_mask=inputs["attention_mask"],
logits=outputs_pos.logits,
)
neg_prob = self.orpo_compute_logps(
prompt_attention_mask=inputs["prompt_attention_mask"],
chosen_inputs=inputs["rejected_input_ids"],
chosen_attention_mask=inputs["rejected_attention_mask"],
logits=outputs_neg.logits,
)

# Calculate log odds
log_odds = (pos_prob - neg_prob) - (
torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob))
)
sig_ratio = torch.nn.functional.sigmoid(log_odds)
ratio = torch.log(sig_ratio)

# Calculate the Final Loss
loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to(
dtype=torch.bfloat16
)

metrics = {}
metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item()
metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item()
metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item()
metrics["log_odds"] = torch.mean(log_odds).cpu().item()
self.store_metrics(metrics, train_eval="train")

return (loss, outputs_pos) if return_outputs else loss

@wraps(Trainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Expand Down Expand Up @@ -527,6 +638,28 @@ def create_accelerator_and_postprocess(self):

return res

def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)

def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
) -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)


class AxolotlMambaTrainer(AxolotlTrainer):
"""
Expand Down Expand Up @@ -903,6 +1036,11 @@ def build(self, total_num_steps):
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True

if self.cfg.remove_unused_columns is not None:
training_arguments_kwargs[
"remove_unused_columns"
] = self.cfg.remove_unused_columns

if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
Expand Down Expand Up @@ -1070,6 +1208,9 @@ def build(self, total_num_steps):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)

if self.cfg.rl == "orpo":
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha

if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[
"neftune_noise_alpha"
Expand Down
20 changes: 20 additions & 0 deletions src/axolotl/prompt_strategies/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
module for base dataset transform strategies
"""

import importlib
import logging

LOG = logging.getLogger("axolotl")


def load(strategy, cfg, module_base=None, **kwargs):
try:
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", module_base)
func = getattr(mod, load_fn)
return func(cfg, **kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}")
return None
18 changes: 3 additions & 15 deletions src/axolotl/prompt_strategies/dpo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
"""
module for DPO style dataset transform strategies
"""
from functools import partial

import importlib
import logging
from ..base import load as load_base

LOG = logging.getLogger("axolotl")


def load(strategy, cfg, **kwargs):
try:
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
func = getattr(mod, load_fn)
return func(cfg, **kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}")
return None
load = partial(load_base, module="axolotl.prompt_strategies.dpo")
9 changes: 9 additions & 0 deletions src/axolotl/prompt_strategies/orpo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
module for ORPO style dataset transform strategies
"""

from functools import partial

from ..base import load as load_base

load = partial(load_base, module="axolotl.prompt_strategies.orpo")
Loading

0 comments on commit 35202dd

Please sign in to comment.