diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index da36e778e0..255d9a2905 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): ) field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None - return SimpleShareGPTPromptTokenizingStrategy( + strategy = SimpleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2( conversation=conversation, role_key_model=field_model, @@ -34,6 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): cfg.train_on_inputs, cfg.sequence_len, ) + if ds_cfg and "strict" in ds_cfg: + strategy.strict = ds_cfg["strict"] + return strategy def load_role(tokenizer, cfg): @@ -59,8 +62,26 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): basic sharegpt strategy to grab conversations from the sample row """ + _strict = True + + @property + def strict(self): + return self._strict + + @strict.setter + def strict(self, strict): + self._strict = strict + def get_conversation_thread(self, prompt): - return prompt["conversations"] + conversations = prompt["conversations"] + if self.strict: + return conversations + # remap roles - allow for assistant turn + role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"} + turns = [ + {"from": role_map[t["from"]], "value": t["value"]} for t in conversations + ] + return turns class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):