From 1e94e47fde33ceb9c0901794f47d68665cf766c2 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 23 May 2024 19:49:24 +0000 Subject: [PATCH 01/19] add unit test to identify failing tests --- tests/utils/test_exceptions.py | 70 ++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 tests/utils/test_exceptions.py diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py new file mode 100644 index 0000000000..e401d652b8 --- /dev/null +++ b/tests/utils/test_exceptions.py @@ -0,0 +1,70 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pickle +from typing import Dict + +from llmfoundry.utils.exceptions import ( + ClusterDoesNotExistError, + ConsecutiveRepeatedChatRolesError, + FailedToConnectToDatabricksError, + FailedToCreateSQLConnectionError, + IncorrectMessageKeyQuantityError, + InputFolderMissingDataError, + InvalidContentTypeError, + InvalidFileExtensionError, + InvalidLastChatMessageRoleError, + InvalidPromptResponseKeysError, + InvalidPromptTypeError, + InvalidResponseTypeError, + InvalidRoleError, + MisconfiguredHfDatasetError, + MissingHuggingFaceURLSplitError, + NotEnoughChatDataError, + NotEnoughDatasetSamplesError, + OutputFolderNotEmptyError, + RunTimeoutError, + UnableToProcessPromptResponseError, + UnknownExampleTypeError, +) + + +def test_exception_serialization(): + exceptions = [ + MissingHuggingFaceURLSplitError(), + NotEnoughDatasetSamplesError('ds_name', 'split', 1, 2, 3, 4), + UnknownExampleTypeError('my_keys'), + NotEnoughChatDataError(), + ConsecutiveRepeatedChatRolesError('role'), + InvalidLastChatMessageRoleError('role', {'other_role'}), + IncorrectMessageKeyQuantityError(['key', 'key2']), + InvalidRoleError('role', {'other_role'}), + InvalidContentTypeError(Dict), + InvalidPromptTypeError(Dict), + InvalidResponseTypeError(Dict), + InvalidPromptResponseKeysError({'prompt': 'response'}, + {'response': 'prompt'}), + InvalidFileExtensionError('dsname', ['ext1', 'ext2']), + UnableToProcessPromptResponseError({'prompt': 'response'}), + ClusterDoesNotExistError('cluster_name'), + FailedToCreateSQLConnectionError(), + FailedToConnectToDatabricksError(), + InputFolderMissingDataError('folder'), + OutputFolderNotEmptyError('folder'), + MisconfiguredHfDatasetError('dataset_name', 'split'), + RunTimeoutError(100), + ] + + failed_exceptions = {} + + for exception in exceptions: + pkl = pickle.dumps(exception) + try: + pickle.loads(pkl) + except Exception as e: + failed_exceptions[exception.__class__.__name__] = str(e) + + if failed_exceptions: + raise AssertionError( + f'Failed to serialize/deserialize the following exceptions: {failed_exceptions}', + ) From fd53b1617288267c5819163d6dc150b5d0f253a3 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 23 May 2024 20:01:14 +0000 Subject: [PATCH 02/19] clarify issue --- llmfoundry/utils/exceptions.py | 9 +++++---- tests/utils/test_exceptions.py | 7 +++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 51da8610e9..d9b5b9478c 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -49,7 +49,7 @@ class ContextualError(Exception): class MissingHuggingFaceURLSplitError(ValueError, ContextualError): """Error thrown when there's no split used in HF dataset config.""" - def __init__(self) -> None: + def __init__(self, *_) -> None: message = 'When using a HuggingFace dataset from a URL, you must set the ' + \ '`split` key in the dataset config.' super().__init__(message) @@ -66,6 +66,7 @@ def __init__( world_size: int, full_dataset_size: int, minimum_dataset_size: int, + *_, ) -> None: self.dataset_name = dataset_name self.split = split @@ -102,7 +103,7 @@ def __init__(self, example_keys: str) -> None: class NotEnoughChatDataError(ValueError, ContextualError): """Error thrown when there is not enough chat data to train a model.""" - def __init__(self) -> None: + def __init__(self, *_) -> None: message = 'Chat example must have at least two messages' super().__init__(message) @@ -119,7 +120,7 @@ def __init__(self, repeated_role: str) -> None: class InvalidLastChatMessageRoleError(ValueError, ContextualError): """Error thrown when the last message role in a chat example is invalid.""" - def __init__(self, last_role: str, expected_roles: set[str]) -> None: + def __init__(self, last_role: str, expected_roles: set[str], *_) -> None: self.last_role = last_role self.expected_roles = expected_roles message = f'Invalid last message role: {last_role}. Expected one of: {expected_roles}' @@ -138,7 +139,7 @@ def __init__(self, keys: List[str]) -> None: class InvalidRoleError(ValueError, ContextualError): """Error thrown when a role is invalid.""" - def __init__(self, role: str, valid_roles: set[str]) -> None: + def __init__(self, role: str, valid_roles: set[str], *_) -> None: self.role = role self.valid_roles = valid_roles message = f'Expected role to be one of {valid_roles} but found: {role}' diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index e401d652b8..31c8662bf9 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -58,13 +58,16 @@ def test_exception_serialization(): failed_exceptions = {} for exception in exceptions: + exc_str = str(exception) pkl = pickle.dumps(exception) try: - pickle.loads(pkl) + unpickled_exc = pickle.loads(pkl) + unpickled_exc_str = str(unpickled_exc) + assert exc_str == unpickled_exc_str except Exception as e: failed_exceptions[exception.__class__.__name__] = str(e) if failed_exceptions: raise AssertionError( - f'Failed to serialize/deserialize the following exceptions: {failed_exceptions}', + f'Failed to serialize/deserialize the following exceptions: {failed_exceptions.keys()}\n\n{failed_exceptions=}', ) From 1b0c81a8dd69d4cbbe12435dcc45b3b4a8554436 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 23 May 2024 20:15:43 +0000 Subject: [PATCH 03/19] dict -> pretty json --- tests/utils/test_exceptions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 31c8662bf9..8ab2c545f3 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -1,6 +1,7 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import json import pickle from typing import Dict @@ -69,5 +70,6 @@ def test_exception_serialization(): if failed_exceptions: raise AssertionError( - f'Failed to serialize/deserialize the following exceptions: {failed_exceptions.keys()}\n\n{failed_exceptions=}', + f'Failed to serialize/deserialize the following exceptions: {failed_exceptions.keys()}\n\n' + + json.dumps(failed_exceptions, indent=2), ) From ee6a3a04c953c3cb57aa1f0f83434c7edb1ad6f7 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 23 May 2024 20:17:39 +0000 Subject: [PATCH 04/19] revert exceptions.py --- llmfoundry/utils/exceptions.py | 87 +++++++++++++++++++++++----------- 1 file changed, 59 insertions(+), 28 deletions(-) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index d9b5b9478c..7a34430a21 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -35,6 +35,8 @@ ALLOWED_MESSAGES_KEYS = {'messages'} ErrorLocation = Union[Literal['TrainDataloader'], Literal['EvalDataloader']] +ErrorAttribution = Union[Literal['UserError'], Literal['InternalError'], + Literal['NetworkError']] TrainDataLoaderLocation = 'TrainDataloader' EvalDataLoaderLocation = 'EvalDataloader' @@ -43,19 +45,38 @@ class ContextualError(Exception): """Error thrown when an error occurs in the context of a specific task.""" location: Optional[ErrorLocation] = None + error_attribution: Optional[ErrorAttribution] = None + + +class UserError(ContextualError): + """Error thrown when an error is caused by user input.""" + + error_attribution = 'UserError' + + +class NetworkError(ContextualError): + """Error thrown when an error is caused by a network issue.""" + + error_attribution = 'NetworkError' + + +class InternalError(ContextualError): + """Error thrown when an error is caused by an internal issue.""" + + error_attribution = 'InternalError' # Finetuning dataloader exceptions -class MissingHuggingFaceURLSplitError(ValueError, ContextualError): +class MissingHuggingFaceURLSplitError(ValueError, UserError): """Error thrown when there's no split used in HF dataset config.""" - def __init__(self, *_) -> None: + def __init__(self) -> None: message = 'When using a HuggingFace dataset from a URL, you must set the ' + \ '`split` key in the dataset config.' super().__init__(message) -class NotEnoughDatasetSamplesError(ValueError, ContextualError): +class NotEnoughDatasetSamplesError(ValueError, UserError): """Error thrown when there is not enough data to train a model.""" def __init__( @@ -66,7 +87,6 @@ def __init__( world_size: int, full_dataset_size: int, minimum_dataset_size: int, - *_, ) -> None: self.dataset_name = dataset_name self.split = split @@ -86,7 +106,7 @@ def __init__( ## Tasks exceptions -class UnknownExampleTypeError(KeyError, ContextualError): +class UnknownExampleTypeError(KeyError, UserError): """Error thrown when an unknown example type is used in a task.""" def __init__(self, example_keys: str) -> None: @@ -100,15 +120,15 @@ def __init__(self, example_keys: str) -> None: super().__init__(message) -class NotEnoughChatDataError(ValueError, ContextualError): +class NotEnoughChatDataError(ValueError, UserError): """Error thrown when there is not enough chat data to train a model.""" - def __init__(self, *_) -> None: + def __init__(self) -> None: message = 'Chat example must have at least two messages' super().__init__(message) -class ConsecutiveRepeatedChatRolesError(ValueError, ContextualError): +class ConsecutiveRepeatedChatRolesError(ValueError, UserError): """Error thrown when there are consecutive repeated chat roles.""" def __init__(self, repeated_role: str) -> None: @@ -117,17 +137,17 @@ def __init__(self, repeated_role: str) -> None: super().__init__(message) -class InvalidLastChatMessageRoleError(ValueError, ContextualError): +class InvalidLastChatMessageRoleError(ValueError, UserError): """Error thrown when the last message role in a chat example is invalid.""" - def __init__(self, last_role: str, expected_roles: set[str], *_) -> None: + def __init__(self, last_role: str, expected_roles: set[str]) -> None: self.last_role = last_role self.expected_roles = expected_roles message = f'Invalid last message role: {last_role}. Expected one of: {expected_roles}' super().__init__(message) -class IncorrectMessageKeyQuantityError(ValueError, ContextualError): +class IncorrectMessageKeyQuantityError(ValueError, UserError): """Error thrown when a message has an incorrect number of keys.""" def __init__(self, keys: List[str]) -> None: @@ -136,17 +156,17 @@ def __init__(self, keys: List[str]) -> None: super().__init__(message) -class InvalidRoleError(ValueError, ContextualError): +class InvalidRoleError(ValueError, UserError): """Error thrown when a role is invalid.""" - def __init__(self, role: str, valid_roles: set[str], *_) -> None: + def __init__(self, role: str, valid_roles: set[str]) -> None: self.role = role self.valid_roles = valid_roles message = f'Expected role to be one of {valid_roles} but found: {role}' super().__init__(message) -class InvalidContentTypeError(TypeError, ContextualError): +class InvalidContentTypeError(TypeError, UserError): """Error thrown when the content type is invalid.""" def __init__(self, content_type: type) -> None: @@ -155,7 +175,7 @@ def __init__(self, content_type: type) -> None: super().__init__(message) -class InvalidPromptTypeError(TypeError, ContextualError): +class InvalidPromptTypeError(TypeError, UserError): """Error thrown when the prompt type is invalid.""" def __init__(self, prompt_type: type) -> None: @@ -164,7 +184,7 @@ def __init__(self, prompt_type: type) -> None: super().__init__(message) -class InvalidResponseTypeError(TypeError, ContextualError): +class InvalidResponseTypeError(TypeError, UserError): """Error thrown when the response type is invalid.""" def __init__(self, response_type: type) -> None: @@ -173,7 +193,7 @@ def __init__(self, response_type: type) -> None: super().__init__(message) -class InvalidPromptResponseKeysError(ValueError, ContextualError): +class InvalidPromptResponseKeysError(ValueError, UserError): """Error thrown when missing expected prompt and response keys.""" def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]): @@ -182,7 +202,7 @@ def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]): super().__init__(message) -class InvalidFileExtensionError(FileNotFoundError, ContextualError): +class InvalidFileExtensionError(FileNotFoundError, UserError): """Error thrown when a file extension is not a safe extension.""" def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None: @@ -195,7 +215,10 @@ def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None: super().__init__(message) -class UnableToProcessPromptResponseError(ValueError, ContextualError): +class UnableToProcessPromptResponseError( + ValueError, + UserError, +): """Error thrown when a prompt and response cannot be processed.""" def __init__(self, input: Dict) -> None: @@ -205,7 +228,7 @@ def __init__(self, input: Dict) -> None: ## Convert Delta to JSON exceptions -class ClusterDoesNotExistError(ValueError, ContextualError): +class ClusterDoesNotExistError(ValueError, NetworkError): """Error thrown when the cluster does not exist.""" def __init__(self, cluster_id: str) -> None: @@ -214,15 +237,22 @@ def __init__(self, cluster_id: str) -> None: super().__init__(message) -class FailedToCreateSQLConnectionError(RuntimeError, ContextualError): +class FailedToCreateSQLConnectionError( + RuntimeError, + NetworkError, +): """Error thrown when client can't sql connect to Databricks.""" def __init__(self) -> None: - message = 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!' + message = 'Failed to create sql connection to db workspace. ' + \ + 'To use sql connect, you need to provide http_path and cluster_id!' super().__init__(message) -class FailedToConnectToDatabricksError(RuntimeError, ContextualError): +class FailedToConnectToDatabricksError( + RuntimeError, + NetworkError, +): """Error thrown when the client fails to connect to Databricks.""" def __init__(self) -> None: @@ -231,7 +261,7 @@ def __init__(self) -> None: ## Convert Text to MDS exceptions -class InputFolderMissingDataError(ValueError, ContextualError): +class InputFolderMissingDataError(ValueError, UserError): """Error thrown when the input folder is missing data.""" def __init__(self, input_folder: str) -> None: @@ -240,7 +270,7 @@ def __init__(self, input_folder: str) -> None: super().__init__(message) -class OutputFolderNotEmptyError(FileExistsError, ContextualError): +class OutputFolderNotEmptyError(FileExistsError, UserError): """Error thrown when the output folder is not empty.""" def __init__(self, output_folder: str) -> None: @@ -249,17 +279,18 @@ def __init__(self, output_folder: str) -> None: super().__init__(message) -class MisconfiguredHfDatasetError(ValueError, ContextualError): +class MisconfiguredHfDatasetError(ValueError, UserError): """Error thrown when a HuggingFace dataset is misconfigured.""" def __init__(self, dataset_name: str, split: str) -> None: self.dataset_name = dataset_name self.split = split - message = f'Your dataset (name={dataset_name}, split={split}) is misconfigured. Please check your dataset format and make sure you can load your dataset locally.' + message = f'Your dataset (name={dataset_name}, split={split}) is misconfigured. ' + \ + 'Please check your dataset format and make sure you can load your dataset locally.' super().__init__(message) -class RunTimeoutError(RuntimeError): +class RunTimeoutError(RuntimeError, InternalError): """Error thrown when a run times out.""" def __init__(self, timeout: int) -> None: From 26f1be404c83bbe5d1c4b8cf45f960ea8697bbf3 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 23 May 2024 21:18:42 +0000 Subject: [PATCH 05/19] remove multiple inheritance for most classes --- llmfoundry/utils/exceptions.py | 59 +++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 7a34430a21..fb5852af92 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -41,12 +41,30 @@ EvalDataLoaderLocation = 'EvalDataloader' -class ContextualError(Exception): +class SerializableError(): + + def __str__(self) -> str: + return str(self.error) + + def __setstate__(self, state: str): + super().__init__(state) + + def __getstate__(self) -> str: + return str(super()) + + +class ContextualError(Exception, SerializableError): """Error thrown when an error occurs in the context of a specific task.""" location: Optional[ErrorLocation] = None error_attribution: Optional[ErrorAttribution] = None + def __init__(self, message: str) -> None: + self.error = message + + def __str__(self) -> str: + return self.error + class UserError(ContextualError): """Error thrown when an error is caused by user input.""" @@ -67,7 +85,7 @@ class InternalError(ContextualError): # Finetuning dataloader exceptions -class MissingHuggingFaceURLSplitError(ValueError, UserError): +class MissingHuggingFaceURLSplitError(UserError): """Error thrown when there's no split used in HF dataset config.""" def __init__(self) -> None: @@ -76,7 +94,7 @@ def __init__(self) -> None: super().__init__(message) -class NotEnoughDatasetSamplesError(ValueError, UserError): +class NotEnoughDatasetSamplesError(UserError): """Error thrown when there is not enough data to train a model.""" def __init__( @@ -106,7 +124,7 @@ def __init__( ## Tasks exceptions -class UnknownExampleTypeError(KeyError, UserError): +class UnknownExampleTypeError(UserError): """Error thrown when an unknown example type is used in a task.""" def __init__(self, example_keys: str) -> None: @@ -120,7 +138,7 @@ def __init__(self, example_keys: str) -> None: super().__init__(message) -class NotEnoughChatDataError(ValueError, UserError): +class NotEnoughChatDataError(UserError): """Error thrown when there is not enough chat data to train a model.""" def __init__(self) -> None: @@ -128,7 +146,7 @@ def __init__(self) -> None: super().__init__(message) -class ConsecutiveRepeatedChatRolesError(ValueError, UserError): +class ConsecutiveRepeatedChatRolesError(UserError): """Error thrown when there are consecutive repeated chat roles.""" def __init__(self, repeated_role: str) -> None: @@ -137,7 +155,7 @@ def __init__(self, repeated_role: str) -> None: super().__init__(message) -class InvalidLastChatMessageRoleError(ValueError, UserError): +class InvalidLastChatMessageRoleError(UserError): """Error thrown when the last message role in a chat example is invalid.""" def __init__(self, last_role: str, expected_roles: set[str]) -> None: @@ -147,7 +165,7 @@ def __init__(self, last_role: str, expected_roles: set[str]) -> None: super().__init__(message) -class IncorrectMessageKeyQuantityError(ValueError, UserError): +class IncorrectMessageKeyQuantityError(UserError): """Error thrown when a message has an incorrect number of keys.""" def __init__(self, keys: List[str]) -> None: @@ -156,7 +174,7 @@ def __init__(self, keys: List[str]) -> None: super().__init__(message) -class InvalidRoleError(ValueError, UserError): +class InvalidRoleError(UserError): """Error thrown when a role is invalid.""" def __init__(self, role: str, valid_roles: set[str]) -> None: @@ -166,7 +184,7 @@ def __init__(self, role: str, valid_roles: set[str]) -> None: super().__init__(message) -class InvalidContentTypeError(TypeError, UserError): +class InvalidContentTypeError(UserError): """Error thrown when the content type is invalid.""" def __init__(self, content_type: type) -> None: @@ -175,7 +193,7 @@ def __init__(self, content_type: type) -> None: super().__init__(message) -class InvalidPromptTypeError(TypeError, UserError): +class InvalidPromptTypeError(UserError): """Error thrown when the prompt type is invalid.""" def __init__(self, prompt_type: type) -> None: @@ -184,7 +202,7 @@ def __init__(self, prompt_type: type) -> None: super().__init__(message) -class InvalidResponseTypeError(TypeError, UserError): +class InvalidResponseTypeError(UserError): """Error thrown when the response type is invalid.""" def __init__(self, response_type: type) -> None: @@ -193,7 +211,7 @@ def __init__(self, response_type: type) -> None: super().__init__(message) -class InvalidPromptResponseKeysError(ValueError, UserError): +class InvalidPromptResponseKeysError(UserError): """Error thrown when missing expected prompt and response keys.""" def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]): @@ -202,7 +220,7 @@ def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]): super().__init__(message) -class InvalidFileExtensionError(FileNotFoundError, UserError): +class InvalidFileExtensionError(UserError): """Error thrown when a file extension is not a safe extension.""" def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None: @@ -216,7 +234,6 @@ def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None: class UnableToProcessPromptResponseError( - ValueError, UserError, ): """Error thrown when a prompt and response cannot be processed.""" @@ -228,7 +245,7 @@ def __init__(self, input: Dict) -> None: ## Convert Delta to JSON exceptions -class ClusterDoesNotExistError(ValueError, NetworkError): +class ClusterDoesNotExistError(NetworkError): """Error thrown when the cluster does not exist.""" def __init__(self, cluster_id: str) -> None: @@ -238,7 +255,6 @@ def __init__(self, cluster_id: str) -> None: class FailedToCreateSQLConnectionError( - RuntimeError, NetworkError, ): """Error thrown when client can't sql connect to Databricks.""" @@ -250,7 +266,6 @@ def __init__(self) -> None: class FailedToConnectToDatabricksError( - RuntimeError, NetworkError, ): """Error thrown when the client fails to connect to Databricks.""" @@ -261,7 +276,7 @@ def __init__(self) -> None: ## Convert Text to MDS exceptions -class InputFolderMissingDataError(ValueError, UserError): +class InputFolderMissingDataError(UserError): """Error thrown when the input folder is missing data.""" def __init__(self, input_folder: str) -> None: @@ -270,7 +285,7 @@ def __init__(self, input_folder: str) -> None: super().__init__(message) -class OutputFolderNotEmptyError(FileExistsError, UserError): +class OutputFolderNotEmptyError(UserError): """Error thrown when the output folder is not empty.""" def __init__(self, output_folder: str) -> None: @@ -279,7 +294,7 @@ def __init__(self, output_folder: str) -> None: super().__init__(message) -class MisconfiguredHfDatasetError(ValueError, UserError): +class MisconfiguredHfDatasetError(UserError): """Error thrown when a HuggingFace dataset is misconfigured.""" def __init__(self, dataset_name: str, split: str) -> None: @@ -290,7 +305,7 @@ def __init__(self, dataset_name: str, split: str) -> None: super().__init__(message) -class RunTimeoutError(RuntimeError, InternalError): +class RunTimeoutError(InternalError): """Error thrown when a run times out.""" def __init__(self, timeout: int) -> None: From 02cebd217aaf6bb235268a4f593b57ba88836708 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 24 May 2024 01:20:31 +0000 Subject: [PATCH 06/19] fix with magic --- llmfoundry/utils/exceptions.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index fb5852af92..d891e15a56 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -43,14 +43,12 @@ class SerializableError(): - def __str__(self) -> str: - return str(self.error) - - def __setstate__(self, state: str): - super().__init__(state) + def __getstate__(self): + return self.__dict__ - def __getstate__(self) -> str: - return str(super()) + def __setstate__(self, state: Dict[str, Any]): + for key, value in state.items(): + setattr(self, key, value) class ContextualError(Exception, SerializableError): From 6542a118aaf5fb32c434d431bd0f3e428579d8d5 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 24 May 2024 15:45:33 +0000 Subject: [PATCH 07/19] address comments --- llmfoundry/utils/exceptions.py | 4 +- tests/utils/test_exceptions.py | 120 +++++++++++++++------------------ 2 files changed, 57 insertions(+), 67 deletions(-) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index d891e15a56..ecfa2f51c9 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -28,6 +28,7 @@ 'InputFolderMissingDataError', 'OutputFolderNotEmptyError', 'MisconfiguredHfDatasetError', + 'RunTimeoutError', ] ALLOWED_RESPONSE_KEYS = {'response', 'completion'} @@ -60,9 +61,6 @@ class ContextualError(Exception, SerializableError): def __init__(self, message: str) -> None: self.error = message - def __str__(self) -> str: - return self.error - class UserError(ContextualError): """Error thrown when an error is caused by user input.""" diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 8ab2c545f3..143681f3be 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -1,75 +1,67 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import json +import inspect import pickle -from typing import Dict +from typing import Any, Dict, List, Optional -from llmfoundry.utils.exceptions import ( - ClusterDoesNotExistError, - ConsecutiveRepeatedChatRolesError, - FailedToConnectToDatabricksError, - FailedToCreateSQLConnectionError, - IncorrectMessageKeyQuantityError, - InputFolderMissingDataError, - InvalidContentTypeError, - InvalidFileExtensionError, - InvalidLastChatMessageRoleError, - InvalidPromptResponseKeysError, - InvalidPromptTypeError, - InvalidResponseTypeError, - InvalidRoleError, - MisconfiguredHfDatasetError, - MissingHuggingFaceURLSplitError, - NotEnoughChatDataError, - NotEnoughDatasetSamplesError, - OutputFolderNotEmptyError, - RunTimeoutError, - UnableToProcessPromptResponseError, - UnknownExampleTypeError, -) +import pytest +from llmfoundry.utils.exceptions import __all__ as all_exceptions -def test_exception_serialization(): - exceptions = [ - MissingHuggingFaceURLSplitError(), - NotEnoughDatasetSamplesError('ds_name', 'split', 1, 2, 3, 4), - UnknownExampleTypeError('my_keys'), - NotEnoughChatDataError(), - ConsecutiveRepeatedChatRolesError('role'), - InvalidLastChatMessageRoleError('role', {'other_role'}), - IncorrectMessageKeyQuantityError(['key', 'key2']), - InvalidRoleError('role', {'other_role'}), - InvalidContentTypeError(Dict), - InvalidPromptTypeError(Dict), - InvalidResponseTypeError(Dict), - InvalidPromptResponseKeysError({'prompt': 'response'}, - {'response': 'prompt'}), - InvalidFileExtensionError('dsname', ['ext1', 'ext2']), - UnableToProcessPromptResponseError({'prompt': 'response'}), - ClusterDoesNotExistError('cluster_name'), - FailedToCreateSQLConnectionError(), - FailedToConnectToDatabricksError(), - InputFolderMissingDataError('folder'), - OutputFolderNotEmptyError('folder'), - MisconfiguredHfDatasetError('dataset_name', 'split'), - RunTimeoutError(100), + +def create_exception_object(exception_name: str): + exception_class = getattr( + __import__('llmfoundry.utils.exceptions', fromlist=[exception_name]), + exception_name, + ) + # get required arg types of exception class by inspecting its __init__ method + + required_args = inspect.get_annotations(exception_class.__init__) + + # create a dictionary of required args with default values + + def get_default_value(arg_type: Optional[type] = None): + if arg_type == Dict[str, + str] or arg_type == Dict[str, + Any] or arg_type == Dict: + return {'key': 'value'} + elif arg_type == str: + return 'string' + elif arg_type == int: + return 1 + elif arg_type == set[str]: + return {'set'} + elif arg_type == List[str]: + return ['list'] + elif arg_type == None: + return None + elif arg_type == type: + return bool + raise ValueError(f'Unsupported arg type: {arg_type}') + + required_args.pop('self', None) + required_args.pop('return', None) + kwargs = { + arg: get_default_value(arg_type) + for arg, arg_type in required_args.items() + } + return exception_class(*kwargs.values()) + + +def filter_exceptions(exceptions: List[str]): + return [ + exception for exception in exceptions + if ('Error' in exception or 'Exception' in exception) ] - failed_exceptions = {} - for exception in exceptions: - exc_str = str(exception) - pkl = pickle.dumps(exception) - try: - unpickled_exc = pickle.loads(pkl) - unpickled_exc_str = str(unpickled_exc) - assert exc_str == unpickled_exc_str - except Exception as e: - failed_exceptions[exception.__class__.__name__] = str(e) +@pytest.mark.parametrize('exception_name', filter_exceptions(all_exceptions)) +def test_exception_serialization(exception_name: str): + exception = create_exception_object(exception_name) - if failed_exceptions: - raise AssertionError( - f'Failed to serialize/deserialize the following exceptions: {failed_exceptions.keys()}\n\n' - + json.dumps(failed_exceptions, indent=2), - ) + exc_str = str(exception) + pkl = pickle.dumps(exception) + unpickled_exc = pickle.loads(pkl) + unpickled_exc_str = str(unpickled_exc) + assert exc_str == unpickled_exc_str From 024d462173566300acecb4c9c08f362a754110f1 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 24 May 2024 15:46:32 +0000 Subject: [PATCH 08/19] remove parens --- llmfoundry/utils/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index ecfa2f51c9..c86647aa12 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -42,7 +42,7 @@ EvalDataLoaderLocation = 'EvalDataloader' -class SerializableError(): +class SerializableError: def __getstate__(self): return self.__dict__ From 2af9e539f53d52f01da961e5e137c2918932d3da Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 24 May 2024 16:21:47 +0000 Subject: [PATCH 09/19] update tests to use dir instead of all --- llmfoundry/utils/exceptions.py | 20 ++++++++++---------- scripts/train/train.py | 8 ++++---- tests/utils/test_exceptions.py | 10 +++++++--- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index c86647aa12..546b3ca023 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -35,14 +35,14 @@ ALLOWED_PROMPT_KEYS = {'prompt'} ALLOWED_MESSAGES_KEYS = {'messages'} -ErrorLocation = Union[Literal['TrainDataloader'], Literal['EvalDataloader']] -ErrorAttribution = Union[Literal['UserError'], Literal['InternalError'], - Literal['NetworkError']] +FailureLocation = Union[Literal['TrainDataloader'], Literal['EvalDataloader']] +FailureAttribution = Union[Literal['UserError'], Literal['InternalError'], + Literal['NetworkError']] TrainDataLoaderLocation = 'TrainDataloader' EvalDataLoaderLocation = 'EvalDataloader' -class SerializableError: +class BaseSerialaziableError: def __getstate__(self): return self.__dict__ @@ -52,29 +52,29 @@ def __setstate__(self, state: Dict[str, Any]): setattr(self, key, value) -class ContextualError(Exception, SerializableError): +class BaseContextualError(Exception, BaseSerialaziableError): """Error thrown when an error occurs in the context of a specific task.""" - location: Optional[ErrorLocation] = None - error_attribution: Optional[ErrorAttribution] = None + location: Optional[FailureLocation] = None + error_attribution: Optional[FailureAttribution] = None def __init__(self, message: str) -> None: self.error = message -class UserError(ContextualError): +class UserError(BaseContextualError): """Error thrown when an error is caused by user input.""" error_attribution = 'UserError' -class NetworkError(ContextualError): +class NetworkError(BaseContextualError): """Error thrown when an error is caused by a network issue.""" error_attribution = 'NetworkError' -class InternalError(ContextualError): +class InternalError(BaseContextualError): """Error thrown when an error is caused by an internal issue.""" error_attribution = 'InternalError' diff --git a/scripts/train/train.py b/scripts/train/train.py index e0c2b8a94f..1d7dbbae03 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -30,7 +30,7 @@ maybe_create_mosaicml_logger, ) from llmfoundry.utils.exceptions import ( - ContextualError, + Base, EvalDataLoaderLocation, TrainDataLoaderLocation, ) @@ -396,7 +396,7 @@ def main(cfg: DictConfig) -> Trainer: tokenizer, train_cfg.device_train_batch_size, ) - except ContextualError as e: + except Base as e: if mosaicml_logger is not None: e.location = TrainDataLoaderLocation mosaicml_logger.log_exception(e) @@ -429,7 +429,7 @@ def main(cfg: DictConfig) -> Trainer: ) if eval_gauntlet_callback is not None: callbacks.append(eval_gauntlet_callback) - except ContextualError as e: + except Base as e: if mosaicml_logger is not None: e.location = EvalDataLoaderLocation mosaicml_logger.log_exception(e) @@ -479,7 +479,7 @@ def main(cfg: DictConfig) -> Trainer: evaluators, non_icl_metrics, ) - except ContextualError as e: + except Base as e: if mosaicml_logger is not None: e.location = EvalDataLoaderLocation mosaicml_logger.log_exception(e) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 143681f3be..b3cdf96faf 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -7,7 +7,7 @@ import pytest -from llmfoundry.utils.exceptions import __all__ as all_exceptions +import llmfoundry.utils.exceptions as foundry_exceptions def create_exception_object(exception_name: str): @@ -52,11 +52,15 @@ def get_default_value(arg_type: Optional[type] = None): def filter_exceptions(exceptions: List[str]): return [ exception for exception in exceptions - if ('Error' in exception or 'Exception' in exception) + if ('Error' in exception or 'Exception' in exception) and + ('Base' not in exception) ] -@pytest.mark.parametrize('exception_name', filter_exceptions(all_exceptions)) +@pytest.mark.parametrize( + 'exception_name', + filter_exceptions(dir(foundry_exceptions)), +) def test_exception_serialization(exception_name: str): exception = create_exception_object(exception_name) From 694df2305b0de7b1f4ec4180f2bdcbe6c831dcf0 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 24 May 2024 16:36:15 +0000 Subject: [PATCH 10/19] I've been a silly goose --- llmfoundry/data/finetuning/dataloader.py | 12 ++++++------ scripts/train/train.py | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 639beba6f0..df61fe7ec5 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -291,12 +291,12 @@ def build_finetuning_dataloader( full_dataset_size = len(streaming_dataset) if full_dataset_size < minimum_dataset_size: raise NotEnoughDatasetSamplesError( - dataset_name=dataset_cfg['hf_name'], - split=split, - dataloader_batch_size=dataloader_batch_size, - world_size=world_size, - full_dataset_size=full_dataset_size, - minimum_dataset_size=minimum_dataset_size, + dataset_cfg['hf_name'], + split, + dataloader_batch_size, + world_size, + full_dataset_size, + minimum_dataset_size, ) # Initialize sampler. diff --git a/scripts/train/train.py b/scripts/train/train.py index 1d7dbbae03..4bcac8f773 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -30,7 +30,7 @@ maybe_create_mosaicml_logger, ) from llmfoundry.utils.exceptions import ( - Base, + BaseContextualError, EvalDataLoaderLocation, TrainDataLoaderLocation, ) @@ -396,7 +396,7 @@ def main(cfg: DictConfig) -> Trainer: tokenizer, train_cfg.device_train_batch_size, ) - except Base as e: + except BaseContextualError as e: if mosaicml_logger is not None: e.location = TrainDataLoaderLocation mosaicml_logger.log_exception(e) @@ -429,7 +429,7 @@ def main(cfg: DictConfig) -> Trainer: ) if eval_gauntlet_callback is not None: callbacks.append(eval_gauntlet_callback) - except Base as e: + except BaseContextualError as e: if mosaicml_logger is not None: e.location = EvalDataLoaderLocation mosaicml_logger.log_exception(e) @@ -479,7 +479,7 @@ def main(cfg: DictConfig) -> Trainer: evaluators, non_icl_metrics, ) - except Base as e: + except BaseContextualError as e: if mosaicml_logger is not None: e.location = EvalDataLoaderLocation mosaicml_logger.log_exception(e) From b1fbbef2bcbdd364974a4fa8093a8cdc0e5e2702 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 24 May 2024 16:44:36 +0000 Subject: [PATCH 11/19] spelling --- llmfoundry/utils/exceptions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 546b3ca023..7d752a9c7e 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -42,7 +42,7 @@ EvalDataLoaderLocation = 'EvalDataloader' -class BaseSerialaziableError: +class BaseSerializableError: def __getstate__(self): return self.__dict__ @@ -52,7 +52,7 @@ def __setstate__(self, state: Dict[str, Any]): setattr(self, key, value) -class BaseContextualError(Exception, BaseSerialaziableError): +class BaseContextualError(Exception, BaseSerializableError): """Error thrown when an error occurs in the context of a specific task.""" location: Optional[FailureLocation] = None From 313d52b7706c3c9de8453eee31f101e111fa406c Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 24 May 2024 16:47:04 +0000 Subject: [PATCH 12/19] fix version issue and spelling --- tests/utils/test_exceptions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index b3cdf96faf..297993d905 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -17,7 +17,10 @@ def create_exception_object(exception_name: str): ) # get required arg types of exception class by inspecting its __init__ method - required_args = inspect.get_annotations(exception_class.__init__) + if hasattr(inspect, 'get_annotations'): + required_args = inspect.get_annotations(exception_class.__init__) + else: + required_args = exception_class.__init__.__annotations__ # python 3.9 and below # create a dictionary of required args with default values From aef5632b9352755e244b2d9efc7f53cf2f794457 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 24 May 2024 16:58:17 +0000 Subject: [PATCH 13/19] more like py-wrong amirite? --- tests/utils/test_exceptions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 297993d905..f7ac995c70 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -18,7 +18,9 @@ def create_exception_object(exception_name: str): # get required arg types of exception class by inspecting its __init__ method if hasattr(inspect, 'get_annotations'): - required_args = inspect.get_annotations(exception_class.__init__) + required_args = inspect.get_annotations( + exception_class.__init__, + ) # type: ignore else: required_args = exception_class.__init__.__annotations__ # python 3.9 and below From 57b3438e43a6b2f7dda6c8eb2be8dbad47cdc593 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 24 May 2024 17:13:01 +0000 Subject: [PATCH 14/19] more ignoring --- tests/utils/test_exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index f7ac995c70..1a2ef159f0 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -18,7 +18,7 @@ def create_exception_object(exception_name: str): # get required arg types of exception class by inspecting its __init__ method if hasattr(inspect, 'get_annotations'): - required_args = inspect.get_annotations( + required_args = inspect.get_annotations( # type: ignore exception_class.__init__, ) # type: ignore else: From 3bc11f757900ec8978354b10b1594059caa8e652 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 24 May 2024 16:23:02 -0700 Subject: [PATCH 15/19] lots o fixes --- llmfoundry/utils/exceptions.py | 132 +++++++++++++++++++-------------- tests/utils/test_exceptions.py | 58 ++++++++++----- 2 files changed, 116 insertions(+), 74 deletions(-) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 7d752a9c7e..1a434e1cb4 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -42,24 +42,32 @@ EvalDataLoaderLocation = 'EvalDataloader' -class BaseSerializableError: - - def __getstate__(self): - return self.__dict__ - - def __setstate__(self, state: Dict[str, Any]): - for key, value in state.items(): - setattr(self, key, value) - - -class BaseContextualError(Exception, BaseSerializableError): +class BaseContextualError(Exception): """Error thrown when an error occurs in the context of a specific task.""" location: Optional[FailureLocation] = None error_attribution: Optional[FailureAttribution] = None - def __init__(self, message: str) -> None: + def __init__(self, message: str, **kwargs: Any) -> None: self.error = message + self.serializable_attributes = [] + + for key, value in kwargs.items(): + setattr(self, key, value) + self.serializable_attributes.append(key) + + def __str__(self) -> str: + return self.error + + def __reduce__(self): + if self.__class__ == BaseContextualError: + raise NotImplementedError( + 'BaseContextualError is a base class and cannot be pickled.', + ) + tuple_of_args = tuple([ + getattr(self, key) for key in self.serializable_attributes + ]) + return (self.__class__, tuple_of_args) class UserError(BaseContextualError): @@ -67,18 +75,42 @@ class UserError(BaseContextualError): error_attribution = 'UserError' + def __reduce__(self): + if self.__class__ == UserError: + raise NotImplementedError( + 'BaseContextualError is a base class and cannot be pickled.', + ) + + return super().__reduce__() + class NetworkError(BaseContextualError): """Error thrown when an error is caused by a network issue.""" error_attribution = 'NetworkError' + def __reduce__(self): + if self.__class__ == NetworkError: + raise NotImplementedError( + 'BaseContextualError is a base class and cannot be pickled.', + ) + + return super().__reduce__() + class InternalError(BaseContextualError): """Error thrown when an error is caused by an internal issue.""" error_attribution = 'InternalError' + def __reduce__(self): + if self.__class__ == InternalError: + raise NotImplementedError( + 'BaseContextualError is a base class and cannot be pickled.', + ) + + return super().__reduce__() + # Finetuning dataloader exceptions class MissingHuggingFaceURLSplitError(UserError): @@ -102,12 +134,6 @@ def __init__( full_dataset_size: int, minimum_dataset_size: int, ) -> None: - self.dataset_name = dataset_name - self.split = split - self.dataloader_batch_size = dataloader_batch_size - self.world_size = world_size - self.full_dataset_size = full_dataset_size - self.minimum_dataset_size = minimum_dataset_size message = ( f'Your dataset (name={dataset_name}, split={split}) ' + f'has {full_dataset_size} samples, but your minimum batch size ' + @@ -116,7 +142,15 @@ def __init__( f'your per device batch size is {dataloader_batch_size}. Please increase the number ' + f'of samples in your dataset to at least {minimum_dataset_size}.' ) - super().__init__(message) + super().__init__( + message, + dataset_name=dataset_name, + split=split, + dataloader_batch_size=dataloader_batch_size, + world_size=world_size, + full_dataset_size=full_dataset_size, + minimum_dataset_size=minimum_dataset_size, + ) ## Tasks exceptions @@ -124,14 +158,13 @@ class UnknownExampleTypeError(UserError): """Error thrown when an unknown example type is used in a task.""" def __init__(self, example_keys: str) -> None: - self.example = example_keys message = ( f'Found keys {example_keys} in dataset. Unknown example type. For prompt and response ' f'finetuning, the valid prompt keys are {ALLOWED_PROMPT_KEYS} and the valid response keys are ' f'{ALLOWED_RESPONSE_KEYS}. For chat finetuning, the allowed keys are {ALLOWED_MESSAGES_KEYS}' ) - super().__init__(message) + super().__init__(message, example_keys=example_keys) class NotEnoughChatDataError(UserError): @@ -146,87 +179,83 @@ class ConsecutiveRepeatedChatRolesError(UserError): """Error thrown when there are consecutive repeated chat roles.""" def __init__(self, repeated_role: str) -> None: - self.repeated_role = repeated_role message = f'Conversation roles must alternate but found {repeated_role} repeated consecutively.' - super().__init__(message) + super().__init__(message, repeated_role=repeated_role) class InvalidLastChatMessageRoleError(UserError): """Error thrown when the last message role in a chat example is invalid.""" def __init__(self, last_role: str, expected_roles: set[str]) -> None: - self.last_role = last_role - self.expected_roles = expected_roles message = f'Invalid last message role: {last_role}. Expected one of: {expected_roles}' - super().__init__(message) + super().__init__( + message, + last_role=last_role, + expected_roles=expected_roles, + ) class IncorrectMessageKeyQuantityError(UserError): """Error thrown when a message has an incorrect number of keys.""" def __init__(self, keys: List[str]) -> None: - self.keys = keys message = f'Expected 2 keys in message, but found {len(keys)}' - super().__init__(message) + super().__init__(message, keys=keys) class InvalidRoleError(UserError): """Error thrown when a role is invalid.""" def __init__(self, role: str, valid_roles: set[str]) -> None: - self.role = role - self.valid_roles = valid_roles message = f'Expected role to be one of {valid_roles} but found: {role}' - super().__init__(message) + super().__init__(message, role=role, valid_roles=valid_roles) class InvalidContentTypeError(UserError): """Error thrown when the content type is invalid.""" def __init__(self, content_type: type) -> None: - self.content_type = content_type message = f'Expected content to be a string, but found {content_type}' - super().__init__(message) + super().__init__(message, content_type=content_type) class InvalidPromptTypeError(UserError): """Error thrown when the prompt type is invalid.""" def __init__(self, prompt_type: type) -> None: - self.prompt_type = prompt_type message = f'Expected prompt to be a string, but found {prompt_type}' - super().__init__(message) + super().__init__(message, prompt_type=prompt_type) class InvalidResponseTypeError(UserError): """Error thrown when the response type is invalid.""" def __init__(self, response_type: type) -> None: - self.response_type = response_type message = f'Expected response to be a string, but found {response_type}' - super().__init__(message) + super().__init__(message, response_type=response_type) class InvalidPromptResponseKeysError(UserError): """Error thrown when missing expected prompt and response keys.""" def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]): - self.example = example message = f'Expected {mapping=} to have keys "prompt" and "response".' - super().__init__(message) + super().__init__(message, mapping=mapping, example=example) class InvalidFileExtensionError(UserError): """Error thrown when a file extension is not a safe extension.""" def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None: - self.dataset_name = dataset_name - self.valid_extensions = valid_extensions message = ( f'safe_load is set to True. No data files with safe extensions {valid_extensions} ' + f'found for dataset at local path {dataset_name}.' ) - super().__init__(message) + super().__init__( + message, + dataset_name=dataset_name, + valid_extensions=valid_extensions, + ) class UnableToProcessPromptResponseError( @@ -235,9 +264,8 @@ class UnableToProcessPromptResponseError( """Error thrown when a prompt and response cannot be processed.""" def __init__(self, input: Dict) -> None: - self.input = input message = f'Unable to extract prompt/response from {input}' - super().__init__(message) + super().__init__(message, input=input) ## Convert Delta to JSON exceptions @@ -245,9 +273,8 @@ class ClusterDoesNotExistError(NetworkError): """Error thrown when the cluster does not exist.""" def __init__(self, cluster_id: str) -> None: - self.cluster_id = cluster_id message = f'Cluster with id {cluster_id} does not exist. Check cluster id and try again!' - super().__init__(message) + super().__init__(message, cluster_id=cluster_id) class FailedToCreateSQLConnectionError( @@ -276,35 +303,30 @@ class InputFolderMissingDataError(UserError): """Error thrown when the input folder is missing data.""" def __init__(self, input_folder: str) -> None: - self.input_folder = input_folder message = f'No text files were found at {input_folder}.' - super().__init__(message) + super().__init__(message, input_folder=input_folder) class OutputFolderNotEmptyError(UserError): """Error thrown when the output folder is not empty.""" def __init__(self, output_folder: str) -> None: - self.output_folder = output_folder message = f'{output_folder} is not empty. Please remove or empty it and retry.' - super().__init__(message) + super().__init__(message, output_folder=output_folder) class MisconfiguredHfDatasetError(UserError): """Error thrown when a HuggingFace dataset is misconfigured.""" def __init__(self, dataset_name: str, split: str) -> None: - self.dataset_name = dataset_name - self.split = split message = f'Your dataset (name={dataset_name}, split={split}) is misconfigured. ' + \ 'Please check your dataset format and make sure you can load your dataset locally.' - super().__init__(message) + super().__init__(message, dataset_name=dataset_name, split=split) class RunTimeoutError(InternalError): """Error thrown when a run times out.""" def __init__(self, timeout: int) -> None: - self.timeout = timeout message = f'Run timed out after {timeout} seconds.' - super().__init__(message) + super().__init__(message, timeout=timeout) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 1a2ef159f0..e351bdcfbb 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -1,20 +1,19 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import contextlib import inspect import pickle -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type import pytest import llmfoundry.utils.exceptions as foundry_exceptions -def create_exception_object(exception_name: str): - exception_class = getattr( - __import__('llmfoundry.utils.exceptions', fromlist=[exception_name]), - exception_name, - ) +def create_exception_object( + exception_class: Type[foundry_exceptions.BaseContextualError], +): # get required arg types of exception class by inspecting its __init__ method if hasattr(inspect, 'get_annotations'): @@ -25,6 +24,7 @@ def create_exception_object(exception_name: str): required_args = exception_class.__init__.__annotations__ # python 3.9 and below # create a dictionary of required args with default values + required_args.pop('kwargs', None) def get_default_value(arg_type: Optional[type] = None): if arg_type == Dict[str, @@ -51,26 +51,46 @@ def get_default_value(arg_type: Optional[type] = None): arg: get_default_value(arg_type) for arg, arg_type in required_args.items() } - return exception_class(*kwargs.values()) + return exception_class(**kwargs) # type: ignore -def filter_exceptions(exceptions: List[str]): - return [ - exception for exception in exceptions - if ('Error' in exception or 'Exception' in exception) and - ('Base' not in exception) +def filter_exceptions(possible_exceptions: List[str]): + attrs = [ + getattr(foundry_exceptions, exception) + for exception in possible_exceptions ] + classes = [attr for attr in attrs if inspect.isclass(attr)] + exceptions = [ + exception_class for exception_class in classes + if issubclass(exception_class, foundry_exceptions.BaseContextualError) + ] + return exceptions @pytest.mark.parametrize( - 'exception_name', + 'exception_class', filter_exceptions(dir(foundry_exceptions)), ) -def test_exception_serialization(exception_name: str): - exception = create_exception_object(exception_name) +def test_exception_serialization( + exception_class: Type[foundry_exceptions.BaseContextualError], +): + excluded_base_classes = [ + foundry_exceptions.InternalError, + foundry_exceptions.UserError, + foundry_exceptions.NetworkError, + foundry_exceptions.BaseContextualError, + ] + + exception = create_exception_object(exception_class) + + expect_reduce_error = exception.__class__ in excluded_base_classes + error_context = pytest.raises( + NotImplementedError, + ) if expect_reduce_error else contextlib.nullcontext() exc_str = str(exception) - pkl = pickle.dumps(exception) - unpickled_exc = pickle.loads(pkl) - unpickled_exc_str = str(unpickled_exc) - assert exc_str == unpickled_exc_str + with error_context: + pkl = pickle.dumps(exception) + unpickled_exc = pickle.loads(pkl) + unpickled_exc_str = str(unpickled_exc) + assert exc_str == unpickled_exc_str From 8df80f62942f8f18bfd0e3aac78b92f5f98f4082 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 24 May 2024 16:28:16 -0700 Subject: [PATCH 16/19] add the kwargs back --- llmfoundry/data/finetuning/dataloader.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index df61fe7ec5..639beba6f0 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -291,12 +291,12 @@ def build_finetuning_dataloader( full_dataset_size = len(streaming_dataset) if full_dataset_size < minimum_dataset_size: raise NotEnoughDatasetSamplesError( - dataset_cfg['hf_name'], - split, - dataloader_batch_size, - world_size, - full_dataset_size, - minimum_dataset_size, + dataset_name=dataset_cfg['hf_name'], + split=split, + dataloader_batch_size=dataloader_batch_size, + world_size=world_size, + full_dataset_size=full_dataset_size, + minimum_dataset_size=minimum_dataset_size, ) # Initialize sampler. From 09e320e74828313c793be4e82751c040c4783982 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 24 May 2024 16:30:55 -0700 Subject: [PATCH 17/19] more tests --- tests/utils/test_exceptions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index e351bdcfbb..a3e5026d88 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -94,3 +94,5 @@ def test_exception_serialization( unpickled_exc = pickle.loads(pkl) unpickled_exc_str = str(unpickled_exc) assert exc_str == unpickled_exc_str + assert exception.location == unpickled_exc.location + assert exception.error_attribution == unpickled_exc.error_attribution From ed41e4772b161786a481c7f29cf2176168f1a916 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 24 May 2024 17:03:27 -0700 Subject: [PATCH 18/19] fix super call --- llmfoundry/utils/exceptions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 1a434e1cb4..b3f0030311 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -56,8 +56,7 @@ def __init__(self, message: str, **kwargs: Any) -> None: setattr(self, key, value) self.serializable_attributes.append(key) - def __str__(self) -> str: - return self.error + super().__init__(message) def __reduce__(self): if self.__class__ == BaseContextualError: From 85d17eb457148a109c8d7b6d470db5757eb0b1c6 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 24 May 2024 17:59:04 -0700 Subject: [PATCH 19/19] pr comments --- llmfoundry/utils/exceptions.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index f1013bf7a3..76f378f8c6 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -59,6 +59,13 @@ def __init__(self, message: str, **kwargs: Any) -> None: super().__init__(message) def __reduce__(self): + """Adjust the reduce behavior for pickling. + + Because we have custom exception subclasses with constructor args, we + need to adjust the reduce behavior to ensure that the exception can be + pickled. This allows error propagation across processes in + multiprocessing. + """ if self.__class__ == BaseContextualError: raise NotImplementedError( 'BaseContextualError is a base class and cannot be pickled.', @@ -77,7 +84,7 @@ class UserError(BaseContextualError): def __reduce__(self): if self.__class__ == UserError: raise NotImplementedError( - 'BaseContextualError is a base class and cannot be pickled.', + 'UserError is a base class and cannot be pickled.', ) return super().__reduce__() @@ -91,7 +98,7 @@ class NetworkError(BaseContextualError): def __reduce__(self): if self.__class__ == NetworkError: raise NotImplementedError( - 'BaseContextualError is a base class and cannot be pickled.', + 'NetworkError is a base class and cannot be pickled.', ) return super().__reduce__() @@ -105,7 +112,7 @@ class InternalError(BaseContextualError): def __reduce__(self): if self.__class__ == InternalError: raise NotImplementedError( - 'BaseContextualError is a base class and cannot be pickled.', + 'InternalError is a base class and cannot be pickled.', ) return super().__reduce__() @@ -196,7 +203,7 @@ def __init__( message, template=template, sample=sample, - inner_message=inner_message + inner_message=inner_message, )