diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 5904062285..242ac4f32c 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -457,4 +457,12 @@ class InsufficientPermissionsError(UserError): """Error thrown when the user does not have sufficient permissions.""" def __init__(self, message: str) -> None: + self.message = message super().__init__(message) + + def __reduce__(self): + # Return a tuple of class, a tuple of arguments, and optionally state + return (InsufficientPermissionsError, (self.message,)) + + def __str__(self): + return self.message diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index b1a9f1e878..981f5c1ed6 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -10,6 +10,7 @@ from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( InsufficientPermissionsError, download, + fetch, fetch_DT, format_tablename, iterative_combine_jsons, @@ -30,27 +31,33 @@ class MockAnalysisException(Exception): def __init__(self, message: str): self.message = message + def __str__(self): + return self.message + with patch.dict('sys.modules', {'pyspark.errors': MagicMock()}): sys.modules[ 'pyspark.errors' - ].AnalysisException = MockAnalysisException # pyright: ignore + ].AnalysisException = MockAnalysisException # type: ignore mock_spark = MagicMock() mock_spark.sql.side_effect = MockAnalysisException(error_message) with self.assertRaises(InsufficientPermissionsError) as context: - run_query( - 'SELECT * FROM table', + fetch( method='dbconnect', - cursor=None, - spark=mock_spark, + tablename='main.oogabooga', + json_output_folder='/fake/path', + batch_size=1, + processes=1, + sparkSession=mock_spark, + dbsql=None, ) - self.assertIn( - 'using the schema main.oogabooga', + self.assertEqual( str(context.exception), + error_message, ) - mock_spark.sql.assert_called_once_with('SELECT * FROM table') + mock_spark.sql.assert_called() @patch( 'databricks.sql.connect', diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 8bfc7287ab..44c6354998 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -14,16 +14,29 @@ def create_exception_object( exception_class: type[foundry_exceptions.BaseContextualError], ): - # get required arg types of exception class by inspecting its __init__ method - if hasattr(inspect, 'get_annotations'): - required_args = inspect.get_annotations( # type: ignore - exception_class.__init__, - ) # type: ignore - else: - required_args = exception_class.__init__.__annotations__ # python 3.9 and below - - # create a dictionary of required args with default values + def get_init_annotations(cls: type): + if hasattr(inspect, 'get_annotations'): + return inspect.get_annotations(cls.__init__) + else: + return getattr(cls.__init__, '__annotations__', {}) + + # First, try to get annotations from the class itself + required_args = get_init_annotations(exception_class) + + # If the annotations are empty, look at parent classes + if not required_args: + for parent in exception_class.__bases__: + if parent == object: + break + parent_args = get_init_annotations(parent) + if parent_args: + required_args = parent_args + break + + # Remove self, return, and kwargs + required_args.pop('self', None) + required_args.pop('return', None) required_args.pop('kwargs', None) def get_default_value(arg_type: Optional[type] = None): @@ -51,8 +64,6 @@ def get_default_value(arg_type: Optional[type] = None): return [{'key': 'value'}] raise ValueError(f'Unsupported arg type: {arg_type}') - required_args.pop('self', None) - required_args.pop('return', None) kwargs = { arg: get_default_value(arg_type) for arg, arg_type in required_args.items() @@ -80,6 +91,7 @@ def filter_exceptions(possible_exceptions: list[str]): def test_exception_serialization( exception_class: type[foundry_exceptions.BaseContextualError], ): + print(f'Testing serialization for {exception_class.__name__}') excluded_base_classes = [ foundry_exceptions.InternalError, foundry_exceptions.UserError, @@ -88,6 +100,7 @@ def test_exception_serialization( ] exception = create_exception_object(exception_class) + print(f'Created exception object: {exception}') expect_reduce_error = exception.__class__ in excluded_base_classes error_context = pytest.raises( @@ -95,6 +108,7 @@ def test_exception_serialization( ) if expect_reduce_error else contextlib.nullcontext() exc_str = str(exception) + print(f'Exception string: {exc_str}') with error_context: pkl = pickle.dumps(exception) unpickled_exc = pickle.loads(pkl)