diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index fbbc5f2cd9..44e8651cdf 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -234,27 +234,7 @@ def run_query( elif method == 'dbconnect': if spark == None: raise ValueError(f'sparkSession is required for dbconnect') - - try: - df = spark.sql(query) - except Exception as e: - from pyspark.errors import AnalysisException - if isinstance(e, AnalysisException): - if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore - match = re.search( - r"Schema\s+'([^']+)'", - e.message, # pyright: ignore - ) - if match: - schema_name = match.group(1) - action = f'using the schema {schema_name}' - else: - action = 'using the schema' - raise InsufficientPermissionsError(action=action,) from e - raise RuntimeError( - f'Error in querying into schema. Restart sparkSession and try again', - ) from e - + df = spark.sql(query) if collect: return df.collect() return df @@ -469,71 +449,66 @@ def fetch( """ cursor = dbsql.cursor() if dbsql is not None else None try: - nrows = get_total_rows( - tablename, - method, - cursor, - sparkSession, - ) - except Exception as e: - from pyspark.errors import AnalysisException - if isinstance(e, AnalysisException): - if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore - raise InsufficientPermissionsError( - action=f'reading from {tablename}', - ) from e - if isinstance(e, InsufficientPermissionsError): - raise e - raise RuntimeError( - f'Error in get rows from {tablename}. Restart sparkSession and try again', - ) from e + # Get total rows + nrows = get_total_rows(tablename, method, cursor, sparkSession) - try: + # Get columns info columns, order_by, columns_str = get_columns_info( tablename, method, cursor, sparkSession, ) + + if method == 'dbconnect' and sparkSession is not None: + log.info(f'{processes=}') + df = sparkSession.table(tablename) + + # Running the query and collecting the data as arrow or json. + signed, _, _ = df.collect_cf('arrow') # pyright: ignore + log.info(f'len(signed) = {len(signed)}') + + args = get_args(signed, json_output_folder, columns) + + # Stopping the SparkSession to avoid spilling connection state into the subprocesses. + sparkSession.stop() + + with ProcessPoolExecutor(max_workers=processes) as executor: + list(executor.map(download_starargs, args)) + + elif method == 'dbsql' and cursor is not None: + for start in range(0, nrows, batch_size): + log.warning(f'batch {start}') + end = min(start + batch_size, nrows) + fetch_data( + method, + cursor, + sparkSession, + start, + end, + order_by, + tablename, + columns_str, + json_output_folder, + ) + except Exception as e: - raise RuntimeError( - f'Error in get columns from {tablename}. Restart sparkSession and try again', - ) from e + from databricks.sql.exc import ServerOperationError + from pyspark.errors import AnalysisException - if method == 'dbconnect' and sparkSession is not None: - log.info(f'{processes=}') - df = sparkSession.table(tablename) - - # Running the query and collecting the data as arrow or json. - signed, _, _ = df.collect_cf('arrow') # pyright: ignore - log.info(f'len(signed) = {len(signed)}') - - args = get_args(signed, json_output_folder, columns) - - # Stopping the SparkSession to avoid spilling connection state into the subprocesses. - sparkSession.stop() - - with ProcessPoolExecutor(max_workers=processes) as executor: - list(executor.map(download_starargs, args)) - - elif method == 'dbsql' and cursor is not None: - for start in range(0, nrows, batch_size): - log.warning(f'batch {start}') - end = min(start + batch_size, nrows) - fetch_data( - method, - cursor, - sparkSession, - start, - end, - order_by, - tablename, - columns_str, - json_output_folder, - ) + if isinstance(e, (AnalysisException, ServerOperationError)): + if 'INSUFFICIENT_PERMISSIONS' in str(e): + raise InsufficientPermissionsError(str(e)) from e + + if isinstance(e, InsufficientPermissionsError): + raise + + # For any other exception, raise a general error + raise RuntimeError(f'Error processing {tablename}: {str(e)}') from e - if cursor is not None: - cursor.close() + finally: + if cursor is not None: + cursor.close() def validate_and_get_cluster_info( diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 265b9bbe8f..242ac4f32c 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -456,6 +456,13 @@ def __init__( class InsufficientPermissionsError(UserError): """Error thrown when the user does not have sufficient permissions.""" - def __init__(self, action: str) -> None: - message = f'Insufficient permissions when {action}. Please check your permissions.' - super().__init__(message, action=action) + def __init__(self, message: str) -> None: + self.message = message + super().__init__(message) + + def __reduce__(self): + # Return a tuple of class, a tuple of arguments, and optionally state + return (InsufficientPermissionsError, (self.message,)) + + def __str__(self): + return self.message diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index b1a9f1e878..981f5c1ed6 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -10,6 +10,7 @@ from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( InsufficientPermissionsError, download, + fetch, fetch_DT, format_tablename, iterative_combine_jsons, @@ -30,27 +31,33 @@ class MockAnalysisException(Exception): def __init__(self, message: str): self.message = message + def __str__(self): + return self.message + with patch.dict('sys.modules', {'pyspark.errors': MagicMock()}): sys.modules[ 'pyspark.errors' - ].AnalysisException = MockAnalysisException # pyright: ignore + ].AnalysisException = MockAnalysisException # type: ignore mock_spark = MagicMock() mock_spark.sql.side_effect = MockAnalysisException(error_message) with self.assertRaises(InsufficientPermissionsError) as context: - run_query( - 'SELECT * FROM table', + fetch( method='dbconnect', - cursor=None, - spark=mock_spark, + tablename='main.oogabooga', + json_output_folder='/fake/path', + batch_size=1, + processes=1, + sparkSession=mock_spark, + dbsql=None, ) - self.assertIn( - 'using the schema main.oogabooga', + self.assertEqual( str(context.exception), + error_message, ) - mock_spark.sql.assert_called_once_with('SELECT * FROM table') + mock_spark.sql.assert_called() @patch( 'databricks.sql.connect', diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 8bfc7287ab..564dfa2f14 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -4,7 +4,7 @@ import contextlib import inspect import pickle -from typing import Any, Optional +from typing import Any, Optional, get_type_hints import pytest @@ -14,16 +14,30 @@ def create_exception_object( exception_class: type[foundry_exceptions.BaseContextualError], ): - # get required arg types of exception class by inspecting its __init__ method - if hasattr(inspect, 'get_annotations'): - required_args = inspect.get_annotations( # type: ignore - exception_class.__init__, - ) # type: ignore - else: - required_args = exception_class.__init__.__annotations__ # python 3.9 and below - - # create a dictionary of required args with default values + def get_init_annotations(cls: type): + try: + return get_type_hints(cls.__init__) + except (AttributeError, TypeError): + # Handle cases where __init__ does not exist or has no annotations + return {} + + # First, try to get annotations from the class itself + required_args = get_init_annotations(exception_class) + + # If the annotations are empty, look at parent classes + if not required_args: + for parent in exception_class.__bases__: + if parent == object: + break + parent_args = get_init_annotations(parent) + if parent_args: + required_args = parent_args + break + + # Remove self, return, and kwargs + required_args.pop('self', None) + required_args.pop('return', None) required_args.pop('kwargs', None) def get_default_value(arg_type: Optional[type] = None): @@ -51,8 +65,6 @@ def get_default_value(arg_type: Optional[type] = None): return [{'key': 'value'}] raise ValueError(f'Unsupported arg type: {arg_type}') - required_args.pop('self', None) - required_args.pop('return', None) kwargs = { arg: get_default_value(arg_type) for arg, arg_type in required_args.items() @@ -80,6 +92,7 @@ def filter_exceptions(possible_exceptions: list[str]): def test_exception_serialization( exception_class: type[foundry_exceptions.BaseContextualError], ): + print(f'Testing serialization for {exception_class.__name__}') excluded_base_classes = [ foundry_exceptions.InternalError, foundry_exceptions.UserError, @@ -88,6 +101,7 @@ def test_exception_serialization( ] exception = create_exception_object(exception_class) + print(f'Created exception object: {exception}') expect_reduce_error = exception.__class__ in excluded_base_classes error_context = pytest.raises( @@ -95,6 +109,7 @@ def test_exception_serialization( ) if expect_reduce_error else contextlib.nullcontext() exc_str = str(exception) + print(f'Exception string: {exc_str}') with error_context: pkl = pickle.dumps(exception) unpickled_exc = pickle.loads(pkl)