From 33a11c9cb64c211a53f85f180e02f4196b30b050 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 24 Sep 2023 11:57:46 -0400 Subject: [PATCH] fix up so it works with multiple conversation styles, and don't strip the turns --- .../fastchat_conversation_turns.py | 166 +++++++++++++++++- src/axolotl/prompt_tokenizers.py | 14 +- src/axolotl/prompters.py | 2 + 3 files changed, 166 insertions(+), 16 deletions(-) diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py index 00a6442540..20a3356a27 100644 --- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py @@ -5,22 +5,170 @@ import logging from typing import Generator, Tuple +from fastchat.conversation import SeparatorStyle + LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns") -def get_turns(self) -> Generator[Tuple[str, str], None, None]: - # seps = [self.sep, self.sep2] - preamble = self.system_message + self.sep - yield ("SYSTEM:", preamble) - for _, (role, message) in enumerate(self.messages): - if message: - yield (role + ":", " " + message) +def get_prompt(self) -> str: + ret = "" + for role, msg in self.get_turns(): + ret += role + msg + return ret + + +def get_turns( # pylint: disable=too-many-return-statements + self, +) -> Generator[Tuple[str, str], None, None]: + """Get the prompt for generation.""" + system_prompt = self.system_template.format(system_message=self.system_message) + if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: + yield "", system_prompt + self.sep + for role, message in self.messages: + if message: + yield role + ": ", message + self.sep + else: + yield role + ":", "" + return + if self.sep_style == SeparatorStyle.ADD_COLON_TWO: + seps = [self.sep, self.sep2] + yield "", system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + yield role + ": ", message + seps[i % 2] + else: + yield role + ":", "" + return + if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: + yield "", system_prompt + self.sep + for role, message in self.messages: + if message: + yield role + ": ", message + self.sep + else: + yield role + ": ", "" # must be end with a space + return + if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: + yield "", "" if system_prompt == "" else system_prompt + self.sep + for role, message in self.messages: + if message: + yield role + "\n", message + self.sep + else: + yield role + "\n", "" + return + if self.sep_style == SeparatorStyle.NO_COLON_SINGLE: + yield "", system_prompt + for role, message in self.messages: + if message: + yield role, message + self.sep + else: + yield role, "" + return + if self.sep_style == SeparatorStyle.NO_COLON_TWO: + seps = [self.sep, self.sep2] + yield "", system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + yield role, message + seps[i % 2] + else: + yield role, "" + return + if self.sep_style == SeparatorStyle.RWKV: + yield "", system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + yield role + ": ", message.replace("\r\n", "\n").replace( + "\n\n", "\n" + ) + "\n\n" + else: + yield role + ":", "" + return + if self.sep_style == SeparatorStyle.LLAMA2: + seps = [self.sep, self.sep2] + if self.system_message: + yield "", system_prompt + self.messages[0] + " " else: - LOG.warning(f"role with empty message: {role}") - yield (role + ":", "") + yield "", "[INST] " + self.messages[0] + " " + for i, (role, message) in enumerate(self.messages[1:]): + if message: + yield role + " ", message + seps[i % 2] + else: + yield role, "" + return + if self.sep_style == SeparatorStyle.CHATGLM: + # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 + # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 + round_add_n = 1 if self.name == "chatglm2" else 0 + if system_prompt: + yield "", system_prompt + self.sep + + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + yield "", f"[Round {i//2 + round_add_n}]{self.sep}" + + if message: + yield f"{role}:", f"{message}{self.sep}" + else: + yield f"{role}:", "" + return + if self.sep_style == SeparatorStyle.CHATML: + yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n" + for role, message in self.messages: + if message: + yield role + "\n", message + self.sep + "\n" + else: + yield role + "\n", "" + return + if self.sep_style == SeparatorStyle.CHATINTERN: + # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 + seps = [self.sep, self.sep2] + yield "", system_prompt + for i, (role, message) in enumerate(self.messages): + prefix = "" if i % 2 == 0 else "" + if message: + yield prefix + role + ":", message + seps[i % 2] + "\n" + else: + yield role + ":", "" + return + if self.sep_style == SeparatorStyle.DOLLY: + seps = [self.sep, self.sep2] + yield "", system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + suffix = "\n\n" if i % 2 == 1 else "" + yield role + ":\n", message + seps[i % 2] + suffix + else: + yield role + ":\n", "" + return + if self.sep_style == SeparatorStyle.PHOENIX: + yield "", system_prompt + for role, message in self.messages: + if message: + yield role + ": ", "" + message + "" + else: + yield role + ": " + "", "" + return + if self.sep_style == SeparatorStyle.ROBIN: + yield "", system_prompt + self.sep + for role, message in self.messages: + if message: + yield role + ":\n", message + self.sep + else: + yield role + ":\n", "" + return + if self.sep_style == SeparatorStyle.FALCON_CHAT: + if self.system_message: + yield "", system_prompt + self.sep + for role, message in self.messages: + if message: + yield role + ": ", message + self.sep + else: + yield role + ":", "" + else: + raise ValueError(f"Invalid style: {self.sep_style}") def add_get_turns_to_conversation(): import fastchat.conversation fastchat.conversation.Conversation.get_turns = get_turns + fastchat.conversation.Conversation.get_prompt = get_prompt diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 342ff64838..31bb73c2f7 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -366,13 +366,13 @@ def tokenize_prompt(self, prompt): self.prompter.build_prompt(self.get_conversation_thread(prompt)) ): if isinstance(part, tuple): - if part[0] == conversation.roles[0] + ":": + if conversation.roles[0] in part[0]: turn = part[0] + part[1] if not user_token else part[1] # this is still the user query, we should if not part[1].strip(): LOG.warning(f"user turn has empty text: {prompt}") res = self._tokenize( - turn.strip(), + turn, add_eos_token=False, strip_bos_token=True, ) @@ -380,14 +380,14 @@ def tokenize_prompt(self, prompt): res["input_ids"] = [user_token, *res["input_ids"]] # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif part[0] == conversation.roles[1] + ":": + elif conversation.roles[1] in part[0]: # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID turn = part[0] + part[1] if not assistant_token else part[1] # this should be the assistant response, should end with an eos token if not part[1].strip(): LOG.warning(f"assistant turn has empty text: {prompt}") res = self._tokenize( - turn.strip(), + turn, add_eos_token=True, strip_bos_token=True, ) @@ -398,11 +398,11 @@ def tokenize_prompt(self, prompt): ] # not masked out from labels labels = copy.deepcopy(res["input_ids"]) - elif part[0] == "SYSTEM:": - part = part[1] # Ignore the system role from preamble + elif part[0] == "": + turn = part[1] # this is only ever the first part, should include the bos token and the user query res = self._tokenize( - part.strip(), add_eos_token=False, strip_bos_token=False + turn, add_eos_token=False, strip_bos_token=False ) # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index e999a7eb46..afb1bb4a1c 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -290,4 +290,6 @@ def build_prompt(self, source) -> Generator[str, None, None]: conv.append_message(role, sentence["value"]) for part in conv.get_turns(): + if part[0] and not part[1]: + LOG.warning(f"role with empty message: {part[0]}") yield part