diff --git a/requirements.txt b/requirements.txt index 9289a40f39..26525be15d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,7 +28,7 @@ scipy scikit-learn==1.2.2 pynvml art -fschat==0.2.36 +fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8 gradio==3.50.2 tensorboard diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py index d09ab5075d..7ab07d4854 100644 --- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py @@ -123,6 +123,14 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role, "" return + if self.sep_style == SeparatorStyle.GEMMA: + if self.system_message: + raise ValueError("Gemma chat template does not support system messages") + for i, (role, message) in enumerate(self.messages): + prefix = "" if i == 0 else "" + message_str = message if message else "" + yield prefix + "" + role + "\n", message_str + "\n" + return if self.sep_style == SeparatorStyle.CHATGLM: # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926