From 2c13f03c23946864afce65ea0e9ec1c5957b33e2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 15 Sep 2023 00:12:33 -0400 Subject: [PATCH] use fastchat conversations template --- src/axolotl/prompters.py | 72 +++++++++++----------------------------- 1 file changed, 20 insertions(+), 52 deletions(-) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 5322a10182..43569d1e20 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -1,9 +1,10 @@ """Module containing prompters""" -import dataclasses import logging from enum import Enum, auto -from typing import Generator, List, Optional, Tuple, Union +from typing import Generator, Optional, Union + +from fastchat.conversation import Conversation, get_conv_template LOG = logging.getLogger("axolotl") IGNORE_TOKEN_ID = -100 @@ -236,45 +237,6 @@ class SeparatorStyle(Enum): DOLLY = auto() -# TODO clean this 💩 up -@dataclasses.dataclass -class Conversation: - """A class that keeps all conversation history.""" - - system: str - roles: List[str] - messages: List[List[str]] - offset: int - sep_style: SeparatorStyle = SeparatorStyle.SINGLE - sep: str = "###" - sep2: Optional[str] = None - - def get_prompt(self) -> Generator[Tuple[str, str], None, None]: - # seps = [self.sep, self.sep2] - preamble = self.system + self.sep - yield ("SYSTEM:", preamble) - for _, (role, message) in enumerate(self.messages): - if message: - yield (role + ":", " " + message) - else: - LOG.warning(f"role with empty message: {role}") - yield (role + ":", "") - - def copy(self): - return Conversation( - system=self.system, - roles=self.roles, - messages=[[x, y] for x, y in self.messages], - offset=self.offset, - sep_style=self.sep_style, - sep=self.sep, - sep2=self.sep2, - ) - - def append_message(self, role, message): - self.messages.append([role, message]) - - SHAREGPT_ASSERTION_FAILED_ROLE = ( "Role did not alternate between turns (gpt and human). Please check your data." ) @@ -285,28 +247,34 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods A prompter that generates prompts for the ShareGPT """ - def __init__(self, prompt_style=None, system_prompt: Optional[str] = None): + def __init__( + self, + prompt_style=None, + system_prompt: Optional[str] = None, + conversation: Optional[Union[str, Conversation]] = None, + ): if prompt_style != PromptStyle.CHAT.value: raise ValueError( f"unsupported prompt_style for ShareGPTPrompter({prompt_style})" ) system: str = ( system_prompt - if system_prompt + if system_prompt is not None else ( "A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions." ) ) - self._conversation = Conversation( - system=system, - roles=["USER", "ASSISTANT"], - messages=[], - offset=0, - sep_style=SeparatorStyle.TWO, - sep=" ", - sep2=" ", - ) + if conversation: + if isinstance(conversation, Conversation): + self._conversation = conversation + else: + self._conversation = get_conv_template(conversation) + else: + self._conversation = get_conv_template("vicuna_v1.1") + + if system: + self._conversation.system_message = system def build_prompt(self, source) -> Generator[str, None, None]: if len(source) < 2: