Skip to content

Commit

Permalink
update tests to use dir instead of all
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed May 24, 2024
1 parent 024d462 commit 2af9e53
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
20 changes: 10 additions & 10 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
ALLOWED_PROMPT_KEYS = {'prompt'}
ALLOWED_MESSAGES_KEYS = {'messages'}

ErrorLocation = Union[Literal['TrainDataloader'], Literal['EvalDataloader']]
ErrorAttribution = Union[Literal['UserError'], Literal['InternalError'],
Literal['NetworkError']]
FailureLocation = Union[Literal['TrainDataloader'], Literal['EvalDataloader']]
FailureAttribution = Union[Literal['UserError'], Literal['InternalError'],
Literal['NetworkError']]
TrainDataLoaderLocation = 'TrainDataloader'
EvalDataLoaderLocation = 'EvalDataloader'


class SerializableError:
class BaseSerialaziableError:

def __getstate__(self):
return self.__dict__
Expand All @@ -52,29 +52,29 @@ def __setstate__(self, state: Dict[str, Any]):
setattr(self, key, value)


class ContextualError(Exception, SerializableError):
class BaseContextualError(Exception, BaseSerialaziableError):
"""Error thrown when an error occurs in the context of a specific task."""

location: Optional[ErrorLocation] = None
error_attribution: Optional[ErrorAttribution] = None
location: Optional[FailureLocation] = None
error_attribution: Optional[FailureAttribution] = None

def __init__(self, message: str) -> None:
self.error = message


class UserError(ContextualError):
class UserError(BaseContextualError):
"""Error thrown when an error is caused by user input."""

error_attribution = 'UserError'


class NetworkError(ContextualError):
class NetworkError(BaseContextualError):
"""Error thrown when an error is caused by a network issue."""

error_attribution = 'NetworkError'


class InternalError(ContextualError):
class InternalError(BaseContextualError):
"""Error thrown when an error is caused by an internal issue."""

error_attribution = 'InternalError'
Expand Down
8 changes: 4 additions & 4 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
maybe_create_mosaicml_logger,
)
from llmfoundry.utils.exceptions import (
ContextualError,
Base,
EvalDataLoaderLocation,
TrainDataLoaderLocation,
)
Expand Down Expand Up @@ -396,7 +396,7 @@ def main(cfg: DictConfig) -> Trainer:
tokenizer,
train_cfg.device_train_batch_size,
)
except ContextualError as e:
except Base as e:
if mosaicml_logger is not None:
e.location = TrainDataLoaderLocation
mosaicml_logger.log_exception(e)
Expand Down Expand Up @@ -429,7 +429,7 @@ def main(cfg: DictConfig) -> Trainer:
)
if eval_gauntlet_callback is not None:
callbacks.append(eval_gauntlet_callback)
except ContextualError as e:
except Base as e:
if mosaicml_logger is not None:
e.location = EvalDataLoaderLocation
mosaicml_logger.log_exception(e)
Expand Down Expand Up @@ -479,7 +479,7 @@ def main(cfg: DictConfig) -> Trainer:
evaluators,
non_icl_metrics,
)
except ContextualError as e:
except Base as e:
if mosaicml_logger is not None:
e.location = EvalDataLoaderLocation
mosaicml_logger.log_exception(e)
Expand Down
10 changes: 7 additions & 3 deletions tests/utils/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest

from llmfoundry.utils.exceptions import __all__ as all_exceptions
import llmfoundry.utils.exceptions as foundry_exceptions


def create_exception_object(exception_name: str):
Expand Down Expand Up @@ -52,11 +52,15 @@ def get_default_value(arg_type: Optional[type] = None):
def filter_exceptions(exceptions: List[str]):
return [
exception for exception in exceptions
if ('Error' in exception or 'Exception' in exception)
if ('Error' in exception or 'Exception' in exception) and
('Base' not in exception)
]


@pytest.mark.parametrize('exception_name', filter_exceptions(all_exceptions))
@pytest.mark.parametrize(
'exception_name',
filter_exceptions(dir(foundry_exceptions)),
)
def test_exception_serialization(exception_name: str):
exception = create_exception_object(exception_name)

Expand Down

0 comments on commit 2af9e53

Please sign in to comment.