From baff2e35d6af2765a21b20bdec8305a27ee66553 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Tue, 7 May 2024 21:12:37 +0000 Subject: [PATCH] add context to errors --- llmfoundry/utils/exceptions.py | 50 ++++++++++++++++++++-------------- scripts/train/train.py | 7 +++-- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 8e9e46a1cf..03e29811e2 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -3,7 +3,7 @@ """Custom exceptions for the LLMFoundry.""" from collections.abc import Mapping -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional __all__ = [ 'ALLOWED_RESPONSE_KEYS', @@ -35,9 +35,17 @@ ALLOWED_PROMPT_KEYS = {'prompt'} ALLOWED_MESSAGES_KEYS = {'messages'} +ErrorContext = 'TrainContext' | 'EvalContext' + + +class ContextualError(BaseException): + """Error thrown when an error occurs in the context of a specific task.""" + + context: Optional[ErrorContext] = None + # Finetuning dataloader exceptions -class MissingHuggingFaceURLSplitError(ValueError): +class MissingHuggingFaceURLSplitError(ValueError, ContextualError): """Error thrown when there's no split used in HF dataset config.""" def __init__(self) -> None: @@ -46,7 +54,7 @@ def __init__(self) -> None: super().__init__(message) -class NotEnoughDatasetSamplesError(ValueError): +class NotEnoughDatasetSamplesError(ValueError, ContextualError): """Error thrown when there is not enough data to train a model.""" def __init__( @@ -76,7 +84,7 @@ def __init__( ## Tasks exceptions -class UnknownExampleTypeError(KeyError): +class UnknownExampleTypeError(KeyError, ContextualError): """Error thrown when an unknown example type is used in a task.""" def __init__(self, example: Mapping) -> None: @@ -89,7 +97,7 @@ def __init__(self, example: Mapping) -> None: super().__init__(message) -class NotEnoughChatDataError(ValueError): +class NotEnoughChatDataError(ValueError, ContextualError): """Error thrown when there is not enough chat data to train a model.""" def __init__(self) -> None: @@ -97,7 +105,7 @@ def __init__(self) -> None: super().__init__(message) -class ConsecutiveRepeatedChatRolesError(ValueError): +class ConsecutiveRepeatedChatRolesError(ValueError, ContextualError): """Error thrown when there are consecutive repeated chat roles.""" def __init__(self, repeated_role: str) -> None: @@ -106,7 +114,7 @@ def __init__(self, repeated_role: str) -> None: super().__init__(message) -class InvalidLastChatMessageRoleError(ValueError): +class InvalidLastChatMessageRoleError(ValueError, ContextualError): """Error thrown when the last message role in a chat example is invalid.""" def __init__(self, last_role: str, expected_roles: set[str]) -> None: @@ -116,7 +124,7 @@ def __init__(self, last_role: str, expected_roles: set[str]) -> None: super().__init__(message) -class IncorrectMessageKeyQuantityError(ValueError): +class IncorrectMessageKeyQuantityError(ValueError, ContextualError): """Error thrown when a message has an incorrect number of keys.""" def __init__(self, keys: List[str]) -> None: @@ -125,7 +133,7 @@ def __init__(self, keys: List[str]) -> None: super().__init__(message) -class InvalidRoleError(ValueError): +class InvalidRoleError(ValueError, ContextualError): """Error thrown when a role is invalid.""" def __init__(self, role: str, valid_roles: set[str]) -> None: @@ -135,7 +143,7 @@ def __init__(self, role: str, valid_roles: set[str]) -> None: super().__init__(message) -class InvalidContentTypeError(TypeError): +class InvalidContentTypeError(TypeError, ContextualError): """Error thrown when the content type is invalid.""" def __init__(self, content_type: type) -> None: @@ -144,7 +152,7 @@ def __init__(self, content_type: type) -> None: super().__init__(message) -class InvalidPromptTypeError(TypeError): +class InvalidPromptTypeError(TypeError, ContextualError): """Error thrown when the prompt type is invalid.""" def __init__(self, prompt_type: type) -> None: @@ -153,7 +161,7 @@ def __init__(self, prompt_type: type) -> None: super().__init__(message) -class InvalidResponseTypeError(TypeError): +class InvalidResponseTypeError(TypeError, ContextualError): """Error thrown when the response type is invalid.""" def __init__(self, response_type: type) -> None: @@ -162,7 +170,7 @@ def __init__(self, response_type: type) -> None: super().__init__(message) -class InvalidPromptResponseKeysError(ValueError): +class InvalidPromptResponseKeysError(ValueError, ContextualError): """Error thrown when missing expected prompt and response keys.""" def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]): @@ -171,7 +179,7 @@ def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]): super().__init__(message) -class InvalidFileExtensionError(FileNotFoundError): +class InvalidFileExtensionError(FileNotFoundError, ContextualError): """Error thrown when a file extension is not a safe extension.""" def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None: @@ -184,7 +192,7 @@ def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None: super().__init__(message) -class UnableToProcessPromptResponseError(ValueError): +class UnableToProcessPromptResponseError(ValueError, ContextualError): """Error thrown when a prompt and response cannot be processed.""" def __init__(self, input: Dict) -> None: @@ -194,7 +202,7 @@ def __init__(self, input: Dict) -> None: ## Convert Delta to JSON exceptions -class ClusterDoesNotExistError(ValueError): +class ClusterDoesNotExistError(ValueError, ContextualError): """Error thrown when the cluster does not exist.""" def __init__(self, cluster_id: str) -> None: @@ -203,7 +211,7 @@ def __init__(self, cluster_id: str) -> None: super().__init__(message) -class FailedToCreateSQLConnectionError(RuntimeError): +class FailedToCreateSQLConnectionError(RuntimeError, ContextualError): """Error thrown when client can't sql connect to Databricks.""" def __init__(self) -> None: @@ -211,7 +219,7 @@ def __init__(self) -> None: super().__init__(message) -class FailedToConnectToDatabricksError(RuntimeError): +class FailedToConnectToDatabricksError(RuntimeError, ContextualError): """Error thrown when the client fails to connect to Databricks.""" def __init__(self) -> None: @@ -220,7 +228,7 @@ def __init__(self) -> None: ## Convert Text to MDS exceptions -class InputFolderMissingDataError(ValueError): +class InputFolderMissingDataError(ValueError, ContextualError): """Error thrown when the input folder is missing data.""" def __init__(self, input_folder: str) -> None: @@ -229,7 +237,7 @@ def __init__(self, input_folder: str) -> None: super().__init__(message) -class OutputFolderNotEmptyError(FileExistsError): +class OutputFolderNotEmptyError(FileExistsError, ContextualError): """Error thrown when the output folder is not empty.""" def __init__(self, output_folder: str) -> None: @@ -238,7 +246,7 @@ def __init__(self, output_folder: str) -> None: super().__init__(message) -class MisconfiguredHfDatasetError(ValueError): +class MisconfiguredHfDatasetError(ValueError, ContextualError): """Error thrown when a HuggingFace dataset is misconfigured.""" def __init__(self, dataset_name: str, split: str) -> None: diff --git a/scripts/train/train.py b/scripts/train/train.py index e8f5b8220a..e0cd01e7f4 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -29,6 +29,7 @@ log_train_analytics, maybe_create_mosaicml_logger, ) +from llmfoundry.utils.exceptions import ContextualError install() @@ -391,7 +392,8 @@ def main(cfg: DictConfig) -> Trainer: tokenizer, train_cfg.device_train_batch_size, ) - except Exception as e: + except ContextualError as e: + e.context = 'TrainContext' if mosaicml_logger is not None: mosaicml_logger.log_exception(e) raise e @@ -467,8 +469,9 @@ def main(cfg: DictConfig) -> Trainer: evaluators, non_icl_metrics, ) - except Exception as e: + except ContextualError as e: if mosaicml_logger is not None: + e.context = 'EvalContext' mosaicml_logger.log_exception(e) raise e