From 24a7349fea4ec8b623f09f69661ea8a247fc119b Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Fri, 2 Feb 2024 13:25:45 +0100 Subject: [PATCH 1/4] support user-defined prompt processing strategies for dpo --- src/axolotl/prompt_strategies/dpo/__init__.py | 5 ++- src/axolotl/prompt_strategies/dpo/chatml.py | 8 +++-- .../prompt_strategies/dpo/user_defined.py | 31 +++++++++++++++++++ src/axolotl/prompt_strategies/dpo/zephyr.py | 2 +- src/axolotl/utils/data.py | 2 +- 5 files changed, 40 insertions(+), 8 deletions(-) create mode 100644 src/axolotl/prompt_strategies/dpo/user_defined.py diff --git a/src/axolotl/prompt_strategies/dpo/__init__.py b/src/axolotl/prompt_strategies/dpo/__init__.py index 3c1c808005..8bd430f912 100644 --- a/src/axolotl/prompt_strategies/dpo/__init__.py +++ b/src/axolotl/prompt_strategies/dpo/__init__.py @@ -8,14 +8,13 @@ LOG = logging.getLogger("axolotl") -def load(strategy, cfg): +def load(strategy, cfg, **kwargs): try: load_fn = strategy.split(".")[-1] strategy = ".".join(strategy.split(".")[:-1]) mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo") func = getattr(mod, load_fn) - load_kwargs = {} - return func(cfg, **load_kwargs) + return func(cfg, **kwargs) except Exception: # pylint: disable=broad-exception-caught LOG.warning(f"unable to load strategy {strategy}") return None diff --git a/src/axolotl/prompt_strategies/dpo/chatml.py b/src/axolotl/prompt_strategies/dpo/chatml.py index 8f62a5088e..0d19d65cbb 100644 --- a/src/axolotl/prompt_strategies/dpo/chatml.py +++ b/src/axolotl/prompt_strategies/dpo/chatml.py @@ -5,6 +5,7 @@ def argilla( cfg, + **kwargs, ): # pylint: disable=possibly-unused-variable,unused-argument def transform_fn(sample): if "system" in sample and sample["system"]: @@ -25,6 +26,7 @@ def transform_fn(sample): def icr( cfg, + **kwargs, ): # pylint: disable=possibly-unused-variable,unused-argument """ chatml transforms for datasets with system, input, chosen, rejected @@ -48,7 +50,7 @@ def transform_fn(sample): return transform_fn -def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument +def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument """ For Intel Orca DPO Pairs """ @@ -70,7 +72,7 @@ def transform_fn(sample): return transform_fn -def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument +def prompt_pairs(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument def transform_fn(sample): if "system" in sample and sample["system"]: sample["prompt"] = ( @@ -88,7 +90,7 @@ def transform_fn(sample): return transform_fn -def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument +def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument """ for ultrafeedback binarized conversations """ diff --git a/src/axolotl/prompt_strategies/dpo/user_defined.py b/src/axolotl/prompt_strategies/dpo/user_defined.py new file mode 100644 index 0000000000..5ba488ca3d --- /dev/null +++ b/src/axolotl/prompt_strategies/dpo/user_defined.py @@ -0,0 +1,31 @@ +""" +User-defined DPO strategies +""" + + +def default(cfg, dataset_idx=0, **kwargs): + ds_cfg = cfg["datasets"][dataset_idx] + field_prompt = ds_cfg.get("field_prompt", "prompt") + field_system = ds_cfg.get("field_system", "system") + field_chosen = ds_cfg.get("field_chosen", "chosen") + field_rejected = ds_cfg.get("field_rejected", "rejected") + prompt_format = ds_cfg.get("prompt_format") + if not prompt_format: + prompt_format = "{" + field_prompt + "}" + chosen_format = ds_cfg.get("chosen_format") + if not chosen_format: + chosen_format = "{" + field_chosen + "}" + rejected_format = ds_cfg.get("rejected_format") + if not rejected_format: + rejected_format = "{" + field_rejected + "}" + + def transform_fn(sample): + if "{" + field_system + "}" in prompt_format and field_system in sample and sample[field_system]: + sample["prompt"] = prompt_format.format(system=sample[field_system], prompt=sample[field_prompt]) + else: + sample["prompt"] = prompt_format.format(prompt=sample["prompt"]) + sample["chosen"] = chosen_format.format(chosen=sample[field_chosen]) + sample["rejected"] = rejected_format.format(rejected=sample[field_rejected]) + return sample + + return transform_fn diff --git a/src/axolotl/prompt_strategies/dpo/zephyr.py b/src/axolotl/prompt_strategies/dpo/zephyr.py index 02bce8a338..9eb8950091 100644 --- a/src/axolotl/prompt_strategies/dpo/zephyr.py +++ b/src/axolotl/prompt_strategies/dpo/zephyr.py @@ -3,7 +3,7 @@ """ -def nectar(cfg): # pylint: disable=possibly-unused-variable,unused-argument +def nectar(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument def transform_fn(sample): data = {} data["prompt"] = ( diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 66a9b0a71b..12cff999f6 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -937,7 +937,7 @@ def load_split(dataset_cfgs, _cfg): for i, data_set in enumerate(split_datasets): _type = dataset_cfgs[i]["type"] if _type: - ds_transform_fn = load_dpo(_type, _cfg) + ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) split_datasets[i] = data_set.map( ds_transform_fn, desc="Mapping RL Dataset", From a0570769b09af9253617eb27a412ebd810504406 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Wed, 7 Feb 2024 13:45:56 +0100 Subject: [PATCH 2/4] interpret dict dataset types as user-defined --- src/axolotl/prompt_strategies/dpo/user_defined.py | 4 +++- src/axolotl/utils/data.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/axolotl/prompt_strategies/dpo/user_defined.py b/src/axolotl/prompt_strategies/dpo/user_defined.py index 5ba488ca3d..754b674102 100644 --- a/src/axolotl/prompt_strategies/dpo/user_defined.py +++ b/src/axolotl/prompt_strategies/dpo/user_defined.py @@ -4,7 +4,9 @@ def default(cfg, dataset_idx=0, **kwargs): - ds_cfg = cfg["datasets"][dataset_idx] + ds_cfg = cfg["datasets"][dataset_idx]["type"] + if not isinstance(ds_cfg, dict): + raise ValueError(f"User-defined dataset type must be a dictionary. Got: {ds_cfg}") field_prompt = ds_cfg.get("field_prompt", "prompt") field_system = ds_cfg.get("field_system", "system") field_chosen = ds_cfg.get("field_chosen", "chosen") diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 12cff999f6..ad3a5cb2d8 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -937,6 +937,8 @@ def load_split(dataset_cfgs, _cfg): for i, data_set in enumerate(split_datasets): _type = dataset_cfgs[i]["type"] if _type: + if isinstance(_type, DictDefault): + _type = "user_defined.default" ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) split_datasets[i] = data_set.map( ds_transform_fn, From 36397ab0562c9fd36e89e4ef52d513a4476e30d1 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:58:46 +0100 Subject: [PATCH 3/4] fix lint errors --- src/axolotl/prompt_strategies/dpo/chatml.py | 4 +++- .../prompt_strategies/dpo/user_defined.py | 16 ++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/axolotl/prompt_strategies/dpo/chatml.py b/src/axolotl/prompt_strategies/dpo/chatml.py index 0d19d65cbb..e8c7f4088c 100644 --- a/src/axolotl/prompt_strategies/dpo/chatml.py +++ b/src/axolotl/prompt_strategies/dpo/chatml.py @@ -72,7 +72,9 @@ def transform_fn(sample): return transform_fn -def prompt_pairs(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument +def prompt_pairs( + cfg, **kwargs +): # pylint: disable=possibly-unused-variable,unused-argument def transform_fn(sample): if "system" in sample and sample["system"]: sample["prompt"] = ( diff --git a/src/axolotl/prompt_strategies/dpo/user_defined.py b/src/axolotl/prompt_strategies/dpo/user_defined.py index 754b674102..1d5f891af6 100644 --- a/src/axolotl/prompt_strategies/dpo/user_defined.py +++ b/src/axolotl/prompt_strategies/dpo/user_defined.py @@ -3,10 +3,12 @@ """ -def default(cfg, dataset_idx=0, **kwargs): +def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument ds_cfg = cfg["datasets"][dataset_idx]["type"] if not isinstance(ds_cfg, dict): - raise ValueError(f"User-defined dataset type must be a dictionary. Got: {ds_cfg}") + raise ValueError( + f"User-defined dataset type must be a dictionary. Got: {ds_cfg}" + ) field_prompt = ds_cfg.get("field_prompt", "prompt") field_system = ds_cfg.get("field_system", "system") field_chosen = ds_cfg.get("field_chosen", "chosen") @@ -22,8 +24,14 @@ def default(cfg, dataset_idx=0, **kwargs): rejected_format = "{" + field_rejected + "}" def transform_fn(sample): - if "{" + field_system + "}" in prompt_format and field_system in sample and sample[field_system]: - sample["prompt"] = prompt_format.format(system=sample[field_system], prompt=sample[field_prompt]) + if ( + "{" + field_system + "}" in prompt_format + and field_system in sample + and sample[field_system] + ): + sample["prompt"] = prompt_format.format( + system=sample[field_system], prompt=sample[field_prompt] + ) else: sample["prompt"] = prompt_format.format(prompt=sample["prompt"]) sample["chosen"] = chosen_format.format(chosen=sample[field_chosen]) From 83c0ceeedeae4afedf92eff10d783d5021ae8f79 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 26 Feb 2024 12:40:56 -0500 Subject: [PATCH 4/4] setup pydantic config for validation of User defined DPO --- .../utils/config/models/input/v0_4_1/__init__.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 433c84af1c..5896b5f403 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -85,12 +85,24 @@ class SFTDataset(BaseModel): field_model: Optional[str] = None +class UserDefinedDPOType(BaseModel): + """User defined typing for DPO""" + + field_system: Optional[str] = None + field_prompt: Optional[str] = None + field_chosen: Optional[str] = None + field_rejected: Optional[str] = None + prompt_format: Optional[str] = None + chosen_format: Optional[str] = None + rejected_format: Optional[str] = None + + class DPODataset(BaseModel): """DPO configuration subset""" path: Optional[str] = None split: Optional[str] = None - type: Optional[str] = None + type: Optional[Union[UserDefinedDPOType, str]] = None data_files: Optional[List[str]] = None