diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 744a4d7b96..76f378f8c6 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -28,46 +28,98 @@ 'InputFolderMissingDataError', 'OutputFolderNotEmptyError', 'MisconfiguredHfDatasetError', + 'RunTimeoutError', ] ALLOWED_RESPONSE_KEYS = {'response', 'completion'} ALLOWED_PROMPT_KEYS = {'prompt'} ALLOWED_MESSAGES_KEYS = {'messages'} -ErrorLocation = Union[Literal['TrainDataloader'], Literal['EvalDataloader']] -ErrorAttribution = Union[Literal['UserError'], Literal['InternalError'], - Literal['NetworkError']] +FailureLocation = Union[Literal['TrainDataloader'], Literal['EvalDataloader']] +FailureAttribution = Union[Literal['UserError'], Literal['InternalError'], + Literal['NetworkError']] TrainDataLoaderLocation = 'TrainDataloader' EvalDataLoaderLocation = 'EvalDataloader' -class ContextualError(Exception): +class BaseContextualError(Exception): """Error thrown when an error occurs in the context of a specific task.""" - location: Optional[ErrorLocation] = None - error_attribution: Optional[ErrorAttribution] = None + location: Optional[FailureLocation] = None + error_attribution: Optional[FailureAttribution] = None + def __init__(self, message: str, **kwargs: Any) -> None: + self.error = message + self.serializable_attributes = [] -class UserError(ContextualError): + for key, value in kwargs.items(): + setattr(self, key, value) + self.serializable_attributes.append(key) + + super().__init__(message) + + def __reduce__(self): + """Adjust the reduce behavior for pickling. + + Because we have custom exception subclasses with constructor args, we + need to adjust the reduce behavior to ensure that the exception can be + pickled. This allows error propagation across processes in + multiprocessing. + """ + if self.__class__ == BaseContextualError: + raise NotImplementedError( + 'BaseContextualError is a base class and cannot be pickled.', + ) + tuple_of_args = tuple([ + getattr(self, key) for key in self.serializable_attributes + ]) + return (self.__class__, tuple_of_args) + + +class UserError(BaseContextualError): """Error thrown when an error is caused by user input.""" error_attribution = 'UserError' + def __reduce__(self): + if self.__class__ == UserError: + raise NotImplementedError( + 'UserError is a base class and cannot be pickled.', + ) + + return super().__reduce__() -class NetworkError(ContextualError): + +class NetworkError(BaseContextualError): """Error thrown when an error is caused by a network issue.""" error_attribution = 'NetworkError' + def __reduce__(self): + if self.__class__ == NetworkError: + raise NotImplementedError( + 'NetworkError is a base class and cannot be pickled.', + ) + + return super().__reduce__() -class InternalError(ContextualError): + +class InternalError(BaseContextualError): """Error thrown when an error is caused by an internal issue.""" error_attribution = 'InternalError' + def __reduce__(self): + if self.__class__ == InternalError: + raise NotImplementedError( + 'InternalError is a base class and cannot be pickled.', + ) + + return super().__reduce__() + # Finetuning dataloader exceptions -class MissingHuggingFaceURLSplitError(ValueError, UserError): +class MissingHuggingFaceURLSplitError(UserError): """Error thrown when there's no split used in HF dataset config.""" def __init__(self) -> None: @@ -88,12 +140,6 @@ def __init__( full_dataset_size: int, minimum_dataset_size: int, ) -> None: - self.dataset_name = dataset_name - self.split = split - self.dataloader_batch_size = dataloader_batch_size - self.world_size = world_size - self.full_dataset_size = full_dataset_size - self.minimum_dataset_size = minimum_dataset_size message = ( f'Your dataset (name={dataset_name}, split={split}) ' + f'has {full_dataset_size} samples, but your minimum batch size ' + @@ -102,25 +148,32 @@ def __init__( f'your per device batch size is {dataloader_batch_size}. Please increase the number ' + f'of samples in your dataset to at least {minimum_dataset_size}.' ) - super().__init__(message) + super().__init__( + message, + dataset_name=dataset_name, + split=split, + dataloader_batch_size=dataloader_batch_size, + world_size=world_size, + full_dataset_size=full_dataset_size, + minimum_dataset_size=minimum_dataset_size, + ) ## Tasks exceptions -class UnknownExampleTypeError(KeyError, UserError): +class UnknownExampleTypeError(UserError): """Error thrown when an unknown example type is used in a task.""" def __init__(self, example_keys: str) -> None: - self.example = example_keys message = ( f'Found keys {example_keys} in dataset. Unknown example type. For prompt and response ' f'finetuning, the valid prompt keys are {ALLOWED_PROMPT_KEYS} and the valid response keys are ' f'{ALLOWED_RESPONSE_KEYS}. For chat finetuning, the allowed keys are {ALLOWED_MESSAGES_KEYS}' ) - super().__init__(message) + super().__init__(message, example_keys=example_keys) -class NotEnoughChatDataError(ValueError, UserError): +class NotEnoughChatDataError(UserError): """Error thrown when there is not enough chat data to train a model.""" def __init__(self) -> None: @@ -128,16 +181,15 @@ def __init__(self) -> None: super().__init__(message) -class ConsecutiveRepeatedChatRolesError(ValueError, UserError): +class ConsecutiveRepeatedChatRolesError(UserError): """Error thrown when there are consecutive repeated chat roles.""" def __init__(self, repeated_role: str) -> None: - self.repeated_role = repeated_role message = f'Conversation roles must alternate but found {repeated_role} repeated consecutively.' - super().__init__(message) + super().__init__(message, repeated_role=repeated_role) -class ChatTemplateError(ValueError, UserError): +class ChatTemplateError(UserError): """Error thrown when a chat template fails to process a sample.""" def __init__( @@ -146,114 +198,110 @@ def __init__( sample: List[Dict[str, Any]], inner_message: str, ) -> None: - self.template = template - self.sample = sample message = f'Failed to process sample {sample} with template {template}. {inner_message}' - super().__init__(message) + super().__init__( + message, + template=template, + sample=sample, + inner_message=inner_message, + ) -class InvalidLastChatMessageRoleError(ValueError, UserError): +class InvalidLastChatMessageRoleError(UserError): """Error thrown when the last message role in a chat example is invalid.""" def __init__(self, last_role: str, expected_roles: set[str]) -> None: - self.last_role = last_role - self.expected_roles = expected_roles message = f'Invalid last message role: {last_role}. Expected one of: {expected_roles}' - super().__init__(message) + super().__init__( + message, + last_role=last_role, + expected_roles=expected_roles, + ) -class IncorrectMessageKeyQuantityError(ValueError, UserError): +class IncorrectMessageKeyQuantityError(UserError): """Error thrown when a message has an incorrect number of keys.""" def __init__(self, keys: List[str]) -> None: - self.keys = keys message = f'Expected 2 keys in message, but found {len(keys)}' - super().__init__(message) + super().__init__(message, keys=keys) -class InvalidRoleError(ValueError, UserError): +class InvalidRoleError(UserError): """Error thrown when a role is invalid.""" def __init__(self, role: str, valid_roles: set[str]) -> None: - self.role = role - self.valid_roles = valid_roles message = f'Expected role to be one of {valid_roles} but found: {role}' - super().__init__(message) + super().__init__(message, role=role, valid_roles=valid_roles) -class InvalidContentTypeError(TypeError, UserError): +class InvalidContentTypeError(UserError): """Error thrown when the content type is invalid.""" def __init__(self, content_type: type) -> None: - self.content_type = content_type message = f'Expected content to be a string, but found {content_type}' - super().__init__(message) + super().__init__(message, content_type=content_type) -class InvalidPromptTypeError(TypeError, UserError): +class InvalidPromptTypeError(UserError): """Error thrown when the prompt type is invalid.""" def __init__(self, prompt_type: type) -> None: - self.prompt_type = prompt_type message = f'Expected prompt to be a string, but found {prompt_type}' - super().__init__(message) + super().__init__(message, prompt_type=prompt_type) -class InvalidResponseTypeError(TypeError, UserError): +class InvalidResponseTypeError(UserError): """Error thrown when the response type is invalid.""" def __init__(self, response_type: type) -> None: - self.response_type = response_type message = f'Expected response to be a string, but found {response_type}' - super().__init__(message) + super().__init__(message, response_type=response_type) -class InvalidPromptResponseKeysError(ValueError, UserError): +class InvalidPromptResponseKeysError(UserError): """Error thrown when missing expected prompt and response keys.""" def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]): - self.example = example message = f'Expected {mapping=} to have keys "prompt" and "response".' - super().__init__(message) + super().__init__(message, mapping=mapping, example=example) -class InvalidFileExtensionError(FileNotFoundError, UserError): +class InvalidFileExtensionError(UserError): """Error thrown when a file extension is not a safe extension.""" def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None: - self.dataset_name = dataset_name - self.valid_extensions = valid_extensions message = ( f'safe_load is set to True. No data files with safe extensions {valid_extensions} ' + f'found for dataset at local path {dataset_name}.' ) - super().__init__(message) + super().__init__( + message, + dataset_name=dataset_name, + valid_extensions=valid_extensions, + ) class UnableToProcessPromptResponseError( - ValueError, UserError, ): """Error thrown when a prompt and response cannot be processed.""" def __init__(self, input: Dict) -> None: - self.input = input message = f'Unable to extract prompt/response from {input}' - super().__init__(message) + super().__init__(message, input=input) ## Convert Delta to JSON exceptions -class ClusterDoesNotExistError(ValueError, NetworkError): +class ClusterDoesNotExistError(NetworkError): """Error thrown when the cluster does not exist.""" def __init__(self, cluster_id: str) -> None: - self.cluster_id = cluster_id message = f'Cluster with id {cluster_id} does not exist. Check cluster id and try again!' - super().__init__(message) + super().__init__(message, cluster_id=cluster_id) class FailedToCreateSQLConnectionError( - RuntimeError, NetworkError, ): """Error thrown when client can't sql connect to Databricks.""" @@ -265,7 +313,6 @@ def __init__(self) -> None: class FailedToConnectToDatabricksError( - RuntimeError, NetworkError, ): """Error thrown when the client fails to connect to Databricks.""" @@ -276,39 +323,34 @@ def __init__(self) -> None: ## Convert Text to MDS exceptions -class InputFolderMissingDataError(ValueError, UserError): +class InputFolderMissingDataError(UserError): """Error thrown when the input folder is missing data.""" def __init__(self, input_folder: str) -> None: - self.input_folder = input_folder message = f'No text files were found at {input_folder}.' - super().__init__(message) + super().__init__(message, input_folder=input_folder) -class OutputFolderNotEmptyError(FileExistsError, UserError): +class OutputFolderNotEmptyError(UserError): """Error thrown when the output folder is not empty.""" def __init__(self, output_folder: str) -> None: - self.output_folder = output_folder message = f'{output_folder} is not empty. Please remove or empty it and retry.' - super().__init__(message) + super().__init__(message, output_folder=output_folder) -class MisconfiguredHfDatasetError(ValueError, UserError): +class MisconfiguredHfDatasetError(UserError): """Error thrown when a HuggingFace dataset is misconfigured.""" def __init__(self, dataset_name: str, split: str) -> None: - self.dataset_name = dataset_name - self.split = split message = f'Your dataset (name={dataset_name}, split={split}) is misconfigured. ' + \ 'Please check your dataset format and make sure you can load your dataset locally.' - super().__init__(message) + super().__init__(message, dataset_name=dataset_name, split=split) -class RunTimeoutError(RuntimeError, InternalError): +class RunTimeoutError(InternalError): """Error thrown when a run times out.""" def __init__(self, timeout: int) -> None: - self.timeout = timeout message = f'Run timed out after {timeout} seconds.' - super().__init__(message) + super().__init__(message, timeout=timeout) diff --git a/scripts/train/train.py b/scripts/train/train.py index bfeec14e0b..c07a1898f8 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -30,7 +30,7 @@ maybe_create_mosaicml_logger, ) from llmfoundry.utils.exceptions import ( - ContextualError, + BaseContextualError, EvalDataLoaderLocation, TrainDataLoaderLocation, ) @@ -397,7 +397,7 @@ def main(cfg: DictConfig) -> Trainer: tokenizer, train_cfg.device_train_batch_size, ) - except ContextualError as e: + except BaseContextualError as e: if mosaicml_logger is not None: e.location = TrainDataLoaderLocation mosaicml_logger.log_exception(e) @@ -430,7 +430,7 @@ def main(cfg: DictConfig) -> Trainer: ) if eval_gauntlet_callback is not None: callbacks.append(eval_gauntlet_callback) - except ContextualError as e: + except BaseContextualError as e: if mosaicml_logger is not None: e.location = EvalDataLoaderLocation mosaicml_logger.log_exception(e) @@ -480,7 +480,7 @@ def main(cfg: DictConfig) -> Trainer: evaluators, non_icl_metrics, ) - except ContextualError as e: + except BaseContextualError as e: if mosaicml_logger is not None: e.location = EvalDataLoaderLocation mosaicml_logger.log_exception(e) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py new file mode 100644 index 0000000000..90841c5222 --- /dev/null +++ b/tests/utils/test_exceptions.py @@ -0,0 +1,100 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import contextlib +import inspect +import pickle +from typing import Any, Dict, List, Optional, Type + +import pytest + +import llmfoundry.utils.exceptions as foundry_exceptions + + +def create_exception_object( + exception_class: Type[foundry_exceptions.BaseContextualError], +): + # get required arg types of exception class by inspecting its __init__ method + + if hasattr(inspect, 'get_annotations'): + required_args = inspect.get_annotations( # type: ignore + exception_class.__init__, + ) # type: ignore + else: + required_args = exception_class.__init__.__annotations__ # python 3.9 and below + + # create a dictionary of required args with default values + required_args.pop('kwargs', None) + + def get_default_value(arg_type: Optional[type] = None): + if arg_type == Dict[str, + str] or arg_type == Dict[str, + Any] or arg_type == Dict: + return {'key': 'value'} + elif arg_type == str: + return 'string' + elif arg_type == int: + return 1 + elif arg_type == set[str]: + return {'set'} + elif arg_type == List[str]: + return ['list'] + elif arg_type == None: + return None + elif arg_type == type: + return bool + elif arg_type == List[Dict[str, Any]]: + return [{'key': 'value'}] + raise ValueError(f'Unsupported arg type: {arg_type}') + + required_args.pop('self', None) + required_args.pop('return', None) + kwargs = { + arg: get_default_value(arg_type) + for arg, arg_type in required_args.items() + } + return exception_class(**kwargs) # type: ignore + + +def filter_exceptions(possible_exceptions: List[str]): + attrs = [ + getattr(foundry_exceptions, exception) + for exception in possible_exceptions + ] + classes = [attr for attr in attrs if inspect.isclass(attr)] + exceptions = [ + exception_class for exception_class in classes + if issubclass(exception_class, foundry_exceptions.BaseContextualError) + ] + return exceptions + + +@pytest.mark.parametrize( + 'exception_class', + filter_exceptions(dir(foundry_exceptions)), +) +def test_exception_serialization( + exception_class: Type[foundry_exceptions.BaseContextualError], +): + excluded_base_classes = [ + foundry_exceptions.InternalError, + foundry_exceptions.UserError, + foundry_exceptions.NetworkError, + foundry_exceptions.BaseContextualError, + ] + + exception = create_exception_object(exception_class) + + expect_reduce_error = exception.__class__ in excluded_base_classes + error_context = pytest.raises( + NotImplementedError, + ) if expect_reduce_error else contextlib.nullcontext() + + exc_str = str(exception) + with error_context: + pkl = pickle.dumps(exception) + unpickled_exc = pickle.loads(pkl) + unpickled_exc_str = str(unpickled_exc) + assert exc_str == unpickled_exc_str + assert exception.location == unpickled_exc.location + assert exception.error_attribution == unpickled_exc.error_attribution