Skip to content

Commit

Permalink
support user-defined prompt processing strategies for dpo
Browse files Browse the repository at this point in the history
  • Loading branch information
nopperl committed Feb 2, 2024
1 parent 2d65f47 commit cdde04c
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 8 deletions.
5 changes: 2 additions & 3 deletions src/axolotl/prompt_strategies/dpo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
LOG = logging.getLogger("axolotl")


def load(strategy, cfg):
def load(strategy, cfg, **kwargs):
try:
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
func = getattr(mod, load_fn)
load_kwargs = {}
return func(cfg, **load_kwargs)
return func(cfg, **kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}")
return None
8 changes: 5 additions & 3 deletions src/axolotl/prompt_strategies/dpo/chatml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
Expand All @@ -25,6 +26,7 @@ def transform_fn(sample):

def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
Expand All @@ -48,7 +50,7 @@ def transform_fn(sample):
return transform_fn


def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca DPO Pairs
"""
Expand All @@ -70,7 +72,7 @@ def transform_fn(sample):
return transform_fn


def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def prompt_pairs(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
Expand All @@ -88,7 +90,7 @@ def transform_fn(sample):
return transform_fn


def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
"""
Expand Down
31 changes: 31 additions & 0 deletions src/axolotl/prompt_strategies/dpo/user_defined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
User-defined DPO strategies
"""


def default(cfg, dataset_idx=0, **kwargs):
ds_cfg = cfg["datasets"][dataset_idx]
field_prompt = ds_cfg.get("field_prompt", "prompt")
field_system = ds_cfg.get("field_system", "system")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
prompt_format = ds_cfg.get("prompt_format")
if not prompt_format:
prompt_format = "{" + field_prompt + "}"
chosen_format = ds_cfg.get("chosen_format")
if not chosen_format:
chosen_format = "{" + field_chosen + "}"
rejected_format = ds_cfg.get("rejected_format")
if not rejected_format:
rejected_format = "{" + field_rejected + "}"

def transform_fn(sample):
if "{" + field_system + "}" in prompt_format and field_system in sample and sample[field_system]:
sample["prompt"] = prompt_format.format(system=sample[field_system], prompt=sample[field_prompt])
else:
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
return sample

return transform_fn
2 changes: 1 addition & 1 deletion src/axolotl/prompt_strategies/dpo/zephyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""


def nectar(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def nectar(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
data = {}
data["prompt"] = (
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ def load_split(dataset_cfgs, _cfg):
for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"]
if _type:
ds_transform_fn = load_dpo(_type, _cfg)
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
split_datasets[i] = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",
Expand Down

0 comments on commit cdde04c

Please sign in to comment.