Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed May 25, 2024
2 parents ed41e47 + 2e10d95 commit b2cc117
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 20 deletions.
26 changes: 17 additions & 9 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
ALLOWED_MESSAGES_KEYS,
ALLOWED_PROMPT_KEYS,
ALLOWED_RESPONSE_KEYS,
ChatTemplateError,
ConsecutiveRepeatedChatRolesError,
IncorrectMessageKeyQuantityError,
InvalidContentTypeError,
Expand Down Expand Up @@ -245,15 +246,22 @@ def slice_out_last_turn(
messages_through_current_turn: List[Dict[str, str]],
conversation_through_previous_turn: str,
) -> Tuple[str, str]:
full_conversation = tokenizer.apply_chat_template(
messages_through_current_turn,
tokenize=False,
)
prompt_with_history = tokenizer.apply_chat_template(
messages_through_current_turn[:-1],
tokenize=False,
add_generation_prompt=True,
)
try:
full_conversation = tokenizer.apply_chat_template(
messages_through_current_turn,
tokenize=False,
)
prompt_with_history = tokenizer.apply_chat_template(
messages_through_current_turn[:-1],
tokenize=False,
add_generation_prompt=True,
)
except Exception as e:
raise ChatTemplateError(
tokenizer.chat_template,
sample=messages_through_current_turn,
inner_message=str(e),
)
if conversation_through_previous_turn != full_conversation[:len(
conversation_through_previous_turn,
)]:
Expand Down
18 changes: 18 additions & 0 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,24 @@ def __init__(self, repeated_role: str) -> None:
super().__init__(message, repeated_role=repeated_role)


class ChatTemplateError(UserError):
"""Error thrown when a chat template fails to process a sample."""

def __init__(
self,
template: str,
sample: List[Dict[str, Any]],
inner_message: str,
) -> None:
message = f'Failed to process sample {sample} with template {template}. {inner_message}'
super().__init__(
message,
template=template,
sample=sample,
inner_message=inner_message
)


class InvalidLastChatMessageRoleError(UserError):
"""Error thrown when the last message role in a chat example is invalid."""

Expand Down
38 changes: 27 additions & 11 deletions scripts/inference/hf_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
TextStreamer,
)

from llmfoundry.utils.exceptions import ChatTemplateError

DEFAULT_SYSTEM_PROMPT = 'You are a friendly chatbot who aims to be helpful and honest.'


Expand Down Expand Up @@ -125,21 +127,35 @@ def _history_to_chat_conversation(self) -> List[Dict[str, str]]:

def _history_as_formatted_str(self) -> str:
chat_conversation = self._history_to_chat_conversation()
return self.tokenizer.apply_chat_template(
chat_conversation,
tokenize=False,
add_generation_prompt=False,
)
try:
return self.tokenizer.apply_chat_template(
chat_conversation,
tokenize=False,
add_generation_prompt=False,
)
except Exception as e:
raise ChatTemplateError(
inner_message=str(e),
template=self.tokenizer.chat_template,
sample=chat_conversation,
)

def turn(self, user_inp: str) -> None:
self.history.append(ChatMessage('user', user_inp))
chat_conversation = self._history_to_chat_conversation()
tokenized_chat = self.tokenizer.apply_chat_template(
chat_conversation,
tokenize=True,
add_generation_prompt=True,
return_tensors='pt',
)
try:
tokenized_chat = self.tokenizer.apply_chat_template(
chat_conversation,
tokenize=True,
add_generation_prompt=True,
return_tensors='pt',
)
except Exception as e:
raise ChatTemplateError(
inner_message=str(e),
template=self.tokenizer.chat_template,
sample=chat_conversation,
)
tokenized_chat = tokenized_chat.to(self.model.device)
# also stream to stdout
maybe_synchronize()
Expand Down
34 changes: 34 additions & 0 deletions tests/data/test_template_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from llmfoundry.utils.exceptions import (
ALLOWED_PROMPT_KEYS,
ALLOWED_RESPONSE_KEYS,
ChatTemplateError,
)


Expand Down Expand Up @@ -270,6 +271,39 @@ def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool):
assert reconstructed_chat == full_chat


def test_fail_chat_template():
convo = [
{
'role':
'system', # this will fail because the tokenizer doesn't have a system role
'content': 'everyone thinks you are so cool',
},
{
'role': 'user',
'content': 'hiiii',
},
{
'role': 'assistant',
'content': 'yassss',
},
]

example = {'messages': convo}

class DummyTokenizer:

def __init__(self) -> None:
self.chat_template = 'Hello, World!'

def apply_chat_template(self, **_):
raise ValueError('This tokenizer does not support the system role')

tok = DummyTokenizer()

with pytest.raises(ChatTemplateError):
_slice_chat_formatted_example(example, tok)


def test_tokenize_no_labels_bos_pr():
# This tokenizer automatically adds bos tokens
tokenizer = transformers.AutoTokenizer.from_pretrained(
Expand Down
2 changes: 2 additions & 0 deletions tests/utils/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def get_default_value(arg_type: Optional[type] = None):
return None
elif arg_type == type:
return bool
elif arg_type == List[Dict[str, Any]]:
return [{'key': 'value'}]
raise ValueError(f'Unsupported arg type: {arg_type}')

required_args.pop('self', None)
Expand Down

0 comments on commit b2cc117

Please sign in to comment.