From 1623a5055f532313cd0331e17a1fbe543b98e0ab Mon Sep 17 00:00:00 2001
From: NanoCode012 <kevinvong@rocketmail.com>
Date: Thu, 18 Jan 2024 02:44:15 +0900
Subject: [PATCH] feat(prompt): support multiple roles for sharegpt

---
 src/axolotl/prompt_strategies/sharegpt.py | 156 +++++++++++++++++++++-
 src/axolotl/prompters.py                  |  59 ++++++++
 2 files changed, 213 insertions(+), 2 deletions(-)

diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py
index 15bfee8c47..e6a2c19cb3 100644
--- a/src/axolotl/prompt_strategies/sharegpt.py
+++ b/src/axolotl/prompt_strategies/sharegpt.py
@@ -1,10 +1,23 @@
 """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
+import copy
+import logging
 from typing import Any, Dict, Optional
 
 from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
 
-from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
-from axolotl.prompters import ShareGPTPrompterV2
+from axolotl.prompt_tokenizers import (
+    InvalidDataException,
+    ShareGPTPromptTokenizingStrategy,
+    parse_tokenized_to_result,
+    tokenize_prompt_default,
+)
+from axolotl.prompters import (
+    IGNORE_TOKEN_ID,
+    ShareGPTPrompterV2,
+    ShareGPTPrompterV2MultiRole,
+)
+
+LOG = logging.getLogger("axolotl")
 
 
 def register_chatml_template(system_message=None):
@@ -77,6 +90,28 @@ def load_guanaco(tokenizer, cfg):
     )
 
 
+def load_multirole(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
+    conversation = (
+        ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
+    )
+    field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
+    field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
+    strategy = MultiRoleShareGPTPromptTokenizingStrategy(
+        ShareGPTPrompterV2MultiRole(
+            conversation=conversation,
+            role_key_model=field_model,
+            role_key_human=field_human,
+        ),
+        tokenizer,
+        cfg.train_on_inputs,
+        cfg.sequence_len,
+    )
+    if ds_cfg and "strict" in ds_cfg:
+        strategy.strict = ds_cfg["strict"]
+
+    return strategy
+
+
 class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
     """
     basic sharegpt strategy to grab conversations from the sample row
@@ -143,3 +178,120 @@ def get_conversation_thread(self, prompt):
             {"from": role_map[t["role"]], "value": t["content"]} for t in conversations
         ]
         return turns
