Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the exceptions serializable #1239

Merged
merged 27 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1e94e47
add unit test to identify failing tests
milocress May 23, 2024
fd53b16
clarify issue
milocress May 23, 2024
1b0c81a
dict -> pretty json
milocress May 23, 2024
ee6a3a0
revert exceptions.py
milocress May 23, 2024
0bd647a
Merge branch 'main' into milo/fix-exception-serialization
milocress May 23, 2024
26f1be4
remove multiple inheritance for most classes
milocress May 23, 2024
c630877
merged
milocress May 23, 2024
02cebd2
fix with magic
milocress May 24, 2024
413f63b
Merge branch 'main' into milo/fix-exception-serialization
dakinggg May 24, 2024
a94c39c
Merge branch 'main' into milo/fix-exception-serialization
milocress May 24, 2024
6542a11
address comments
milocress May 24, 2024
0aeb88f
merged
milocress May 24, 2024
024d462
remove parens
milocress May 24, 2024
2af9e53
update tests to use dir instead of all
milocress May 24, 2024
694df23
I've been a silly goose
milocress May 24, 2024
b1fbbef
spelling
milocress May 24, 2024
313d52b
fix version issue and spelling
milocress May 24, 2024
aef5632
more like py-wrong amirite?
milocress May 24, 2024
57b3438
more ignoring
milocress May 24, 2024
f7a9d80
Merge branch 'main' into milo/fix-exception-serialization
dakinggg May 24, 2024
3bc11f7
lots o fixes
dakinggg May 24, 2024
a65e869
Merge branch 'main' into daniels-exceptions
dakinggg May 24, 2024
8df80f6
add the kwargs back
dakinggg May 24, 2024
09e320e
more tests
dakinggg May 24, 2024
ed41e47
fix super call
dakinggg May 25, 2024
b2cc117
merge
dakinggg May 25, 2024
85d17eb
pr comments
dakinggg May 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 118 additions & 76 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,46 +28,98 @@
'InputFolderMissingDataError',
'OutputFolderNotEmptyError',
'MisconfiguredHfDatasetError',
'RunTimeoutError',
]

ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
ALLOWED_PROMPT_KEYS = {'prompt'}
ALLOWED_MESSAGES_KEYS = {'messages'}

ErrorLocation = Union[Literal['TrainDataloader'], Literal['EvalDataloader']]
ErrorAttribution = Union[Literal['UserError'], Literal['InternalError'],
Literal['NetworkError']]
FailureLocation = Union[Literal['TrainDataloader'], Literal['EvalDataloader']]
FailureAttribution = Union[Literal['UserError'], Literal['InternalError'],
Literal['NetworkError']]
TrainDataLoaderLocation = 'TrainDataloader'
EvalDataLoaderLocation = 'EvalDataloader'


class ContextualError(Exception):
class BaseContextualError(Exception):
"""Error thrown when an error occurs in the context of a specific task."""

location: Optional[ErrorLocation] = None
error_attribution: Optional[ErrorAttribution] = None
location: Optional[FailureLocation] = None
error_attribution: Optional[FailureAttribution] = None

def __init__(self, message: str, **kwargs: Any) -> None:
self.error = message
self.serializable_attributes = []

class UserError(ContextualError):
for key, value in kwargs.items():
setattr(self, key, value)
self.serializable_attributes.append(key)

super().__init__(message)

def __reduce__(self):
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
"""Adjust the reduce behavior for pickling.
Because we have custom exception subclasses with constructor args, we
need to adjust the reduce behavior to ensure that the exception can be
pickled. This allows error propagation across processes in
multiprocessing.
"""
if self.__class__ == BaseContextualError:
raise NotImplementedError(
'BaseContextualError is a base class and cannot be pickled.',
)
tuple_of_args = tuple([
getattr(self, key) for key in self.serializable_attributes
])
return (self.__class__, tuple_of_args)


class UserError(BaseContextualError):
"""Error thrown when an error is caused by user input."""

error_attribution = 'UserError'

def __reduce__(self):
if self.__class__ == UserError:
raise NotImplementedError(
'UserError is a base class and cannot be pickled.',
)

return super().__reduce__()

class NetworkError(ContextualError):

class NetworkError(BaseContextualError):
"""Error thrown when an error is caused by a network issue."""

error_attribution = 'NetworkError'

def __reduce__(self):
if self.__class__ == NetworkError:
raise NotImplementedError(
'NetworkError is a base class and cannot be pickled.',
)

return super().__reduce__()

class InternalError(ContextualError):

class InternalError(BaseContextualError):
"""Error thrown when an error is caused by an internal issue."""

error_attribution = 'InternalError'

