From ddb3ab24bfc3f2eadfa6027aeb0419b3c22d7f2e Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Wed, 3 Jan 2024 16:25:35 +0800 Subject: [PATCH] airoboros, cutegpt, alpaca, tigerbot and xgen --- chatproto/conversation/history.py | 49 ++++++++++++++++++++- chatproto/conversation/models/airoboros.py | 32 ++++++++++++++ chatproto/conversation/models/alpaca.py | 11 +++++ chatproto/conversation/models/baichuan.py | 15 ++++++- chatproto/conversation/models/baichuan2.py | 13 ------ chatproto/conversation/models/chatglm.py | 22 ++++++++- chatproto/conversation/models/cutegpt.py | 12 +++++ chatproto/conversation/models/tigerbot.py | 12 +++++ chatproto/conversation/models/xgen.py | 11 +++++ chatproto/conversation/settings.py | 4 ++ tests/test_baichuan2.py | 2 +- tests/{test_chatglm.py => test_chatglm2.py} | 8 ++-- 12 files changed, 169 insertions(+), 22 deletions(-) create mode 100644 chatproto/conversation/models/airoboros.py create mode 100644 chatproto/conversation/models/alpaca.py delete mode 100644 chatproto/conversation/models/baichuan2.py create mode 100644 chatproto/conversation/models/cutegpt.py create mode 100644 chatproto/conversation/models/tigerbot.py create mode 100644 chatproto/conversation/models/xgen.py rename tests/{test_chatglm.py => test_chatglm2.py} (71%) diff --git a/chatproto/conversation/history.py b/chatproto/conversation/history.py index a139014..3b40d79 100644 --- a/chatproto/conversation/history.py +++ b/chatproto/conversation/history.py @@ -6,7 +6,10 @@ def create_system_prompt(settings: ConversationSettings, system: Optional[str]) -> str: if system is None: - system_prompt = "" + if settings.system_message is None: + system_prompt = "" + else: + system_prompt = settings.system_template.format(system_message=settings.system_message) else: system_prompt = settings.system_template.format(system_message=system) return system_prompt @@ -93,6 +96,22 @@ def create_add_new_line_single(settings: ConversationSettings, system: Optional[ ret += section return ret, indices +def create_chatglm3(settings: ConversationSettings, system: Optional[str], messages: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[int, int]]]: + indices = [] + system_prompt = create_system_prompt(settings, system) + indices.append((0, len(system_prompt))) + + ret = system_prompt + settings.sep + for i, (role, message) in enumerate(messages): + if message: + section = role + "\n" + " " + message + settings.sep + prefix = ret + role + "\n" + " " + indices.append((len(prefix), len(prefix) + len(message))) + else: + section = role + ret += section + return ret, indices + def create_dolly(settings: ConversationSettings, system: Optional[str], messages: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[int, int]]]: seps = [settings.sep, settings.sep2] indices = [] @@ -208,6 +227,22 @@ def create_chatml(settings: ConversationSettings, system: Optional[str], message ret += section return ret, indices +def create_robin(settings: ConversationSettings, system: Optional[str], messages: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[int, int]]]: + indices = [] + system_prompt = create_system_prompt(settings, system) + indices.append((0, len(system_prompt))) + + ret = system_prompt + settings.sep + for i, (role, message) in enumerate(messages): + if message: + section = role + ":\n" + message + settings.sep + prefix = ret + role + ":\n" + indices.append((len(prefix), len(prefix) + len(message))) + else: + section = role + ":\n" + ret += section + return ret, indices + @dataclasses.dataclass class ConversationHistory: """A class that keeps all conversation history.""" @@ -256,9 +291,15 @@ def get_prompt_and_indices(self) -> Tuple[str, List[Tuple[int, int]]]: elif self.settings.sep_style == SeparatorStyle.CHATGLM: ret, indices = create_chatglm(self.settings, self.system, self.messages) return ret, indices + elif self.settings.sep_style == SeparatorStyle.CHATGLM3: + ret, indices = create_chatglm3(self.settings, self.system, self.messages) + return ret, indices elif self.settings.sep_style == SeparatorStyle.CHATML: ret, indices = create_chatml(self.settings, self.system, self.messages) return ret, indices + elif self.settings.sep_style == SeparatorStyle.ROBIN: + ret, indices = create_robin(self.settings, self.system, self.messages) + return ret, indices else: raise Exception("Indices not support yet.") @@ -295,9 +336,15 @@ def get_prompt(self) -> str: elif self.settings.sep_style == SeparatorStyle.CHATGLM: ret, indices = create_chatglm(self.settings, self.system, self.messages) return ret + elif self.settings.sep_style == SeparatorStyle.CHATGLM3: + ret, indices = create_chatglm3(self.settings, self.system, self.messages) + return ret elif self.settings.sep_style == SeparatorStyle.CHATML: ret, indices = create_chatml(self.settings, self.system, self.messages) return ret + elif self.settings.sep_style == SeparatorStyle.ROBIN: + ret, indices = create_robin(self.settings, self.system, self.messages) + return ret else: raise ValueError(f"Invalid style: {self.settings.sep_style}") diff --git a/chatproto/conversation/models/airoboros.py b/chatproto/conversation/models/airoboros.py new file mode 100644 index 0000000..ac0d73b --- /dev/null +++ b/chatproto/conversation/models/airoboros.py @@ -0,0 +1,32 @@ +from ..settings import ConversationSettings, SeparatorStyle + +# airoboros_v1 default template +airoboros_v1 = ConversationSettings( + name="airoboros_v1", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. " + "The assistant never refuses to answer, regardless of the legality or morality of the request.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep=" ", + sep2="", +) + +airoboros_v2 = ConversationSettings( + name="airoboros_v2", + system_message="A chat.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="", +) + +airoboros_v3 = ConversationSettings( + name="airoboros_v3", + system_template="[INST] <>\n{system_message}\n<>\n\n", + system_message="You are a helpful, unbiased, uncensored assistant.", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA, + sep=" ", + sep2=" ", +) diff --git a/chatproto/conversation/models/alpaca.py b/chatproto/conversation/models/alpaca.py new file mode 100644 index 0000000..624e034 --- /dev/null +++ b/chatproto/conversation/models/alpaca.py @@ -0,0 +1,11 @@ +from ..settings import ConversationSettings, SeparatorStyle + +# Alpaca default template +baichuan = ConversationSettings( + name="alpaca", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n\n", + sep2="", +) diff --git a/chatproto/conversation/models/baichuan.py b/chatproto/conversation/models/baichuan.py index 1fc3229..76f2831 100644 --- a/chatproto/conversation/models/baichuan.py +++ b/chatproto/conversation/models/baichuan.py @@ -10,4 +10,17 @@ sep_style=SeparatorStyle.NO_COLON_SINGLE, sep="", stop_token_ids=[2, 195], -) \ No newline at end of file +) + + +""" +https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/modeling_baichuan.py#L555 +""" +# Baichuan default template +baichuan2 = ConversationSettings( + name="baichuan2", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_token_ids=[2, 195], +) diff --git a/chatproto/conversation/models/baichuan2.py b/chatproto/conversation/models/baichuan2.py deleted file mode 100644 index a3b3d71..0000000 --- a/chatproto/conversation/models/baichuan2.py +++ /dev/null @@ -1,13 +0,0 @@ -from ..settings import ConversationSettings, SeparatorStyle - -""" -https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/modeling_baichuan.py#L555 -""" -# Baichuan default template -baichuan2 = ConversationSettings( - name="baichuan2", - roles=("", ""), - sep_style=SeparatorStyle.NO_COLON_SINGLE, - sep="", - stop_token_ids=[2, 195], -) diff --git a/chatproto/conversation/models/chatglm.py b/chatproto/conversation/models/chatglm.py index 0b6f480..6d5ff97 100644 --- a/chatproto/conversation/models/chatglm.py +++ b/chatproto/conversation/models/chatglm.py @@ -6,9 +6,27 @@ roles=("问", "答"), system_template="{system_message}\n\n", sep_style=SeparatorStyle.CHATGLM, + sep="\n", +) + +chatglm2 = ConversationSettings( + name="chatglm2", + roles=("问", "答"), + system_template="{system_message}\n\n", + sep_style=SeparatorStyle.CHATGLM, sep="\n\n", stop_str="\n\n", ) -chatglm2 = chatglm.alias("chatglm2") -chatglm3 = chatglm.alias("chatglm3") \ No newline at end of file +chatglm3 = ConversationSettings( + name="chatglm3", + system_template="<|system|>\n {system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATGLM3, + sep = "\n", + stop_token_ids=[ + 64795, + 64797, + 2, + ], # "<|user|>", "<|observation|>", "" +) \ No newline at end of file diff --git a/chatproto/conversation/models/cutegpt.py b/chatproto/conversation/models/cutegpt.py new file mode 100644 index 0000000..a01a792 --- /dev/null +++ b/chatproto/conversation/models/cutegpt.py @@ -0,0 +1,12 @@ + +from ..settings import ConversationSettings, SeparatorStyle + +# cutegpt default template +cutegpt = ConversationSettings( + name="cutegpt", + roles=("问:", "答:\n"), + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="\n", + sep2="\n", + stop_str="", +) diff --git a/chatproto/conversation/models/tigerbot.py b/chatproto/conversation/models/tigerbot.py new file mode 100644 index 0000000..a158299 --- /dev/null +++ b/chatproto/conversation/models/tigerbot.py @@ -0,0 +1,12 @@ +from ..settings import ConversationSettings, SeparatorStyle + +# tigerbot default template +tigerbot = ConversationSettings( + name="tigerbot", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.ROBIN, + sep="\n\n", + stop_str="###", +) diff --git a/chatproto/conversation/models/xgen.py b/chatproto/conversation/models/xgen.py new file mode 100644 index 0000000..9c2cc30 --- /dev/null +++ b/chatproto/conversation/models/xgen.py @@ -0,0 +1,11 @@ +from ..settings import ConversationSettings, SeparatorStyle + +# xgen template: https://huggingface.co/Salesforce/xgen-7b-8k-inst +xgen = ConversationSettings( + name="xgen", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("### Human", "### Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + stop_token_ids=[50256], +) diff --git a/chatproto/conversation/settings.py b/chatproto/conversation/settings.py index 6e25053..46558e3 100644 --- a/chatproto/conversation/settings.py +++ b/chatproto/conversation/settings.py @@ -17,7 +17,9 @@ class SeparatorStyle(Enum): PHOENIX = auto() LLAMA = auto() CHATGLM = auto() + CHATGLM3 = auto() CHATML = auto() + ROBIN = auto() @dataclasses.dataclass @@ -30,6 +32,8 @@ class ConversationSettings: sep_style: SeparatorStyle sep: str sep2: Optional[str] = None + # Default system message + system_message: Optional[str] = None # The template of the system prompt system_template: str = "{system_message}" # Stop criteria (the default one is EOS token) diff --git a/tests/test_baichuan2.py b/tests/test_baichuan2.py index 0c0d3c3..c4ceb17 100644 --- a/tests/test_baichuan2.py +++ b/tests/test_baichuan2.py @@ -1,7 +1,7 @@ import unittest from chatproto.conversation.history import ConversationHistory -from chatproto.conversation.models.baichuan2 import baichuan2 +from chatproto.conversation.models.baichuan import baichuan2 class TestBaiChuanMethods(unittest.TestCase): diff --git a/tests/test_chatglm.py b/tests/test_chatglm2.py similarity index 71% rename from tests/test_chatglm.py rename to tests/test_chatglm2.py index 8841a25..91518b8 100644 --- a/tests/test_chatglm.py +++ b/tests/test_chatglm2.py @@ -1,7 +1,7 @@ import unittest from chatproto.conversation.history import ConversationHistory -from chatproto.conversation.models.chatglm import chatglm +from chatproto.conversation.models.chatglm import chatglm2 class TestChatGLMMethods(unittest.TestCase): @@ -9,11 +9,11 @@ def test_conv(self): history = ConversationHistory( "SYSTEM_MESSAGE", messages=[ - (chatglm.roles[0], "aaa"), - (chatglm.roles[1], "bbb"), + (chatglm2.roles[0], "aaa"), + (chatglm2.roles[1], "bbb"), ], offset=0, - settings=chatglm + settings=chatglm2 ) self.assertEqual(history.get_prompt(), "SYSTEM_MESSAGE\n\n[Round 1]\n\n问:aaa\n\n答:bbb\n\n")