diff --git a/src/axolotl/prompt_strategies/dpo/user_defined.py b/src/axolotl/prompt_strategies/dpo/user_defined.py index 5ba488ca3d..754b674102 100644 --- a/src/axolotl/prompt_strategies/dpo/user_defined.py +++ b/src/axolotl/prompt_strategies/dpo/user_defined.py @@ -4,7 +4,9 @@ def default(cfg, dataset_idx=0, **kwargs): - ds_cfg = cfg["datasets"][dataset_idx] + ds_cfg = cfg["datasets"][dataset_idx]["type"] + if not isinstance(ds_cfg, dict): + raise ValueError(f"User-defined dataset type must be a dictionary. Got: {ds_cfg}") field_prompt = ds_cfg.get("field_prompt", "prompt") field_system = ds_cfg.get("field_system", "system") field_chosen = ds_cfg.get("field_chosen", "chosen") diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 12cff999f6..ad3a5cb2d8 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -937,6 +937,8 @@ def load_split(dataset_cfgs, _cfg): for i, data_set in enumerate(split_datasets): _type = dataset_cfgs[i]["type"] if _type: + if isinstance(_type, DictDefault): + _type = "user_defined.default" ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) split_datasets[i] = data_set.map( ds_transform_fn,