Skip to content

Commit

Permalink
add context to errors
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed May 7, 2024
1 parent eccf849 commit baff2e3
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 23 deletions.
50 changes: 29 additions & 21 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""Custom exceptions for the LLMFoundry."""
from collections.abc import Mapping
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

__all__ = [
'ALLOWED_RESPONSE_KEYS',
Expand Down Expand Up @@ -35,9 +35,17 @@
ALLOWED_PROMPT_KEYS = {'prompt'}
ALLOWED_MESSAGES_KEYS = {'messages'}

ErrorContext = 'TrainContext' | 'EvalContext'


class ContextualError(BaseException):
"""Error thrown when an error occurs in the context of a specific task."""

context: Optional[ErrorContext] = None


# Finetuning dataloader exceptions
class MissingHuggingFaceURLSplitError(ValueError):
class MissingHuggingFaceURLSplitError(ValueError, ContextualError):
"""Error thrown when there's no split used in HF dataset config."""

def __init__(self) -> None:
Expand All @@ -46,7 +54,7 @@ def __init__(self) -> None:
super().__init__(message)


class NotEnoughDatasetSamplesError(ValueError):
class NotEnoughDatasetSamplesError(ValueError, ContextualError):
"""Error thrown when there is not enough data to train a model."""

def __init__(
Expand Down Expand Up @@ -76,7 +84,7 @@ def __init__(


## Tasks exceptions
class UnknownExampleTypeError(KeyError):
class UnknownExampleTypeError(KeyError, ContextualError):
"""Error thrown when an unknown example type is used in a task."""

def __init__(self, example: Mapping) -> None:
Expand All @@ -89,15 +97,15 @@ def __init__(self, example: Mapping) -> None:
super().__init__(message)


class NotEnoughChatDataError(ValueError):
class NotEnoughChatDataError(ValueError, ContextualError):
"""Error thrown when there is not enough chat data to train a model."""

def __init__(self) -> None:
message = 'Chat example must have at least two messages'
super().__init__(message)


class ConsecutiveRepeatedChatRolesError(ValueError):
class ConsecutiveRepeatedChatRolesError(ValueError, ContextualError):
"""Error thrown when there are consecutive repeated chat roles."""

def __init__(self, repeated_role: str) -> None:
Expand All @@ -106,7 +114,7 @@ def __init__(self, repeated_role: str) -> None:
super().__init__(message)


class InvalidLastChatMessageRoleError(ValueError):
class InvalidLastChatMessageRoleError(ValueError, ContextualError):
"""Error thrown when the last message role in a chat example is invalid."""

def __init__(self, last_role: str, expected_roles: set[str]) -> None:
Expand All @@ -116,7 +124,7 @@ def __init__(self, last_role: str, expected_roles: set[str]) -> None:
super().__init__(message)


class IncorrectMessageKeyQuantityError(ValueError):
class IncorrectMessageKeyQuantityError(ValueError, ContextualError):
"""Error thrown when a message has an incorrect number of keys."""

def __init__(self, keys: List[str]) -> None:
Expand All @@ -125,7 +133,7 @@ def __init__(self, keys: List[str]) -> None:
super().__init__(message)


class InvalidRoleError(ValueError):
class InvalidRoleError(ValueError, ContextualError):
"""Error thrown when a role is invalid."""

def __init__(self, role: str, valid_roles: set[str]) -> None:
Expand All @@ -135,7 +143,7 @@ def __init__(self, role: str, valid_roles: set[str]) -> None:
super().__init__(message)


class InvalidContentTypeError(TypeError):
class InvalidContentTypeError(TypeError, ContextualError):
"""Error thrown when the content type is invalid."""

def __init__(self, content_type: type) -> None:
Expand All @@ -144,7 +152,7 @@ def __init__(self, content_type: type) -> None:
super().__init__(message)


class InvalidPromptTypeError(TypeError):
class InvalidPromptTypeError(TypeError, ContextualError):
"""Error thrown when the prompt type is invalid."""

def __init__(self, prompt_type: type) -> None:
Expand All @@ -153,7 +161,7 @@ def __init__(self, prompt_type: type) -> None:
super().__init__(message)


class InvalidResponseTypeError(TypeError):
class InvalidResponseTypeError(TypeError, ContextualError):
"""Error thrown when the response type is invalid."""

def __init__(self, response_type: type) -> None:
Expand All @@ -162,7 +170,7 @@ def __init__(self, response_type: type) -> None:
super().__init__(message)


class InvalidPromptResponseKeysError(ValueError):
class InvalidPromptResponseKeysError(ValueError, ContextualError):
"""Error thrown when missing expected prompt and response keys."""

def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]):
Expand All @@ -171,7 +179,7 @@ def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]):
super().__init__(message)


