Skip to content

Commit

Permalink
interpret dict dataset types as user-defined
Browse files Browse the repository at this point in the history
  • Loading branch information
nopperl authored and winglian committed Feb 26, 2024
1 parent 24a7349 commit a057076
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/axolotl/prompt_strategies/dpo/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@


def default(cfg, dataset_idx=0, **kwargs):
ds_cfg = cfg["datasets"][dataset_idx]
ds_cfg = cfg["datasets"][dataset_idx]["type"]
if not isinstance(ds_cfg, dict):
raise ValueError(f"User-defined dataset type must be a dictionary. Got: {ds_cfg}")
field_prompt = ds_cfg.get("field_prompt", "prompt")
field_system = ds_cfg.get("field_system", "system")
field_chosen = ds_cfg.get("field_chosen", "chosen")
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,8 @@ def load_split(dataset_cfgs, _cfg):
for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"]
if _type:
if isinstance(_type, DictDefault):
_type = "user_defined.default"
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
split_datasets[i] = data_set.map(
ds_transform_fn,
Expand Down

0 comments on commit a057076

Please sign in to comment.