Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support user-defined prompt processing strategies for dpo #1248

Merged
merged 4 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/axolotl/prompt_strategies/dpo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
LOG = logging.getLogger("axolotl")


def load(strategy, cfg):
def load(strategy, cfg, **kwargs):
try:
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
func = getattr(mod, load_fn)
load_kwargs = {}
return func(cfg, **load_kwargs)
return func(cfg, **kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}")
return None
10 changes: 7 additions & 3 deletions src/axolotl/prompt_strategies/dpo/chatml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
Expand All @@ -25,6 +26,7 @@ def transform_fn(sample):

def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
Expand All @@ -48,7 +50,7 @@ def transform_fn(sample):
return transform_fn


def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca DPO Pairs
"""
Expand All @@ -70,7 +72,9 @@ def transform_fn(sample):
return transform_fn


def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
Expand All @@ -88,7 +92,7 @@ def transform_fn(sample):
return transform_fn


def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
"""
Expand Down
41 changes: 41 additions & 0 deletions src/axolotl/prompt_strategies/dpo/user_defined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
User-defined DPO strategies
"""


def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
ds_cfg = cfg["datasets"][dataset_idx]["type"]
if not isinstance(ds_cfg, dict):
raise ValueError(
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
)
field_prompt = ds_cfg.get("field_prompt", "prompt")
field_system = ds_cfg.get("field_system", "system")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
prompt_format = ds_cfg.get("prompt_format")
if not prompt_format:
prompt_format = "{" + field_prompt + "}"
chosen_format = ds_cfg.get("chosen_format")
if not chosen_format:
chosen_format = "{" + field_chosen + "}"
rejected_format = ds_cfg.get("rejected_format")
if not rejected_format:
rejected_format = "{" + field_rejected + "}"

def transform_fn(sample):
if (
"{" + field_system + "}" in prompt_format
and field_system in sample
and sample[field_system]
):
sample["prompt"] = prompt_format.format(
system=sample[field_system], prompt=sample[field_prompt]
)
else:
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
return sample

return transform_fn
2 changes: 1 addition & 1 deletion src/axolotl/prompt_strategies/dpo/zephyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""


def nectar(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def nectar(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
data = {}
data["prompt"] = (
Expand Down
14 changes: 13 additions & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,24 @@ class SFTDataset(BaseModel):
field_model: Optional[str] = None


class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""

field_system: Optional[str] = None
field_prompt: Optional[str] = None
field_chosen: Optional[str] = None
field_rejected: Optional[str] = None
prompt_format: Optional[str] = None
chosen_format: Optional[str] = None
rejected_format: Optional[str] = None


class DPODataset(BaseModel):
"""DPO configuration subset"""

path: Optional[str] = None
split: Optional[str] = None
type: Optional[str] = None
type: Optional[Union[UserDefinedDPOType, str]] = None
data_files: Optional[List[str]] = None


Expand Down
4 changes: 3 additions & 1 deletion src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,9 @@ def load_split(dataset_cfgs, _cfg):
for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"]
if _type:
ds_transform_fn = load_dpo(_type, _cfg)
if isinstance(_type, DictDefault):
_type = "user_defined.default"
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
split_datasets[i] = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",
Expand Down
Loading