Skip to content

Commit

Permalink
use fastchat conversations template
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 15, 2023
1 parent 2414673 commit 2c13f03
Showing 1 changed file with 20 additions and 52 deletions.
72 changes: 20 additions & 52 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Module containing prompters"""

import dataclasses
import logging
from enum import Enum, auto
from typing import Generator, List, Optional, Tuple, Union
from typing import Generator, Optional, Union

from fastchat.conversation import Conversation, get_conv_template

LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100
Expand Down Expand Up @@ -236,45 +237,6 @@ class SeparatorStyle(Enum):
DOLLY = auto()


# TODO clean this 💩 up
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""

system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: Optional[str] = None

def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
# seps = [self.sep, self.sep2]
preamble = self.system + self.sep
yield ("SYSTEM:", preamble)
for _, (role, message) in enumerate(self.messages):
if message:
yield (role + ":", " " + message)
else:
LOG.warning(f"role with empty message: {role}")
yield (role + ":", "")

def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
)

def append_message(self, role, message):
self.messages.append([role, message])


SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
Expand All @@ -285,28 +247,34 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
A prompter that generates prompts for the ShareGPT
"""

def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
def __init__(
self,
prompt_style=None,
system_prompt: Optional[str] = None,
conversation: Optional[Union[str, Conversation]] = None,
):
if prompt_style != PromptStyle.CHAT.value:
raise ValueError(
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
)
system: str = (
system_prompt
if system_prompt
if system_prompt is not None
else (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
)
)
self._conversation = Conversation(
system=system,
roles=["USER", "ASSISTANT"],
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2=" ",
)
if conversation:
if isinstance(conversation, Conversation):
self._conversation = conversation
else:
self._conversation = get_conv_template(conversation)
else:
self._conversation = get_conv_template("vicuna_v1.1")

if system:
self._conversation.system_message = system

def build_prompt(self, source) -> Generator[str, None, None]:
if len(source) < 2:
Expand Down

0 comments on commit 2c13f03

Please sign in to comment.