From 42fcd8db27a936c29708f36e3058012fdb919258 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 30 May 2024 20:47:27 +0000 Subject: [PATCH] fix allow case insensitive exampple keys tmp fix fix revert fix fix formatting fix --- llmfoundry/data/finetuning/tasks.py | 36 ++++++++++++++++++----------- llmfoundry/utils/exceptions.py | 13 +++++++++++ 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index b7cce4d20a..41b9d4adda 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -35,7 +35,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import logging import os import warnings -from collections.abc import Mapping +from collections.abc import KeysView, Mapping from functools import partial from pathlib import Path from typing import ( @@ -71,6 +71,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: ALLOWED_RESPONSE_KEYS, ChatTemplateError, ConsecutiveRepeatedChatRolesError, + ExampleDatasetKeyCaseError, IncorrectMessageKeyQuantityError, InvalidContentTypeError, InvalidFileExtensionError, @@ -134,22 +135,31 @@ def _get_example_type(example: Example) -> ExampleType: raise TypeError( f'Expected example to be a Mapping, but found {type(example)}', ) - if ( - len(example.keys()) == 1 and any( + + def match_keys(keys: KeysView) -> ExampleType: + if len(keys) == 1 and any( allowed_message_key in example for allowed_message_key in ALLOWED_MESSAGES_KEYS - ) - ): - return 'chat' - elif ( - len(example.keys()) == 2 and - any(p in example for p in ALLOWED_PROMPT_KEYS) and - any(r in example for r in ALLOWED_RESPONSE_KEYS) - ): - return 'prompt_response' - else: + ): + return 'chat' + elif ( + len(example.keys()) == 2 and + any(p in example for p in ALLOWED_PROMPT_KEYS) and + any(r in example for r in ALLOWED_RESPONSE_KEYS) + ): + return 'prompt_response' raise UnknownExampleTypeError(str(example.keys())) + try: + example_type = match_keys(example.keys()) + except UnknownExampleTypeError: + # We try to match the keys in lower case again. + example_lower = {key.lower(): value for key, value in example.items()} + match_keys(example_lower.keys()) + # If there is a match then we let the user know that the keys are case senssitive. + raise ExampleDatasetKeyCaseError(str(example.keys())) + return example_type + def _is_empty_or_nonexistent(dirpath: str) -> bool: """Check if a directory is empty or non-existent. diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 76f378f8c6..ac068f87d4 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -172,6 +172,19 @@ def __init__(self, example_keys: str) -> None: super().__init__(message, example_keys=example_keys) +class ExampleDatasetKeyCaseError(UserError): + """Error thrown when keys in a dataset example are not in lowercase. + + This error checks for keys that could potentially match the expected example types if corrected. + """ + + + def __init__(self, example_keys: str) -> None: + message = ( + f"Found keys {example_keys} in the dataset. All keys in datasets must be in lowercase. " + f"Please ensure all keys are formatted correctly." + ) + super().__init__(message, example_keys=example_keys) class NotEnoughChatDataError(UserError): """Error thrown when there is not enough chat data to train a model."""