From ec5d85374dee7de19f81c47c24cc6f8cbba52121 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:58:46 +0100 Subject: [PATCH] fix lint errors --- src/axolotl/prompt_strategies/dpo/chatml.py | 4 +++- .../prompt_strategies/dpo/user_defined.py | 16 ++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/axolotl/prompt_strategies/dpo/chatml.py b/src/axolotl/prompt_strategies/dpo/chatml.py index 0d19d65cbb..e8c7f4088c 100644 --- a/src/axolotl/prompt_strategies/dpo/chatml.py +++ b/src/axolotl/prompt_strategies/dpo/chatml.py @@ -72,7 +72,9 @@ def transform_fn(sample): return transform_fn -def prompt_pairs(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument +def prompt_pairs( + cfg, **kwargs +): # pylint: disable=possibly-unused-variable,unused-argument def transform_fn(sample): if "system" in sample and sample["system"]: sample["prompt"] = ( diff --git a/src/axolotl/prompt_strategies/dpo/user_defined.py b/src/axolotl/prompt_strategies/dpo/user_defined.py index 754b674102..1d5f891af6 100644 --- a/src/axolotl/prompt_strategies/dpo/user_defined.py +++ b/src/axolotl/prompt_strategies/dpo/user_defined.py @@ -3,10 +3,12 @@ """ -def default(cfg, dataset_idx=0, **kwargs): +def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument ds_cfg = cfg["datasets"][dataset_idx]["type"] if not isinstance(ds_cfg, dict): - raise ValueError(f"User-defined dataset type must be a dictionary. Got: {ds_cfg}") + raise ValueError( + f"User-defined dataset type must be a dictionary. Got: {ds_cfg}" + ) field_prompt = ds_cfg.get("field_prompt", "prompt") field_system = ds_cfg.get("field_system", "system") field_chosen = ds_cfg.get("field_chosen", "chosen") @@ -22,8 +24,14 @@ def default(cfg, dataset_idx=0, **kwargs): rejected_format = "{" + field_rejected + "}" def transform_fn(sample): - if "{" + field_system + "}" in prompt_format and field_system in sample and sample[field_system]: - sample["prompt"] = prompt_format.format(system=sample[field_system], prompt=sample[field_prompt]) + if ( + "{" + field_system + "}" in prompt_format + and field_system in sample + and sample[field_system] + ): + sample["prompt"] = prompt_format.format( + system=sample[field_system], prompt=sample[field_prompt] + ) else: sample["prompt"] = prompt_format.format(prompt=sample["prompt"]) sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])