+
+
+class MultiRoleShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
+    """
+    sharegpt strategy for support of multi-role
+    """
+
+    def tokenize_prompt(self, prompt):
+        # Initial values. We will append to these as we go through the conversation.
+        result, current_len = tokenize_prompt_default()
+        conversation: Conversation = (
+            self.prompter._conversation.copy()  # pylint: disable=protected-access
+        )
+        user, assistant = conversation.roles
+
+        input_roles = {
+            "human",
+            "funcresponse",
+            "funccaller",
+            "tool",
+            "tool_response",
+            user,
+        }
+        output_roles = {"gpt", "tool_caller", assistant}
+
+        # support for custom roles from the dataset, only useful for vicuna style prompts/roles
+        role_remap = []
+        if (
+            conversation.name == "vicuna_v1.1"
+            and "roles" in prompt
+            and len(prompt["roles"]) >= 2
+        ):
+            role_remap = [
+                {"from": conversation.roles[0], "to": prompt["roles"][0]},
+                {"from": conversation.roles[1], "to": prompt["roles"][1]},
+            ]
+
+        try:
+            for _, part in enumerate(
+                self.prompter.build_prompt(self.get_conversation_thread(prompt))
+            ):
+                if not isinstance(part, tuple):
+                    LOG.warning(f"expected tuple, got {part}")
+                    continue
+
+                role, content = part
+
+                # Uses "in" because role contains extra characters
+                input_turn = any(r in role.lower() for r in input_roles)
+                output_turn = any(r in role.lower() for r in output_roles)
+
+                if input_turn:
+                    role = (
+                        role.replace(role_remap[0]["from"], role_remap[0]["to"])
+                        if role_remap
+                        else role
+                    )
+                    turn = role + content
+                    # this is still the user query, we should
+                    if not content.strip():
+                        LOG.warning(f"user turn has empty text: {prompt}")
+                    res = self._tokenize(
+                        turn,
+                        add_eos_token=False,
+                        strip_bos_token=True,
+                    )
+                    if self.train_on_inputs:
+                        labels = copy.deepcopy(res["input_ids"])
+                    else:
+                        # everything from this is masked out from the labels
+                        labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
+                elif output_turn:
+                    role = (
+                        role.replace(role_remap[1]["from"], role_remap[1]["to"])
+                        if role_remap
+                        else role
+                    )
+                    turn = role + content
+                    # this should be the assistant response, should end with an eos token
+                    if not content.strip():
+                        LOG.warning(f"assistant turn has empty text: {prompt}")
+                    add_eos_token = not (
+                        conversation.name == "chatml"
+                        and conversation.sep == self.tokenizer.eos_token
+                    )
+                    res = self._tokenize(
+                        turn,
+                        add_eos_token=add_eos_token,
+                        strip_bos_token=True,
+                    )
+                    role_res = self._tokenize(
+                        role.rstrip(),
+                        add_eos_token=False,
+                        strip_bos_token=True,
+                    )
+                    labels = copy.deepcopy(res["input_ids"])
+                    if not self.train_on_inputs:
+                        # mask out role tokens from the labels
+                        len_role = len(role_res["input_ids"])
+                        labels[:len_role] = [IGNORE_TOKEN_ID] * min(
+                            len_role, len(labels)
+                        )
+                else:
+                    LOG.warning(f"unhandled role: {role}")
+                    continue
+
+                # pylint: disable=duplicate-code
+                result, current_len = parse_tokenized_to_result(
+                    result,
+                    current_len,
+                    res,
+                    labels,
+                    pad_token_id=self.tokenizer.pad_token_id,
+                )
+            return result
+        except (KeyError, AssertionError, IndexError) as err:
+            raise InvalidDataException(str(err)) from err
diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py
index 748db1a162..05bf53510e 100644
--- a/src/axolotl/prompters.py
+++ b/src/axolotl/prompters.py
@@ -355,6 +355,65 @@ def __init__(
         )
 
 
+CONVERSATION_ROLE_FORMAT = {
+    "chatml": "<|im_start|>{ROLE}",
+    "zephyr": "<|{ROLE}|>",
+}
+
+
+class ShareGPTPrompterV2MultiRole(ShareGPTPrompterV2):
+    """
+    An multi-role V2 prompter that generates prompts for the ShareGPT that supports multi-role
+    """
+
+    def _build_result(self, source):
+        if len(source) < 2:
+            # If there isn't a back and forth conversation, ignore it
+            # also happens on the data splitting leaving empty conversations
+            raise IndexError(
+                f"A conversation entry has less than 2 messages :\n{source}"
+            )
+
+        conv = self._conversation.copy()
+
+        # Add the conversation system prompt if provided, otherwise use the default one
+        if source[0]["from"] == "system":
+            conv.set_system_message(source[0]["value"])
+            source.pop(0)
+
+        roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
+
+        try:
+            # Apply prompt templates
+            if source[0]["from"] not in roles:
+                # Skip the first one if it is not from human
+                source = source[1:]
+        except IndexError as err:
+            # sometimes there is a bing or system chat
+            raise err
+
+        conv.messages = []
+        for _, sentence in enumerate(source):
+            from_role = sentence["from"]
+            if from_role in roles:
+                role = roles[from_role]
+            else:
+                if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
+                    raise NotImplementedError(
+                        f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
+                    )
+
+                role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
+                    ROLE=from_role
+                )
+
+            if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
+                LOG.warning(f"Roles did not alternate: {sentence}")
+            conv.append_message(role, sentence["value"])
+
+        return conv.get_turns()
+
+
 class UnsupportedPrompter(Prompter):
     """
     A dummy class for custom prompters