class InvalidFileExtensionError(FileNotFoundError):
class InvalidFileExtensionError(FileNotFoundError, ContextualError):
"""Error thrown when a file extension is not a safe extension."""

def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None:
Expand All @@ -184,7 +192,7 @@ def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None:
super().__init__(message)


class UnableToProcessPromptResponseError(ValueError):
class UnableToProcessPromptResponseError(ValueError, ContextualError):
"""Error thrown when a prompt and response cannot be processed."""

def __init__(self, input: Dict) -> None:
Expand All @@ -194,7 +202,7 @@ def __init__(self, input: Dict) -> None:


## Convert Delta to JSON exceptions
class ClusterDoesNotExistError(ValueError):
class ClusterDoesNotExistError(ValueError, ContextualError):
"""Error thrown when the cluster does not exist."""

def __init__(self, cluster_id: str) -> None:
Expand All @@ -203,15 +211,15 @@ def __init__(self, cluster_id: str) -> None:
super().__init__(message)


class FailedToCreateSQLConnectionError(RuntimeError):
class FailedToCreateSQLConnectionError(RuntimeError, ContextualError):
"""Error thrown when client can't sql connect to Databricks."""

def __init__(self) -> None:
message = 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!'
super().__init__(message)


class FailedToConnectToDatabricksError(RuntimeError):
class FailedToConnectToDatabricksError(RuntimeError, ContextualError):
"""Error thrown when the client fails to connect to Databricks."""

def __init__(self) -> None:
Expand All @@ -220,7 +228,7 @@ def __init__(self) -> None:


## Convert Text to MDS exceptions
class InputFolderMissingDataError(ValueError):
class InputFolderMissingDataError(ValueError, ContextualError):
"""Error thrown when the input folder is missing data."""

def __init__(self, input_folder: str) -> None:
Expand All @@ -229,7 +237,7 @@ def __init__(self, input_folder: str) -> None:
super().__init__(message)


class OutputFolderNotEmptyError(FileExistsError):
class OutputFolderNotEmptyError(FileExistsError, ContextualError):
"""Error thrown when the output folder is not empty."""

def __init__(self, output_folder: str) -> None:
Expand All @@ -238,7 +246,7 @@ def __init__(self, output_folder: str) -> None:
super().__init__(message)


class MisconfiguredHfDatasetError(ValueError):
class MisconfiguredHfDatasetError(ValueError, ContextualError):
"""Error thrown when a HuggingFace dataset is misconfigured."""

def __init__(self, dataset_name: str, split: str) -> None:
Expand Down
7 changes: 5 additions & 2 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
log_train_analytics,
maybe_create_mosaicml_logger,
)
from llmfoundry.utils.exceptions import ContextualError

install()

Expand Down Expand Up @@ -391,7 +392,8 @@ def main(cfg: DictConfig) -> Trainer:
tokenizer,
train_cfg.device_train_batch_size,
)
except Exception as e:
except ContextualError as e:
e.context = 'TrainContext'
if mosaicml_logger is not None:
mosaicml_logger.log_exception(e)
raise e
Expand Down Expand Up @@ -467,8 +469,9 @@ def main(cfg: DictConfig) -> Trainer:
evaluators,
non_icl_metrics,
)
except Exception as e:
except ContextualError as e:
if mosaicml_logger is not None:
e.context = 'EvalContext'
mosaicml_logger.log_exception(e)
raise e

Expand Down

0 comments on commit baff2e3

Please sign in to comment.