Skip to content

Commit

Permalink
dpo/kto/ipo smoke tests w lora, simplify dpo dataset type names
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 23, 2024
1 parent 7141fd1 commit 44a6f2d
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ def build_training_arguments(self, total_num_steps):

training_args = TrainingArguments(
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=total_num_steps,
max_steps=self.cfg.max_steps or total_num_steps,
remove_unused_columns=False,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
Expand Down
16 changes: 12 additions & 4 deletions src/axolotl/prompt_strategies/dpo/chatml.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""


def argilla_apply_chatml(
def argilla(
cfg,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
Expand All @@ -23,7 +23,11 @@ def transform_fn(sample):
return transform_fn


def intel_apply_chatml(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca DPO Pairs
"""

def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
Expand All @@ -41,7 +45,7 @@ def transform_fn(sample):
return transform_fn


def apply_chatml(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
Expand All @@ -59,7 +63,11 @@ def transform_fn(sample):
return transform_fn


def ultra_apply_chatml(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
"""

def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
Expand Down
145 changes: 145 additions & 0 deletions tests/e2e/test_dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
E2E tests for lora llama
"""

import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_rl_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from .utils import with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"


class TestDPOLlamaLora(unittest.TestCase):
"""
Test case for DPO Llama models using LoRA
"""

@with_temp_dir
def test_dpo_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {},
"rl": "dpo",
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"type": "chatml.intel",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

@with_temp_dir
def test_kto_pair_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {},
"rl": "kto_pair",
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"type": "chatml.intel",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

@with_temp_dir
def test_ipo_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {},
"rl": "ipo",
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"type": "chatml.intel",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

0 comments on commit 44a6f2d

Please sign in to comment.