Skip to content

Commit

Permalink
handle roles dynamically from conversation
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 15, 2023
1 parent 712dd8b commit 718e844
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
8 changes: 6 additions & 2 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
from typing import Dict, List, Tuple, Union

from fastchat.conversation import Conversation
from transformers import PreTrainedTokenizer

from axolotl.prompters import IGNORE_TOKEN_ID
Expand Down Expand Up @@ -357,12 +358,15 @@ def tokenize_prompt(self, prompt):
result, current_len = tokenize_prompt_default()
user_token = self._get_user_token()
assistant_token = self._get_assistant_token()
conversation: Conversation = (
self.prompter._conversation # pylint: disable=protected-access
)
try:
for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if isinstance(part, tuple):
if part[0] == "USER:":
if part[0] == conversation.roles[0]:
part = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should
res = self._tokenize(
Expand All @@ -374,7 +378,7 @@ def tokenize_prompt(self, prompt):
res["input_ids"] = [user_token, *res["input_ids"]]
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif part[0] == "ASSISTANT:":
elif part[0] == conversation.roles[1]:
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
part = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistent response, should end with an eos token
Expand Down
10 changes: 1 addition & 9 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module containing prompters"""

import logging
from enum import Enum, auto
from enum import Enum
from typing import Generator, Optional, Union

from fastchat.conversation import Conversation, get_conv_template
Expand Down Expand Up @@ -229,14 +229,6 @@ def build_prompt(
yield res


class SeparatorStyle(Enum):
"""Different separator style."""

SINGLE = auto()
TWO = auto()
DOLLY = auto()


SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
Expand Down

0 comments on commit 718e844

Please sign in to comment.