Skip to content

Commit

Permalink
tweak fastchat conversation with a monkeypatch to get individual turns
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 24, 2023
1 parent f13ce27 commit 4738fea
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
26 changes: 26 additions & 0 deletions src/axolotl/monkeypatch/fastchat_conversation_turns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
monkeypatch to add a get_turns method
"""

import logging
from typing import Generator, Tuple

LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")


def get_turns(self) -> Generator[Tuple[str, str], None, None]:
# seps = [self.sep, self.sep2]
preamble = self.system_message + self.sep
yield ("SYSTEM:", preamble)
for _, (role, message) in enumerate(self.messages):
if message:
yield (role + ":", " " + message)
else:
LOG.warning(f"role with empty message: {role}")
yield (role + ":", "")


def add_get_turns_to_conversation():
import fastchat.conversation

fastchat.conversation.Conversation.get_turns = get_turns
10 changes: 8 additions & 2 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from fastchat.conversation import Conversation
from transformers import BatchEncoding, PreTrainedTokenizer

from axolotl.monkeypatch.fastchat_conversation_turns import (
add_get_turns_to_conversation,
)
from axolotl.prompters import IGNORE_TOKEN_ID

LOG = logging.getLogger("axolotl")
Expand All @@ -19,6 +22,8 @@
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec

add_get_turns_to_conversation()


class InvalidDataException(Exception):
"""
Expand Down Expand Up @@ -361,7 +366,7 @@ def tokenize_prompt(self, prompt):
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if isinstance(part, tuple):
if part[0] == conversation.roles[0]:
if part[0] == conversation.roles[0] + ":":
turn = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should
if not part[1].strip():
Expand All @@ -375,7 +380,7 @@ def tokenize_prompt(self, prompt):
res["input_ids"] = [user_token, *res["input_ids"]]
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif part[0] == conversation.roles[1]:
elif part[0] == conversation.roles[1] + ":":
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
turn = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistant response, should end with an eos token
Expand Down Expand Up @@ -403,6 +408,7 @@ def tokenize_prompt(self, prompt):
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {part[0]}")
continue

# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,5 +289,5 @@ def build_prompt(self, source) -> Generator[str, None, None]:
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
conv.append_message(role, sentence["value"])

for part in conv.get_prompt():
for part in conv.get_turns():
yield part

0 comments on commit 4738fea

Please sign in to comment.