forked from axolotl-ai-cloud/axolotl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
use fastchat conversations template (axolotl-ai-cloud#578)
* use fastchat conversations template * require fastchat (fschat) pip install * handle roles dynamically from conversation * tweak fastchat conversation with a monkeypatch to get individual turns * fix up so it works with multiple conversation styles, and don't strip the turns * fix sharegpt fixture now that we're using a more correct tokenization * use a new prompter and support fastchat conversation type * use sharegpt from prompt strategies now * update docs, add chatml template * add a newline after im_end token * ensure we correctly set system message * update per PR feedback to handle deprecated sharegpt types * don't add duplicate wandb req * make sharegpt fields configurable from yml * llama2 fixes * don't fail fatally when turns are improper
- Loading branch information
Showing
13 changed files
with
324 additions
and
112 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,3 +31,4 @@ scipy | |
scikit-learn==1.2.2 | ||
pynvml | ||
art | ||
fschat==0.2.29 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
""" | ||
monkeypatch to add a get_turns method | ||
""" | ||
|
||
import logging | ||
from typing import Generator, Tuple | ||
|
||
from fastchat.conversation import SeparatorStyle | ||
|
||
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns") | ||
|
||
|
||
def get_prompt(self) -> str: | ||
ret = "" | ||
for role, msg in self.get_turns(): | ||
ret += role + msg | ||
return ret | ||
|
||
|
||
def get_turns( # pylint: disable=too-many-return-statements | ||
self, | ||
) -> Generator[Tuple[str, str], None, None]: | ||
"""Get the prompt for generation.""" | ||
system_prompt = self.system_template.format(system_message=self.system_message) | ||
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: | ||
yield "", system_prompt + self.sep | ||
for role, message in self.messages: | ||
if message: | ||
yield role + ": ", message + self.sep | ||
else: | ||
yield role + ":", "" | ||
return | ||
if self.sep_style == SeparatorStyle.ADD_COLON_TWO: | ||
seps = [self.sep, self.sep2] | ||
yield "", system_prompt + seps[0] | ||
for i, (role, message) in enumerate(self.messages): | ||
if message: | ||
yield role + ": ", message + seps[i % 2] | ||
else: | ||
yield role + ":", "" | ||
return | ||
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: | ||
yield "", system_prompt + self.sep | ||
for role, message in self.messages: | ||
if message: | ||
yield role + ": ", message + self.sep | ||
else: | ||
yield role + ": ", "" # must be end with a space | ||
return | ||
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: | ||
yield "", "" if system_prompt == "" else system_prompt + self.sep | ||
for role, message in self.messages: | ||
if message: | ||
yield role + "\n", message + self.sep | ||
else: | ||
yield role + "\n", "" | ||
return | ||
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE: | ||
yield "", system_prompt | ||
for role, message in self.messages: | ||
if message: | ||
yield role, message + self.sep | ||
else: | ||
yield role, "" | ||
return | ||
if self.sep_style == SeparatorStyle.NO_COLON_TWO: | ||
seps = [self.sep, self.sep2] | ||
yield "", system_prompt | ||
for i, (role, message) in enumerate(self.messages): | ||
if message: | ||
yield role, message + seps[i % 2] | ||
else: | ||
yield role, "" | ||
return | ||
if self.sep_style == SeparatorStyle.RWKV: | ||
yield "", system_prompt | ||
for i, (role, message) in enumerate(self.messages): | ||
if message: | ||
yield role + ": ", message.replace("\r\n", "\n").replace( | ||
"\n\n", "\n" | ||
) + "\n\n" | ||
else: | ||
yield role + ":", "" | ||
return | ||
if self.sep_style == SeparatorStyle.LLAMA2: | ||
seps = [self.sep, self.sep2] | ||
if self.system_message: | ||
yield "", system_prompt | ||
else: | ||
yield "", "[INST] " | ||
for i, (role, message) in enumerate(self.messages[1:]): | ||
if message: | ||
yield role + " ", message + seps[i % 2] | ||
else: | ||
yield role, "" | ||
return | ||
if self.sep_style == SeparatorStyle.CHATGLM: | ||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 | ||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 | ||
round_add_n = 1 if self.name == "chatglm2" else 0 | ||
if system_prompt: | ||
yield "", system_prompt + self.sep | ||
|
||
for i, (role, message) in enumerate(self.messages): | ||
if i % 2 == 0: | ||
yield "", f"[Round {i//2 + round_add_n}]{self.sep}" | ||
|
||
if message: | ||
yield f"{role}:", f"{message}{self.sep}" | ||
else: | ||
yield f"{role}:", "" | ||
return | ||
if self.sep_style == SeparatorStyle.CHATML: | ||
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n" | ||
for role, message in self.messages: | ||
if message: | ||
yield role + "\n", message + self.sep + "\n" | ||
else: | ||
yield role + "\n", "" | ||
return | ||
if self.sep_style == SeparatorStyle.CHATINTERN: | ||
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 | ||
seps = [self.sep, self.sep2] | ||
yield "", system_prompt | ||
for i, (role, message) in enumerate(self.messages): | ||
prefix = "<s>" if i % 2 == 0 else "" | ||
if message: | ||
yield prefix + role + ":", message + seps[i % 2] + "\n" | ||
else: | ||
yield role + ":", "" | ||
return | ||
if self.sep_style == SeparatorStyle.DOLLY: | ||
seps = [self.sep, self.sep2] | ||
yield "", system_prompt | ||
for i, (role, message) in enumerate(self.messages): | ||
if message: | ||
suffix = "\n\n" if i % 2 == 1 else "" | ||
yield role + ":\n", message + seps[i % 2] + suffix | ||
else: | ||
yield role + ":\n", "" | ||
return | ||
if self.sep_style == SeparatorStyle.PHOENIX: | ||
yield "", system_prompt | ||
for role, message in self.messages: | ||
if message: | ||
yield role + ": ", "<s>" + message + "</s>" | ||
else: | ||
yield role + ": " + "<s>", "" | ||
return | ||
if self.sep_style == SeparatorStyle.ROBIN: | ||
yield "", system_prompt + self.sep | ||
for role, message in self.messages: | ||
if message: | ||
yield role + ":\n", message + self.sep | ||
else: | ||
yield role + ":\n", "" | ||
return | ||
if self.sep_style == SeparatorStyle.FALCON_CHAT: | ||
if self.system_message: | ||
yield "", system_prompt + self.sep | ||
for role, message in self.messages: | ||
if message: | ||
yield role + ": ", message + self.sep | ||
else: | ||
yield role + ":", "" | ||
else: | ||
raise ValueError(f"Invalid style: {self.sep_style}") | ||
|
||
|
||
def add_get_turns_to_conversation(): | ||
import fastchat.conversation | ||
|
||
fastchat.conversation.Conversation.get_turns = get_turns | ||
fastchat.conversation.Conversation.get_prompt = get_prompt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.