From a2ccc9ee350863e42e36119b8c60dc5e6ce85860 Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Thu, 26 Sep 2024 15:32:53 -0700 Subject: [PATCH 01/10] insufficient errors --- .../data_prep/convert_delta_to_json.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index fbbc5f2cd9..b23655d5f5 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -228,7 +228,22 @@ def run_query( if method == 'dbsql': if cursor is None: raise ValueError(f'cursor cannot be None if using method dbsql') - cursor.execute(query) + try: + cursor.execute(query) + except Exception as e: + from databricks.sql.exc import ServerOperationError + if isinstance(e, ServerOperationError): + if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore + match = re.search( + r"'([^']+)'", + e.message, # pyright: ignore + ) + if match: + table_name = match.group(1) + action = f'accessing table {table_name}' + else: + action = 'accessing table' + raise InsufficientPermissionsError(action=action,) from e if collect: return cursor.fetchall() elif method == 'dbconnect': From cf5d5197888f750cdeb7840bd8761be4e2fe4373 Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Thu, 26 Sep 2024 22:31:44 -0700 Subject: [PATCH 02/10] test --- .../data_prep/convert_delta_to_json.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index b23655d5f5..5cce321ba9 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -491,14 +491,22 @@ def fetch( sparkSession, ) except Exception as e: + from databricks.sql.exc import ServerOperationError from pyspark.errors import AnalysisException - if isinstance(e, AnalysisException): - if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore - raise InsufficientPermissionsError( - action=f'reading from {tablename}', - ) from e + + if isinstance(e, (AnalysisException, ServerOperationError)): + if 'INSUFFICIENT_PERMISSIONS' in str(e): + if isinstance( + e, + AnalysisException, + ) or isinstance(e, ServerOperationError): + raise InsufficientPermissionsError( + action=f'reading from {tablename}', + ) from e + if isinstance(e, InsufficientPermissionsError): - raise e + raise + raise RuntimeError( f'Error in get rows from {tablename}. Restart sparkSession and try again', ) from e From 2cb38b590dc36db347b859b4ae0db62def200317 Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Thu, 26 Sep 2024 22:38:28 -0700 Subject: [PATCH 03/10] test --- .../data_prep/convert_delta_to_json.py | 50 +++++++------------ 1 file changed, 17 insertions(+), 33 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 5cce321ba9..808a18a8b6 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -216,63 +216,47 @@ def run_query( spark: Optional['SparkSession'] = None, collect: bool = True, ) -> Optional[Union[list['Row'], 'DataFrame', 'SparkDataFrame']]: - """Run SQL query via databricks-connect or databricks-sql. - - Args: - query (str): sql query - method (str): select from dbsql and dbconnect - cursor (Optional[Cursor]): connection.cursor - spark (Optional[SparkSession]): spark session - collect (bool): whether to get the underlying data from spark dataframe - """ + """Run SQL query via databricks-connect or databricks-sql.""" if method == 'dbsql': if cursor is None: raise ValueError(f'cursor cannot be None if using method dbsql') try: cursor.execute(query) + if collect: + return cursor.fetchall() except Exception as e: from databricks.sql.exc import ServerOperationError if isinstance(e, ServerOperationError): - if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore - match = re.search( - r"'([^']+)'", - e.message, # pyright: ignore - ) + if 'INSUFFICIENT_PERMISSIONS' in str(e): + match = re.search(r"'([^']+)'", str(e)) if match: table_name = match.group(1) action = f'accessing table {table_name}' else: action = 'accessing table' - raise InsufficientPermissionsError(action=action,) from e - if collect: - return cursor.fetchall() + raise InsufficientPermissionsError(action=action) from e + raise elif method == 'dbconnect': if spark == None: raise ValueError(f'sparkSession is required for dbconnect') try: df = spark.sql(query) + if collect: + return df.collect() + return df except Exception as e: from pyspark.errors import AnalysisException if isinstance(e, AnalysisException): - if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore - match = re.search( - r"Schema\s+'([^']+)'", - e.message, # pyright: ignore - ) + if 'INSUFFICIENT_PERMISSIONS' in str(e): + match = re.search(r"Table '([^']+)'", str(e)) if match: - schema_name = match.group(1) - action = f'using the schema {schema_name}' + table_name = match.group(1) + action = f'accessing table {table_name}' else: - action = 'using the schema' - raise InsufficientPermissionsError(action=action,) from e - raise RuntimeError( - f'Error in querying into schema. Restart sparkSession and try again', - ) from e - - if collect: - return df.collect() - return df + action = 'accessing table' + raise InsufficientPermissionsError(action=action) from e + raise else: raise ValueError(f'Unrecognized method: {method}') From ae1810aea22fdc2d1051b51db83161d748fdb7f2 Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Thu, 26 Sep 2024 22:39:57 -0700 Subject: [PATCH 04/10] test --- .../command_utils/data_prep/convert_delta_to_json.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 808a18a8b6..67f73c8b06 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -216,7 +216,14 @@ def run_query( spark: Optional['SparkSession'] = None, collect: bool = True, ) -> Optional[Union[list['Row'], 'DataFrame', 'SparkDataFrame']]: - """Run SQL query via databricks-connect or databricks-sql.""" + """Run SQL query via databricks-connect or databricks-sql. + Args: + query (str): sql query + method (str): select from dbsql and dbconnect + cursor (Optional[Cursor]): connection.cursor + spark (Optional[SparkSession]): spark session + collect (bool): whether to get the underlying data from spark dataframe + """ if method == 'dbsql': if cursor is None: raise ValueError(f'cursor cannot be None if using method dbsql') From d8af01543da8f1ec1ed8fba5eb7702626b17ddc5 Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Thu, 26 Sep 2024 22:51:07 -0700 Subject: [PATCH 05/10] test central erroring --- .../data_prep/convert_delta_to_json.py | 164 ++++++++---------- 1 file changed, 73 insertions(+), 91 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 67f73c8b06..73af7a7769 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -217,6 +217,7 @@ def run_query( collect: bool = True, ) -> Optional[Union[list['Row'], 'DataFrame', 'SparkDataFrame']]: """Run SQL query via databricks-connect or databricks-sql. + Args: query (str): sql query method (str): select from dbsql and dbconnect @@ -227,43 +228,36 @@ def run_query( if method == 'dbsql': if cursor is None: raise ValueError(f'cursor cannot be None if using method dbsql') - try: - cursor.execute(query) - if collect: - return cursor.fetchall() - except Exception as e: - from databricks.sql.exc import ServerOperationError - if isinstance(e, ServerOperationError): - if 'INSUFFICIENT_PERMISSIONS' in str(e): - match = re.search(r"'([^']+)'", str(e)) - if match: - table_name = match.group(1) - action = f'accessing table {table_name}' - else: - action = 'accessing table' - raise InsufficientPermissionsError(action=action) from e - raise + cursor.execute(query) + if collect: + return cursor.fetchall() elif method == 'dbconnect': if spark == None: raise ValueError(f'sparkSession is required for dbconnect') try: df = spark.sql(query) - if collect: - return df.collect() - return df except Exception as e: from pyspark.errors import AnalysisException if isinstance(e, AnalysisException): - if 'INSUFFICIENT_PERMISSIONS' in str(e): - match = re.search(r"Table '([^']+)'", str(e)) + if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore + match = re.search( + r"Schema\s+'([^']+)'", + e.message, # pyright: ignore + ) if match: - table_name = match.group(1) - action = f'accessing table {table_name}' + schema_name = match.group(1) + action = f'using the schema {schema_name}' else: - action = 'accessing table' - raise InsufficientPermissionsError(action=action) from e - raise + action = 'using the schema' + raise InsufficientPermissionsError(action=action,) from e + raise RuntimeError( + f'Error in querying into schema. Restart sparkSession and try again', + ) from e + + if collect: + return df.collect() + return df else: raise ValueError(f'Unrecognized method: {method}') @@ -475,79 +469,67 @@ def fetch( """ cursor = dbsql.cursor() if dbsql is not None else None try: - nrows = get_total_rows( - tablename, - method, - cursor, - sparkSession, - ) + # Get total rows + nrows = get_total_rows(tablename, method, cursor, sparkSession) + + # Get columns info + columns, order_by, columns_str = get_columns_info(tablename, method, cursor, sparkSession) + + if method == 'dbconnect' and sparkSession is not None: + log.info(f'{processes=}') + df = sparkSession.table(tablename) + + # Running the query and collecting the data as arrow or json. + signed, _, _ = df.collect_cf('arrow') # pyright: ignore + log.info(f'len(signed) = {len(signed)}') + + args = get_args(signed, json_output_folder, columns) + + # Stopping the SparkSession to avoid spilling connection state into the subprocesses. + sparkSession.stop() + + with ProcessPoolExecutor(max_workers=processes) as executor: + list(executor.map(download_starargs, args)) + + elif method == 'dbsql' and cursor is not None: + for start in range(0, nrows, batch_size): + log.warning(f'batch {start}') + end = min(start + batch_size, nrows) + fetch_data( + method, + cursor, + sparkSession, + start, + end, + order_by, + tablename, + columns_str, + json_output_folder, + ) + except Exception as e: - from databricks.sql.exc import ServerOperationError from pyspark.errors import AnalysisException + from databricks.sql.exc import ServerOperationError if isinstance(e, (AnalysisException, ServerOperationError)): if 'INSUFFICIENT_PERMISSIONS' in str(e): - if isinstance( - e, - AnalysisException, - ) or isinstance(e, ServerOperationError): - raise InsufficientPermissionsError( - action=f'reading from {tablename}', - ) from e - + match = re.search(r"(?:Table|Schema)\s+'([^']+)'", str(e)) + if match: + object_name = match.group(1) + action = f'accessing {object_name}' + else: + action = f'accessing {tablename}' + raise InsufficientPermissionsError(action=action) from e + if isinstance(e, InsufficientPermissionsError): raise - raise RuntimeError( - f'Error in get rows from {tablename}. Restart sparkSession and try again', - ) from e - - try: - columns, order_by, columns_str = get_columns_info( - tablename, - method, - cursor, - sparkSession, - ) - except Exception as e: - raise RuntimeError( - f'Error in get columns from {tablename}. Restart sparkSession and try again', - ) from e - - if method == 'dbconnect' and sparkSession is not None: - log.info(f'{processes=}') - df = sparkSession.table(tablename) - - # Running the query and collecting the data as arrow or json. - signed, _, _ = df.collect_cf('arrow') # pyright: ignore - log.info(f'len(signed) = {len(signed)}') - - args = get_args(signed, json_output_folder, columns) - - # Stopping the SparkSession to avoid spilling connection state into the subprocesses. - sparkSession.stop() - - with ProcessPoolExecutor(max_workers=processes) as executor: - list(executor.map(download_starargs, args)) - - elif method == 'dbsql' and cursor is not None: - for start in range(0, nrows, batch_size): - log.warning(f'batch {start}') - end = min(start + batch_size, nrows) - fetch_data( - method, - cursor, - sparkSession, - start, - end, - order_by, - tablename, - columns_str, - json_output_folder, - ) + # For any other exception, raise a general error + raise RuntimeError(f"Error processing {tablename}: {str(e)}") from e - if cursor is not None: - cursor.close() + finally: + if cursor is not None: + cursor.close() def validate_and_get_cluster_info( @@ -814,4 +796,4 @@ def convert_delta_to_json_from_args( DATABRICKS_HOST=DATABRICKS_HOST, DATABRICKS_TOKEN=DATABRICKS_TOKEN, ) - log.info(f'Elapsed time {time.time() - tik}') + log.info(f'Elapsed time {time.time() - tik}') \ No newline at end of file From 79fcbc63ef6f1bfa88b36945ac7db1dd83b5619e Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Thu, 26 Sep 2024 22:53:25 -0700 Subject: [PATCH 06/10] test central erroring --- .../data_prep/convert_delta_to_json.py | 22 +------------------ 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 73af7a7769..7b3ab7a5e8 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -234,27 +234,7 @@ def run_query( elif method == 'dbconnect': if spark == None: raise ValueError(f'sparkSession is required for dbconnect') - - try: - df = spark.sql(query) - except Exception as e: - from pyspark.errors import AnalysisException - if isinstance(e, AnalysisException): - if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore - match = re.search( - r"Schema\s+'([^']+)'", - e.message, # pyright: ignore - ) - if match: - schema_name = match.group(1) - action = f'using the schema {schema_name}' - else: - action = 'using the schema' - raise InsufficientPermissionsError(action=action,) from e - raise RuntimeError( - f'Error in querying into schema. Restart sparkSession and try again', - ) from e - + df = spark.sql(query) if collect: return df.collect() return df From 5a04f7c0ac9a91be9af32c84b24dd8547c14b3ed Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Thu, 26 Sep 2024 23:00:17 -0700 Subject: [PATCH 07/10] test central erroring --- .../command_utils/data_prep/convert_delta_to_json.py | 8 +------- llmfoundry/utils/exceptions.py | 5 ++--- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 7b3ab7a5e8..81d4cd2715 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -493,13 +493,7 @@ def fetch( if isinstance(e, (AnalysisException, ServerOperationError)): if 'INSUFFICIENT_PERMISSIONS' in str(e): - match = re.search(r"(?:Table|Schema)\s+'([^']+)'", str(e)) - if match: - object_name = match.group(1) - action = f'accessing {object_name}' - else: - action = f'accessing {tablename}' - raise InsufficientPermissionsError(action=action) from e + raise InsufficientPermissionsError(str(e)) from e if isinstance(e, InsufficientPermissionsError): raise diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 265b9bbe8f..5904062285 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -456,6 +456,5 @@ def __init__( class InsufficientPermissionsError(UserError): """Error thrown when the user does not have sufficient permissions.""" - def __init__(self, action: str) -> None: - message = f'Insufficient permissions when {action}. Please check your permissions.' - super().__init__(message, action=action) + def __init__(self, message: str) -> None: + super().__init__(message) From 1494c913f59ab1257a757e7b86b480d8483093d2 Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Thu, 26 Sep 2024 23:16:26 -0700 Subject: [PATCH 08/10] precommit --- .../data_prep/convert_delta_to_json.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 81d4cd2715..44e8651cdf 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -451,10 +451,15 @@ def fetch( try: # Get total rows nrows = get_total_rows(tablename, method, cursor, sparkSession) - + # Get columns info - columns, order_by, columns_str = get_columns_info(tablename, method, cursor, sparkSession) - + columns, order_by, columns_str = get_columns_info( + tablename, + method, + cursor, + sparkSession, + ) + if method == 'dbconnect' and sparkSession is not None: log.info(f'{processes=}') df = sparkSession.table(tablename) @@ -488,18 +493,18 @@ def fetch( ) except Exception as e: - from pyspark.errors import AnalysisException from databricks.sql.exc import ServerOperationError + from pyspark.errors import AnalysisException if isinstance(e, (AnalysisException, ServerOperationError)): if 'INSUFFICIENT_PERMISSIONS' in str(e): raise InsufficientPermissionsError(str(e)) from e - + if isinstance(e, InsufficientPermissionsError): raise # For any other exception, raise a general error - raise RuntimeError(f"Error processing {tablename}: {str(e)}") from e + raise RuntimeError(f'Error processing {tablename}: {str(e)}') from e finally: if cursor is not None: @@ -770,4 +775,4 @@ def convert_delta_to_json_from_args( DATABRICKS_HOST=DATABRICKS_HOST, DATABRICKS_TOKEN=DATABRICKS_TOKEN, ) - log.info(f'Elapsed time {time.time() - tik}') \ No newline at end of file + log.info(f'Elapsed time {time.time() - tik}') From dcf4569a6382be80540d81905795bf27a9450c6c Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Fri, 27 Sep 2024 12:35:52 -0700 Subject: [PATCH 09/10] update tests --- llmfoundry/utils/exceptions.py | 8 +++++ .../data_prep/test_convert_delta_to_json.py | 23 +++++++----- tests/utils/test_exceptions.py | 36 +++++++++++++------ 3 files changed, 48 insertions(+), 19 deletions(-) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 5904062285..242ac4f32c 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -457,4 +457,12 @@ class InsufficientPermissionsError(UserError): """Error thrown when the user does not have sufficient permissions.""" def __init__(self, message: str) -> None: + self.message = message super().__init__(message) + + def __reduce__(self): + # Return a tuple of class, a tuple of arguments, and optionally state + return (InsufficientPermissionsError, (self.message,)) + + def __str__(self): + return self.message diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index b1a9f1e878..981f5c1ed6 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -10,6 +10,7 @@ from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( InsufficientPermissionsError, download, + fetch, fetch_DT, format_tablename, iterative_combine_jsons, @@ -30,27 +31,33 @@ class MockAnalysisException(Exception): def __init__(self, message: str): self.message = message + def __str__(self): + return self.message + with patch.dict('sys.modules', {'pyspark.errors': MagicMock()}): sys.modules[ 'pyspark.errors' - ].AnalysisException = MockAnalysisException # pyright: ignore + ].AnalysisException = MockAnalysisException # type: ignore mock_spark = MagicMock() mock_spark.sql.side_effect = MockAnalysisException(error_message) with self.assertRaises(InsufficientPermissionsError) as context: - run_query( - 'SELECT * FROM table', + fetch( method='dbconnect', - cursor=None, - spark=mock_spark, + tablename='main.oogabooga', + json_output_folder='/fake/path', + batch_size=1, + processes=1, + sparkSession=mock_spark, + dbsql=None, ) - self.assertIn( - 'using the schema main.oogabooga', + self.assertEqual( str(context.exception), + error_message, ) - mock_spark.sql.assert_called_once_with('SELECT * FROM table') + mock_spark.sql.assert_called() @patch( 'databricks.sql.connect', diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 8bfc7287ab..44c6354998 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -14,16 +14,29 @@ def create_exception_object( exception_class: type[foundry_exceptions.BaseContextualError], ): - # get required arg types of exception class by inspecting its __init__ method - if hasattr(inspect, 'get_annotations'): - required_args = inspect.get_annotations( # type: ignore - exception_class.__init__, - ) # type: ignore - else: - required_args = exception_class.__init__.__annotations__ # python 3.9 and below - - # create a dictionary of required args with default values + def get_init_annotations(cls: type): + if hasattr(inspect, 'get_annotations'): + return inspect.get_annotations(cls.__init__) + else: + return getattr(cls.__init__, '__annotations__', {}) + + # First, try to get annotations from the class itself + required_args = get_init_annotations(exception_class) + + # If the annotations are empty, look at parent classes + if not required_args: + for parent in exception_class.__bases__: + if parent == object: + break + parent_args = get_init_annotations(parent) + if parent_args: + required_args = parent_args + break + + # Remove self, return, and kwargs + required_args.pop('self', None) + required_args.pop('return', None) required_args.pop('kwargs', None) def get_default_value(arg_type: Optional[type] = None): @@ -51,8 +64,6 @@ def get_default_value(arg_type: Optional[type] = None): return [{'key': 'value'}] raise ValueError(f'Unsupported arg type: {arg_type}') - required_args.pop('self', None) - required_args.pop('return', None) kwargs = { arg: get_default_value(arg_type) for arg, arg_type in required_args.items() @@ -80,6 +91,7 @@ def filter_exceptions(possible_exceptions: list[str]): def test_exception_serialization( exception_class: type[foundry_exceptions.BaseContextualError], ): + print(f'Testing serialization for {exception_class.__name__}') excluded_base_classes = [ foundry_exceptions.InternalError, foundry_exceptions.UserError, @@ -88,6 +100,7 @@ def test_exception_serialization( ] exception = create_exception_object(exception_class) + print(f'Created exception object: {exception}') expect_reduce_error = exception.__class__ in excluded_base_classes error_context = pytest.raises( @@ -95,6 +108,7 @@ def test_exception_serialization( ) if expect_reduce_error else contextlib.nullcontext() exc_str = str(exception) + print(f'Exception string: {exc_str}') with error_context: pkl = pickle.dumps(exception) unpickled_exc = pickle.loads(pkl) From 5edb5a59b133b52489d6a1c73afd83c9342b14e6 Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Fri, 27 Sep 2024 12:48:29 -0700 Subject: [PATCH 10/10] precommit --- tests/utils/test_exceptions.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 44c6354998..564dfa2f14 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -4,7 +4,7 @@ import contextlib import inspect import pickle -from typing import Any, Optional +from typing import Any, Optional, get_type_hints import pytest @@ -16,10 +16,11 @@ def create_exception_object( ): def get_init_annotations(cls: type): - if hasattr(inspect, 'get_annotations'): - return inspect.get_annotations(cls.__init__) - else: - return getattr(cls.__init__, '__annotations__', {}) + try: + return get_type_hints(cls.__init__) + except (AttributeError, TypeError): + # Handle cases where __init__ does not exist or has no annotations + return {} # First, try to get annotations from the class itself required_args = get_init_annotations(exception_class)