def __reduce__(self):
if self.__class__ == InternalError:
raise NotImplementedError(
'InternalError is a base class and cannot be pickled.',
)

return super().__reduce__()


# Finetuning dataloader exceptions
class MissingHuggingFaceURLSplitError(ValueError, UserError):
class MissingHuggingFaceURLSplitError(UserError):
"""Error thrown when there's no split used in HF dataset config."""

def __init__(self) -> None:
Expand All @@ -88,12 +140,6 @@ def __init__(
full_dataset_size: int,
minimum_dataset_size: int,
) -> None:
self.dataset_name = dataset_name
self.split = split
self.dataloader_batch_size = dataloader_batch_size
self.world_size = world_size
self.full_dataset_size = full_dataset_size
self.minimum_dataset_size = minimum_dataset_size
message = (
f'Your dataset (name={dataset_name}, split={split}) ' +
f'has {full_dataset_size} samples, but your minimum batch size ' +
Expand All @@ -102,42 +148,48 @@ def __init__(
f'your per device batch size is {dataloader_batch_size}. Please increase the number '
+ f'of samples in your dataset to at least {minimum_dataset_size}.'
)
super().__init__(message)
super().__init__(
message,
dataset_name=dataset_name,
split=split,
dataloader_batch_size=dataloader_batch_size,
world_size=world_size,
full_dataset_size=full_dataset_size,
minimum_dataset_size=minimum_dataset_size,
)


## Tasks exceptions
class UnknownExampleTypeError(KeyError, UserError):
class UnknownExampleTypeError(UserError):
"""Error thrown when an unknown example type is used in a task."""

def __init__(self, example_keys: str) -> None:
self.example = example_keys
message = (
f'Found keys {example_keys} in dataset. Unknown example type. For prompt and response '
f'finetuning, the valid prompt keys are {ALLOWED_PROMPT_KEYS} and the valid response keys are '
f'{ALLOWED_RESPONSE_KEYS}. For chat finetuning, the allowed keys are {ALLOWED_MESSAGES_KEYS}'
)

super().__init__(message)
super().__init__(message, example_keys=example_keys)


class NotEnoughChatDataError(ValueError, UserError):
class NotEnoughChatDataError(UserError):
"""Error thrown when there is not enough chat data to train a model."""

def __init__(self) -> None:
message = 'Chat example must have at least two messages'
super().__init__(message)


class ConsecutiveRepeatedChatRolesError(ValueError, UserError):
class ConsecutiveRepeatedChatRolesError(UserError):
"""Error thrown when there are consecutive repeated chat roles."""

def __init__(self, repeated_role: str) -> None:
self.repeated_role = repeated_role
message = f'Conversation roles must alternate but found {repeated_role} repeated consecutively.'
super().__init__(message)
super().__init__(message, repeated_role=repeated_role)


class ChatTemplateError(ValueError, UserError):
class ChatTemplateError(UserError):
"""Error thrown when a chat template fails to process a sample."""

def __init__(
Expand All @@ -146,114 +198,110 @@ def __init__(
sample: List[Dict[str, Any]],
inner_message: str,
) -> None:
self.template = template
self.sample = sample
message = f'Failed to process sample {sample} with template {template}. {inner_message}'
super().__init__(message)
super().__init__(
message,
template=template,
sample=sample,
inner_message=inner_message,
)


class InvalidLastChatMessageRoleError(ValueError, UserError):
class InvalidLastChatMessageRoleError(UserError):
"""Error thrown when the last message role in a chat example is invalid."""

def __init__(self, last_role: str, expected_roles: set[str]) -> None:
self.last_role = last_role
self.expected_roles = expected_roles
message = f'Invalid last message role: {last_role}. Expected one of: {expected_roles}'
super().__init__(message)
super().__init__(
message,
last_role=last_role,
expected_roles=expected_roles,
)


class IncorrectMessageKeyQuantityError(ValueError, UserError):
class IncorrectMessageKeyQuantityError(UserError):
"""Error thrown when a message has an incorrect number of keys."""

def __init__(self, keys: List[str]) -> None:
self.keys = keys
message = f'Expected 2 keys in message, but found {len(keys)}'
super().__init__(message)
super().__init__(message, keys=keys)


class InvalidRoleError(ValueError, UserError):
class InvalidRoleError(UserError):
"""Error thrown when a role is invalid."""

def __init__(self, role: str, valid_roles: set[str]) -> None:
self.role = role
self.valid_roles = valid_roles
message = f'Expected role to be one of {valid_roles} but found: {role}'
super().__init__(message)
super().__init__(message, role=role, valid_roles=valid_roles)


class InvalidContentTypeError(TypeError, UserError):
class InvalidContentTypeError(UserError):
"""Error thrown when the content type is invalid."""

def __init__(self, content_type: type) -> None:
self.content_type = content_type
message = f'Expected content to be a string, but found {content_type}'
super().__init__(message)
super().__init__(message, content_type=content_type)


class InvalidPromptTypeError(TypeError, UserError):
class InvalidPromptTypeError(UserError):
"""Error thrown when the prompt type is invalid."""

def __init__(self, prompt_type: type) -> None:
self.prompt_type = prompt_type
message = f'Expected prompt to be a string, but found {prompt_type}'
super().__init__(message)
super().__init__(message, prompt_type=prompt_type)


class InvalidResponseTypeError(TypeError, UserError):
class InvalidResponseTypeError(UserError):
"""Error thrown when the response type is invalid."""

def __init__(self, response_type: type) -> None:
self.response_type = response_type
message = f'Expected response to be a string, but found {response_type}'
super().__init__(message)
super().__init__(message, response_type=response_type)


class InvalidPromptResponseKeysError(ValueError, UserError):
class InvalidPromptResponseKeysError(UserError):
"""Error thrown when missing expected prompt and response keys."""

def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]):
self.example = example
message = f'Expected {mapping=} to have keys "prompt" and "response".'
super().__init__(message)
super().__init__(message, mapping=mapping, example=example)


