From 98cc645b81bf37dcb3074e7f9fb451e668da9bf4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 15 Sep 2023 01:39:34 -0400 Subject: [PATCH] tweak fastchat conversation with a monkeypatch to get individual turns --- .../fastchat_conversation_turns.py | 26 +++++++++++++++++++ src/axolotl/prompt_tokenizers.py | 10 +++++-- src/axolotl/prompters.py | 2 +- 3 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 src/axolotl/monkeypatch/fastchat_conversation_turns.py diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py new file mode 100644 index 0000000000..00a6442540 --- /dev/null +++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py @@ -0,0 +1,26 @@ +""" +monkeypatch to add a get_turns method +""" + +import logging +from typing import Generator, Tuple + +LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns") + + +def get_turns(self) -> Generator[Tuple[str, str], None, None]: + # seps = [self.sep, self.sep2] + preamble = self.system_message + self.sep + yield ("SYSTEM:", preamble) + for _, (role, message) in enumerate(self.messages): + if message: + yield (role + ":", " " + message) + else: + LOG.warning(f"role with empty message: {role}") + yield (role + ":", "") + + +def add_get_turns_to_conversation(): + import fastchat.conversation + + fastchat.conversation.Conversation.get_turns = get_turns diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 7803291114..1fba9c9da4 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -9,6 +9,9 @@ from fastchat.conversation import Conversation from transformers import PreTrainedTokenizer +from axolotl.monkeypatch.fastchat_conversation_turns import ( + add_get_turns_to_conversation, +) from axolotl.prompters import IGNORE_TOKEN_ID LOG = logging.getLogger("axolotl") @@ -19,6 +22,8 @@ LLAMA_DEFAULT_BOS_TOKEN = "" # nosec LLAMA_DEFAULT_UNK_TOKEN = "" # nosec +add_get_turns_to_conversation() + class InvalidDataException(Exception): """ @@ -366,7 +371,7 @@ def tokenize_prompt(self, prompt): self.prompter.build_prompt(self.get_conversation_thread(prompt)) ): if isinstance(part, tuple): - if part[0] == conversation.roles[0]: + if part[0] == conversation.roles[0] + ":": part = part[0] + part[1] if not user_token else part[1] # this is still the user query, we should res = self._tokenize( @@ -378,7 +383,7 @@ def tokenize_prompt(self, prompt): res["input_ids"] = [user_token, *res["input_ids"]] # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif part[0] == conversation.roles[1]: + elif part[0] == conversation.roles[1] + ":": # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID part = part[0] + part[1] if not assistant_token else part[1] # this should be the assistent response, should end with an eos token @@ -404,6 +409,7 @@ def tokenize_prompt(self, prompt): labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) else: LOG.warning(f"unhandled role: {part[0]}") + continue # pylint: disable=duplicate-code result, current_len = parse_tokenized_to_result( diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 57d88117d6..f17ec77606 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -303,5 +303,5 @@ def build_prompt(self, source) -> Generator[str, None, None]: assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE conv.append_message(role, sentence["value"]) - for part in conv.get_prompt(): + for part in conv.get_turns(): yield part