From ddb3ab24bfc3f2eadfa6027aeb0419b3c22d7f2e Mon Sep 17 00:00:00 2001
From: jstzwj <1103870790@qq.com>
Date: Wed, 3 Jan 2024 16:25:35 +0800
Subject: [PATCH] airoboros, cutegpt, alpaca, tigerbot and xgen
---
chatproto/conversation/history.py | 49 ++++++++++++++++++++-
chatproto/conversation/models/airoboros.py | 32 ++++++++++++++
chatproto/conversation/models/alpaca.py | 11 +++++
chatproto/conversation/models/baichuan.py | 15 ++++++-
chatproto/conversation/models/baichuan2.py | 13 ------
chatproto/conversation/models/chatglm.py | 22 ++++++++-
chatproto/conversation/models/cutegpt.py | 12 +++++
chatproto/conversation/models/tigerbot.py | 12 +++++
chatproto/conversation/models/xgen.py | 11 +++++
chatproto/conversation/settings.py | 4 ++
tests/test_baichuan2.py | 2 +-
tests/{test_chatglm.py => test_chatglm2.py} | 8 ++--
12 files changed, 169 insertions(+), 22 deletions(-)
create mode 100644 chatproto/conversation/models/airoboros.py
create mode 100644 chatproto/conversation/models/alpaca.py
delete mode 100644 chatproto/conversation/models/baichuan2.py
create mode 100644 chatproto/conversation/models/cutegpt.py
create mode 100644 chatproto/conversation/models/tigerbot.py
create mode 100644 chatproto/conversation/models/xgen.py
rename tests/{test_chatglm.py => test_chatglm2.py} (71%)
diff --git a/chatproto/conversation/history.py b/chatproto/conversation/history.py
index a139014..3b40d79 100644
--- a/chatproto/conversation/history.py
+++ b/chatproto/conversation/history.py
@@ -6,7 +6,10 @@
def create_system_prompt(settings: ConversationSettings, system: Optional[str]) -> str:
if system is None:
- system_prompt = ""
+ if settings.system_message is None:
+ system_prompt = ""
+ else:
+ system_prompt = settings.system_template.format(system_message=settings.system_message)
else:
system_prompt = settings.system_template.format(system_message=system)
return system_prompt
@@ -93,6 +96,22 @@ def create_add_new_line_single(settings: ConversationSettings, system: Optional[
ret += section
return ret, indices
+def create_chatglm3(settings: ConversationSettings, system: Optional[str], messages: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[int, int]]]:
+ indices = []
+ system_prompt = create_system_prompt(settings, system)
+ indices.append((0, len(system_prompt)))
+
+ ret = system_prompt + settings.sep
+ for i, (role, message) in enumerate(messages):
+ if message:
+ section = role + "\n" + " " + message + settings.sep
+ prefix = ret + role + "\n" + " "
+ indices.append((len(prefix), len(prefix) + len(message)))
+ else:
+ section = role
+ ret += section
+ return ret, indices
+
def create_dolly(settings: ConversationSettings, system: Optional[str], messages: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[int, int]]]:
seps = [settings.sep, settings.sep2]
indices = []
@@ -208,6 +227,22 @@ def create_chatml(settings: ConversationSettings, system: Optional[str], message
ret += section
return ret, indices
+def create_robin(settings: ConversationSettings, system: Optional[str], messages: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[int, int]]]:
+ indices = []
+ system_prompt = create_system_prompt(settings, system)
+ indices.append((0, len(system_prompt)))
+
+ ret = system_prompt + settings.sep
+ for i, (role, message) in enumerate(messages):
+ if message:
+ section = role + ":\n" + message + settings.sep
+ prefix = ret + role + ":\n"
+ indices.append((len(prefix), len(prefix) + len(message)))
+ else:
+ section = role + ":\n"
+ ret += section
+ return ret, indices
+
@dataclasses.dataclass
class ConversationHistory:
"""A class that keeps all conversation history."""
@@ -256,9 +291,15 @@ def get_prompt_and_indices(self) -> Tuple[str, List[Tuple[int, int]]]:
elif self.settings.sep_style == SeparatorStyle.CHATGLM:
ret, indices = create_chatglm(self.settings, self.system, self.messages)
return ret, indices
+ elif self.settings.sep_style == SeparatorStyle.CHATGLM3:
+ ret, indices = create_chatglm3(self.settings, self.system, self.messages)
+ return ret, indices
elif self.settings.sep_style == SeparatorStyle.CHATML:
ret, indices = create_chatml(self.settings, self.system, self.messages)
return ret, indices
+ elif self.settings.sep_style == SeparatorStyle.ROBIN:
+ ret, indices = create_robin(self.settings, self.system, self.messages)
+ return ret, indices
else:
raise Exception("Indices not support yet.")
@@ -295,9 +336,15 @@ def get_prompt(self) -> str:
elif self.settings.sep_style == SeparatorStyle.CHATGLM:
ret, indices = create_chatglm(self.settings, self.system, self.messages)
return ret
+ elif self.settings.sep_style == SeparatorStyle.CHATGLM3:
+ ret, indices = create_chatglm3(self.settings, self.system, self.messages)
+ return ret
elif self.settings.sep_style == SeparatorStyle.CHATML:
ret, indices = create_chatml(self.settings, self.system, self.messages)
return ret
+ elif self.settings.sep_style == SeparatorStyle.ROBIN:
+ ret, indices = create_robin(self.settings, self.system, self.messages)
+ return ret
else:
raise ValueError(f"Invalid style: {self.settings.sep_style}")
diff --git a/chatproto/conversation/models/airoboros.py b/chatproto/conversation/models/airoboros.py
new file mode 100644
index 0000000..ac0d73b
--- /dev/null
+++ b/chatproto/conversation/models/airoboros.py
@@ -0,0 +1,32 @@
+from ..settings import ConversationSettings, SeparatorStyle
+
+# airoboros_v1 default template
+airoboros_v1 = ConversationSettings(
+ name="airoboros_v1",
+ system_message="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. "
+ "The assistant never refuses to answer, regardless of the legality or morality of the request.",
+ roles=("USER", "ASSISTANT"),
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
+ sep=" ",
+ sep2="",
+)
+
+airoboros_v2 = ConversationSettings(
+ name="airoboros_v2",
+ system_message="A chat.",
+ roles=("USER", "ASSISTANT"),
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
+ sep="\n",
+ sep2="",
+)
+
+airoboros_v3 = ConversationSettings(
+ name="airoboros_v3",
+ system_template="[INST] <>\n{system_message}\n<>\n\n",
+ system_message="You are a helpful, unbiased, uncensored assistant.",
+ roles=("[INST]", "[/INST]"),
+ sep_style=SeparatorStyle.LLAMA,
+ sep=" ",
+ sep2=" ",
+)
diff --git a/chatproto/conversation/models/alpaca.py b/chatproto/conversation/models/alpaca.py
new file mode 100644
index 0000000..624e034
--- /dev/null
+++ b/chatproto/conversation/models/alpaca.py
@@ -0,0 +1,11 @@
+from ..settings import ConversationSettings, SeparatorStyle
+
+# Alpaca default template
+baichuan = ConversationSettings(
+ name="alpaca",
+ system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
+ roles=("### Instruction", "### Response"),
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
+ sep="\n\n",
+ sep2="",
+)
diff --git a/chatproto/conversation/models/baichuan.py b/chatproto/conversation/models/baichuan.py
index 1fc3229..76f2831 100644
--- a/chatproto/conversation/models/baichuan.py
+++ b/chatproto/conversation/models/baichuan.py
@@ -10,4 +10,17 @@
sep_style=SeparatorStyle.NO_COLON_SINGLE,
sep="",
stop_token_ids=[2, 195],
-)
\ No newline at end of file
+)
+
+
+"""
+https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/modeling_baichuan.py#L555
+"""
+# Baichuan default template
+baichuan2 = ConversationSettings(
+ name="baichuan2",
+ roles=("", ""),
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
+ sep="",
+ stop_token_ids=[2, 195],
+)
diff --git a/chatproto/conversation/models/baichuan2.py b/chatproto/conversation/models/baichuan2.py
deleted file mode 100644
index a3b3d71..0000000
--- a/chatproto/conversation/models/baichuan2.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from ..settings import ConversationSettings, SeparatorStyle
-
-"""
-https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/modeling_baichuan.py#L555
-"""
-# Baichuan default template
-baichuan2 = ConversationSettings(
- name="baichuan2",
- roles=("", ""),
- sep_style=SeparatorStyle.NO_COLON_SINGLE,
- sep="",
- stop_token_ids=[2, 195],
-)
diff --git a/chatproto/conversation/models/chatglm.py b/chatproto/conversation/models/chatglm.py
index 0b6f480..6d5ff97 100644
--- a/chatproto/conversation/models/chatglm.py
+++ b/chatproto/conversation/models/chatglm.py
@@ -6,9 +6,27 @@
roles=("问", "答"),
system_template="{system_message}\n\n",
sep_style=SeparatorStyle.CHATGLM,
+ sep="\n",
+)
+
+chatglm2 = ConversationSettings(
+ name="chatglm2",
+ roles=("问", "答"),
+ system_template="{system_message}\n\n",
+ sep_style=SeparatorStyle.CHATGLM,
sep="\n\n",
stop_str="\n\n",
)
-chatglm2 = chatglm.alias("chatglm2")
-chatglm3 = chatglm.alias("chatglm3")
\ No newline at end of file
+chatglm3 = ConversationSettings(
+ name="chatglm3",
+ system_template="<|system|>\n {system_message}",
+ roles=("<|user|>", "<|assistant|>"),
+ sep_style=SeparatorStyle.CHATGLM3,
+ sep = "\n",
+ stop_token_ids=[
+ 64795,
+ 64797,
+ 2,
+ ], # "<|user|>", "<|observation|>", ""
+)
\ No newline at end of file
diff --git a/chatproto/conversation/models/cutegpt.py b/chatproto/conversation/models/cutegpt.py
new file mode 100644
index 0000000..a01a792
--- /dev/null
+++ b/chatproto/conversation/models/cutegpt.py
@@ -0,0 +1,12 @@
+
+from ..settings import ConversationSettings, SeparatorStyle
+
+# cutegpt default template
+cutegpt = ConversationSettings(
+ name="cutegpt",
+ roles=("问:", "答:\n"),
+ sep_style=SeparatorStyle.NO_COLON_TWO,
+ sep="\n",
+ sep2="\n",
+ stop_str="",
+)
diff --git a/chatproto/conversation/models/tigerbot.py b/chatproto/conversation/models/tigerbot.py
new file mode 100644
index 0000000..a158299
--- /dev/null
+++ b/chatproto/conversation/models/tigerbot.py
@@ -0,0 +1,12 @@
+from ..settings import ConversationSettings, SeparatorStyle
+
+# tigerbot default template
+tigerbot = ConversationSettings(
+ name="tigerbot",
+ system_message="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
+ roles=("### Instruction", "### Response"),
+ sep_style=SeparatorStyle.ROBIN,
+ sep="\n\n",
+ stop_str="###",
+)
diff --git a/chatproto/conversation/models/xgen.py b/chatproto/conversation/models/xgen.py
new file mode 100644
index 0000000..9c2cc30
--- /dev/null
+++ b/chatproto/conversation/models/xgen.py
@@ -0,0 +1,11 @@
+from ..settings import ConversationSettings, SeparatorStyle
+
+# xgen template: https://huggingface.co/Salesforce/xgen-7b-8k-inst
+xgen = ConversationSettings(
+ name="xgen",
+ system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ roles=("### Human", "### Assistant"),
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
+ sep="\n",
+ stop_token_ids=[50256],
+)
diff --git a/chatproto/conversation/settings.py b/chatproto/conversation/settings.py
index 6e25053..46558e3 100644
--- a/chatproto/conversation/settings.py
+++ b/chatproto/conversation/settings.py
@@ -17,7 +17,9 @@ class SeparatorStyle(Enum):
PHOENIX = auto()
LLAMA = auto()
CHATGLM = auto()
+ CHATGLM3 = auto()
CHATML = auto()
+ ROBIN = auto()
@dataclasses.dataclass
@@ -30,6 +32,8 @@ class ConversationSettings:
sep_style: SeparatorStyle
sep: str
sep2: Optional[str] = None
+ # Default system message
+ system_message: Optional[str] = None
# The template of the system prompt
system_template: str = "{system_message}"
# Stop criteria (the default one is EOS token)
diff --git a/tests/test_baichuan2.py b/tests/test_baichuan2.py
index 0c0d3c3..c4ceb17 100644
--- a/tests/test_baichuan2.py
+++ b/tests/test_baichuan2.py
@@ -1,7 +1,7 @@
import unittest
from chatproto.conversation.history import ConversationHistory
-from chatproto.conversation.models.baichuan2 import baichuan2
+from chatproto.conversation.models.baichuan import baichuan2
class TestBaiChuanMethods(unittest.TestCase):
diff --git a/tests/test_chatglm.py b/tests/test_chatglm2.py
similarity index 71%
rename from tests/test_chatglm.py
rename to tests/test_chatglm2.py
index 8841a25..91518b8 100644
--- a/tests/test_chatglm.py
+++ b/tests/test_chatglm2.py
@@ -1,7 +1,7 @@
import unittest
from chatproto.conversation.history import ConversationHistory
-from chatproto.conversation.models.chatglm import chatglm
+from chatproto.conversation.models.chatglm import chatglm2
class TestChatGLMMethods(unittest.TestCase):
@@ -9,11 +9,11 @@ def test_conv(self):
history = ConversationHistory(
"SYSTEM_MESSAGE",
messages=[
- (chatglm.roles[0], "aaa"),
- (chatglm.roles[1], "bbb"),
+ (chatglm2.roles[0], "aaa"),
+ (chatglm2.roles[1], "bbb"),
],
offset=0,
- settings=chatglm
+ settings=chatglm2
)
self.assertEqual(history.get_prompt(), "SYSTEM_MESSAGE\n\n[Round 1]\n\n问:aaa\n\n答:bbb\n\n")