class InvalidFileExtensionError(FileNotFoundError, UserError):
class InvalidFileExtensionError(UserError):
"""Error thrown when a file extension is not a safe extension."""

def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None:
self.dataset_name = dataset_name
self.valid_extensions = valid_extensions
message = (
f'safe_load is set to True. No data files with safe extensions {valid_extensions} '
+ f'found for dataset at local path {dataset_name}.'
)
super().__init__(message)
super().__init__(
message,
dataset_name=dataset_name,
valid_extensions=valid_extensions,
)


class UnableToProcessPromptResponseError(
ValueError,
UserError,
):
"""Error thrown when a prompt and response cannot be processed."""

def __init__(self, input: Dict) -> None:
self.input = input
message = f'Unable to extract prompt/response from {input}'
super().__init__(message)
super().__init__(message, input=input)


## Convert Delta to JSON exceptions
class ClusterDoesNotExistError(ValueError, NetworkError):
class ClusterDoesNotExistError(NetworkError):
"""Error thrown when the cluster does not exist."""

def __init__(self, cluster_id: str) -> None:
self.cluster_id = cluster_id
message = f'Cluster with id {cluster_id} does not exist. Check cluster id and try again!'
super().__init__(message)
super().__init__(message, cluster_id=cluster_id)


class FailedToCreateSQLConnectionError(
RuntimeError,
NetworkError,
):
"""Error thrown when client can't sql connect to Databricks."""
Expand All @@ -265,7 +313,6 @@ def __init__(self) -> None:


class FailedToConnectToDatabricksError(
RuntimeError,
NetworkError,
):
"""Error thrown when the client fails to connect to Databricks."""
Expand All @@ -276,39 +323,34 @@ def __init__(self) -> None:


## Convert Text to MDS exceptions
class InputFolderMissingDataError(ValueError, UserError):
class InputFolderMissingDataError(UserError):
"""Error thrown when the input folder is missing data."""

def __init__(self, input_folder: str) -> None:
self.input_folder = input_folder
message = f'No text files were found at {input_folder}.'
super().__init__(message)
super().__init__(message, input_folder=input_folder)


class OutputFolderNotEmptyError(FileExistsError, UserError):
class OutputFolderNotEmptyError(UserError):
"""Error thrown when the output folder is not empty."""

def __init__(self, output_folder: str) -> None:
self.output_folder = output_folder
message = f'{output_folder} is not empty. Please remove or empty it and retry.'
super().__init__(message)
super().__init__(message, output_folder=output_folder)


class MisconfiguredHfDatasetError(ValueError, UserError):
class MisconfiguredHfDatasetError(UserError):
"""Error thrown when a HuggingFace dataset is misconfigured."""

def __init__(self, dataset_name: str, split: str) -> None:
self.dataset_name = dataset_name
self.split = split
message = f'Your dataset (name={dataset_name}, split={split}) is misconfigured. ' + \
'Please check your dataset format and make sure you can load your dataset locally.'
super().__init__(message)
super().__init__(message, dataset_name=dataset_name, split=split)


class RunTimeoutError(RuntimeError, InternalError):
class RunTimeoutError(InternalError):
"""Error thrown when a run times out."""

def __init__(self, timeout: int) -> None:
self.timeout = timeout
message = f'Run timed out after {timeout} seconds.'
super().__init__(message)
super().__init__(message, timeout=timeout)
Loading
Loading