From 0800885e2fb7e785f553b540ac63a46c149588d6 Mon Sep 17 00:00:00 2001 From: MilesQLi Date: Fri, 27 Oct 2023 22:00:16 -0400 Subject: [PATCH] =?UTF-8?q?Update=20to=20adapt=20to=20sharegpt=20datasets?= =?UTF-8?q?=20with=20"assistant"=20rather=20than=20"gp=E2=80=A6=20(#774)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update to adapt to sharegpt datasets with "assistant" rather than "gpt" as the machine answers. * use a strict option for hanedling incorrect turn data * chore: lint --------- Co-authored-by: Wing Lian --- src/axolotl/prompt_strategies/sharegpt.py | 25 +++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index da36e778e0..255d9a2905 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): ) field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None - return SimpleShareGPTPromptTokenizingStrategy( + strategy = SimpleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2( conversation=conversation, role_key_model=field_model, @@ -34,6 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): cfg.train_on_inputs, cfg.sequence_len, ) + if ds_cfg and "strict" in ds_cfg: + strategy.strict = ds_cfg["strict"] + return strategy def load_role(tokenizer, cfg): @@ -59,8 +62,26 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): basic sharegpt strategy to grab conversations from the sample row """ + _strict = True + + @property + def strict(self): + return self._strict + + @strict.setter + def strict(self, strict): + self._strict = strict + def get_conversation_thread(self, prompt): - return prompt["conversations"] + conversations = prompt["conversations"] + if self.strict: + return conversations + # remap roles - allow for assistant turn + role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"} + turns = [ + {"from": role_map[t["from"]], "value": t["value"]} for t in conversations + ] + return turns class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):