From 2e10d9536f3a23169bbc787e2c7bff43d997524d Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 24 May 2024 20:44:55 -0400 Subject: [PATCH] add error when chat template fails (#1222) * add error when chat template fails * type * formatting --- llmfoundry/data/finetuning/tasks.py | 26 ++++++++++------ llmfoundry/utils/exceptions.py | 17 ++++++++++- scripts/inference/hf_chat.py | 38 +++++++++++++++++------- tests/data/test_template_tokenization.py | 34 +++++++++++++++++++++ 4 files changed, 94 insertions(+), 21 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 7e8ab7a471..b7cce4d20a 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -69,6 +69,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: ALLOWED_MESSAGES_KEYS, ALLOWED_PROMPT_KEYS, ALLOWED_RESPONSE_KEYS, + ChatTemplateError, ConsecutiveRepeatedChatRolesError, IncorrectMessageKeyQuantityError, InvalidContentTypeError, @@ -245,15 +246,22 @@ def slice_out_last_turn( messages_through_current_turn: List[Dict[str, str]], conversation_through_previous_turn: str, ) -> Tuple[str, str]: - full_conversation = tokenizer.apply_chat_template( - messages_through_current_turn, - tokenize=False, - ) - prompt_with_history = tokenizer.apply_chat_template( - messages_through_current_turn[:-1], - tokenize=False, - add_generation_prompt=True, - ) + try: + full_conversation = tokenizer.apply_chat_template( + messages_through_current_turn, + tokenize=False, + ) + prompt_with_history = tokenizer.apply_chat_template( + messages_through_current_turn[:-1], + tokenize=False, + add_generation_prompt=True, + ) + except Exception as e: + raise ChatTemplateError( + tokenizer.chat_template, + sample=messages_through_current_turn, + inner_message=str(e), + ) if conversation_through_previous_turn != full_conversation[:len( conversation_through_previous_turn, )]: diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 7a34430a21..744a4d7b96 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -76,7 +76,7 @@ def __init__(self) -> None: super().__init__(message) -class NotEnoughDatasetSamplesError(ValueError, UserError): +class NotEnoughDatasetSamplesError(UserError): """Error thrown when there is not enough data to train a model.""" def __init__( @@ -137,6 +137,21 @@ def __init__(self, repeated_role: str) -> None: super().__init__(message) +class ChatTemplateError(ValueError, UserError): + """Error thrown when a chat template fails to process a sample.""" + + def __init__( + self, + template: str, + sample: List[Dict[str, Any]], + inner_message: str, + ) -> None: + self.template = template + self.sample = sample + message = f'Failed to process sample {sample} with template {template}. {inner_message}' + super().__init__(message) + + class InvalidLastChatMessageRoleError(ValueError, UserError): """Error thrown when the last message role in a chat example is invalid.""" diff --git a/scripts/inference/hf_chat.py b/scripts/inference/hf_chat.py index 3657bbe1b0..e992371c32 100644 --- a/scripts/inference/hf_chat.py +++ b/scripts/inference/hf_chat.py @@ -19,6 +19,8 @@ TextStreamer, ) +from llmfoundry.utils.exceptions import ChatTemplateError + DEFAULT_SYSTEM_PROMPT = 'You are a friendly chatbot who aims to be helpful and honest.' @@ -125,21 +127,35 @@ def _history_to_chat_conversation(self) -> List[Dict[str, str]]: def _history_as_formatted_str(self) -> str: chat_conversation = self._history_to_chat_conversation() - return self.tokenizer.apply_chat_template( - chat_conversation, - tokenize=False, - add_generation_prompt=False, - ) + try: + return self.tokenizer.apply_chat_template( + chat_conversation, + tokenize=False, + add_generation_prompt=False, + ) + except Exception as e: + raise ChatTemplateError( + inner_message=str(e), + template=self.tokenizer.chat_template, + sample=chat_conversation, + ) def turn(self, user_inp: str) -> None: self.history.append(ChatMessage('user', user_inp)) chat_conversation = self._history_to_chat_conversation() - tokenized_chat = self.tokenizer.apply_chat_template( - chat_conversation, - tokenize=True, - add_generation_prompt=True, - return_tensors='pt', - ) + try: + tokenized_chat = self.tokenizer.apply_chat_template( + chat_conversation, + tokenize=True, + add_generation_prompt=True, + return_tensors='pt', + ) + except Exception as e: + raise ChatTemplateError( + inner_message=str(e), + template=self.tokenizer.chat_template, + sample=chat_conversation, + ) tokenized_chat = tokenized_chat.to(self.model.device) # also stream to stdout maybe_synchronize() diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 785a124379..16447d6623 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -15,6 +15,7 @@ from llmfoundry.utils.exceptions import ( ALLOWED_PROMPT_KEYS, ALLOWED_RESPONSE_KEYS, + ChatTemplateError, ) @@ -270,6 +271,39 @@ def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool): assert reconstructed_chat == full_chat +def test_fail_chat_template(): + convo = [ + { + 'role': + 'system', # this will fail because the tokenizer doesn't have a system role + 'content': 'everyone thinks you are so cool', + }, + { + 'role': 'user', + 'content': 'hiiii', + }, + { + 'role': 'assistant', + 'content': 'yassss', + }, + ] + + example = {'messages': convo} + + class DummyTokenizer: + + def __init__(self) -> None: + self.chat_template = 'Hello, World!' + + def apply_chat_template(self, **_): + raise ValueError('This tokenizer does not support the system role') + + tok = DummyTokenizer() + + with pytest.raises(ChatTemplateError): + _slice_chat_formatted_example(example, tok) + + def test_tokenize_no_labels_bos_pr(): # This tokenizer automatically adds bos tokens tokenizer = transformers.AutoTokenizer.from_pretrained(