Skip to content

Commit

Permalink
Insufficient Permissions Error when trying to access table (mosaicml#…
Browse files Browse the repository at this point in the history
…1555)

Co-authored-by: v-chen_data <[email protected]>
  • Loading branch information
KuuCi and v-chen_data authored Sep 27, 2024
1 parent ee45600 commit 107d246
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 99 deletions.
127 changes: 51 additions & 76 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,27 +234,7 @@ def run_query(
elif method == 'dbconnect':
if spark == None:
raise ValueError(f'sparkSession is required for dbconnect')

try:
df = spark.sql(query)
except Exception as e:
from pyspark.errors import AnalysisException
if isinstance(e, AnalysisException):
if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore
match = re.search(
r"Schema\s+'([^']+)'",
e.message, # pyright: ignore
)
if match:
schema_name = match.group(1)
action = f'using the schema {schema_name}'
else:
action = 'using the schema'
raise InsufficientPermissionsError(action=action,) from e
raise RuntimeError(
f'Error in querying into schema. Restart sparkSession and try again',
) from e

df = spark.sql(query)
if collect:
return df.collect()
return df
Expand Down Expand Up @@ -469,71 +449,66 @@ def fetch(
"""
cursor = dbsql.cursor() if dbsql is not None else None
try:
nrows = get_total_rows(
tablename,
method,
cursor,
sparkSession,
)
except Exception as e:
from pyspark.errors import AnalysisException
if isinstance(e, AnalysisException):
if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore
raise InsufficientPermissionsError(
action=f'reading from {tablename}',
) from e
if isinstance(e, InsufficientPermissionsError):
raise e
raise RuntimeError(
f'Error in get rows from {tablename}. Restart sparkSession and try again',
) from e
# Get total rows
nrows = get_total_rows(tablename, method, cursor, sparkSession)

try:
# Get columns info
columns, order_by, columns_str = get_columns_info(
tablename,
method,
cursor,
sparkSession,
)

if method == 'dbconnect' and sparkSession is not None:
log.info(f'{processes=}')
df = sparkSession.table(tablename)

# Running the query and collecting the data as arrow or json.
signed, _, _ = df.collect_cf('arrow') # pyright: ignore
log.info(f'len(signed) = {len(signed)}')

args = get_args(signed, json_output_folder, columns)

# Stopping the SparkSession to avoid spilling connection state into the subprocesses.
sparkSession.stop()

with ProcessPoolExecutor(max_workers=processes) as executor:
list(executor.map(download_starargs, args))

elif method == 'dbsql' and cursor is not None:
for start in range(0, nrows, batch_size):
log.warning(f'batch {start}')
end = min(start + batch_size, nrows)
fetch_data(
method,
cursor,
sparkSession,
start,
end,
order_by,
tablename,
columns_str,
json_output_folder,
)

except Exception as e:
raise RuntimeError(
f'Error in get columns from {tablename}. Restart sparkSession and try again',
) from e
from databricks.sql.exc import ServerOperationError
from pyspark.errors import AnalysisException

if method == 'dbconnect' and sparkSession is not None:
log.info(f'{processes=}')
df = sparkSession.table(tablename)

# Running the query and collecting the data as arrow or json.
signed, _, _ = df.collect_cf('arrow') # pyright: ignore
log.info(f'len(signed) = {len(signed)}')

args = get_args(signed, json_output_folder, columns)

# Stopping the SparkSession to avoid spilling connection state into the subprocesses.
sparkSession.stop()

with ProcessPoolExecutor(max_workers=processes) as executor:
list(executor.map(download_starargs, args))

elif method == 'dbsql' and cursor is not None:
for start in range(0, nrows, batch_size):
log.warning(f'batch {start}')
end = min(start + batch_size, nrows)
fetch_data(
method,
cursor,
sparkSession,
start,
end,
order_by,
tablename,
columns_str,
json_output_folder,
)
if isinstance(e, (AnalysisException, ServerOperationError)):
if 'INSUFFICIENT_PERMISSIONS' in str(e):
raise InsufficientPermissionsError(str(e)) from e

if isinstance(e, InsufficientPermissionsError):
raise

# For any other exception, raise a general error
raise RuntimeError(f'Error processing {tablename}: {str(e)}') from e

if cursor is not None:
cursor.close()
finally:
if cursor is not None:
cursor.close()


def validate_and_get_cluster_info(
Expand Down
13 changes: 10 additions & 3 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,13 @@ def __init__(
class InsufficientPermissionsError(UserError):
"""Error thrown when the user does not have sufficient permissions."""

def __init__(self, action: str) -> None:
message = f'Insufficient permissions when {action}. Please check your permissions.'
super().__init__(message, action=action)
def __init__(self, message: str) -> None:
self.message = message
super().__init__(message)

def __reduce__(self):
# Return a tuple of class, a tuple of arguments, and optionally state
return (InsufficientPermissionsError, (self.message,))

def __str__(self):
return self.message
23 changes: 15 additions & 8 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from llmfoundry.command_utils.data_prep.convert_delta_to_json import (
InsufficientPermissionsError,
download,
fetch,
fetch_DT,
format_tablename,
iterative_combine_jsons,
Expand All @@ -30,27 +31,33 @@ class MockAnalysisException(Exception):
def __init__(self, message: str):
self.message = message

def __str__(self):
return self.message

with patch.dict('sys.modules', {'pyspark.errors': MagicMock()}):
sys.modules[
'pyspark.errors'
].AnalysisException = MockAnalysisException # pyright: ignore
].AnalysisException = MockAnalysisException # type: ignore

mock_spark = MagicMock()
mock_spark.sql.side_effect = MockAnalysisException(error_message)

with self.assertRaises(InsufficientPermissionsError) as context:
run_query(
'SELECT * FROM table',
fetch(
method='dbconnect',
cursor=None,
spark=mock_spark,
tablename='main.oogabooga',
json_output_folder='/fake/path',
batch_size=1,
processes=1,
sparkSession=mock_spark,
dbsql=None,
)

self.assertIn(
'using the schema main.oogabooga',
self.assertEqual(
str(context.exception),
error_message,
)
mock_spark.sql.assert_called_once_with('SELECT * FROM table')
mock_spark.sql.assert_called()

@patch(
'databricks.sql.connect',
Expand Down
39 changes: 27 additions & 12 deletions tests/utils/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contextlib
import inspect
import pickle
from typing import Any, Optional
from typing import Any, Optional, get_type_hints

import pytest

Expand All @@ -14,16 +14,30 @@
def create_exception_object(
exception_class: type[foundry_exceptions.BaseContextualError],
):
# get required arg types of exception class by inspecting its __init__ method

if hasattr(inspect, 'get_annotations'):
required_args = inspect.get_annotations( # type: ignore
exception_class.__init__,
) # type: ignore
else:
required_args = exception_class.__init__.__annotations__ # python 3.9 and below

# create a dictionary of required args with default values
def get_init_annotations(cls: type):
try:
return get_type_hints(cls.__init__)
except (AttributeError, TypeError):
# Handle cases where __init__ does not exist or has no annotations
return {}

# First, try to get annotations from the class itself
required_args = get_init_annotations(exception_class)

# If the annotations are empty, look at parent classes
if not required_args:
for parent in exception_class.__bases__:
if parent == object:
break
parent_args = get_init_annotations(parent)
if parent_args:
required_args = parent_args
break

# Remove self, return, and kwargs
required_args.pop('self', None)
required_args.pop('return', None)
required_args.pop('kwargs', None)

def get_default_value(arg_type: Optional[type] = None):
Expand Down Expand Up @@ -51,8 +65,6 @@ def get_default_value(arg_type: Optional[type] = None):
return [{'key': 'value'}]
raise ValueError(f'Unsupported arg type: {arg_type}')

required_args.pop('self', None)
required_args.pop('return', None)
kwargs = {
arg: get_default_value(arg_type)
for arg, arg_type in required_args.items()
Expand Down Expand Up @@ -80,6 +92,7 @@ def filter_exceptions(possible_exceptions: list[str]):
def test_exception_serialization(
exception_class: type[foundry_exceptions.BaseContextualError],
):
print(f'Testing serialization for {exception_class.__name__}')
excluded_base_classes = [
foundry_exceptions.InternalError,
foundry_exceptions.UserError,
Expand All @@ -88,13 +101,15 @@ def test_exception_serialization(
]

exception = create_exception_object(exception_class)
print(f'Created exception object: {exception}')

expect_reduce_error = exception.__class__ in excluded_base_classes
error_context = pytest.raises(
NotImplementedError,
) if expect_reduce_error else contextlib.nullcontext()

exc_str = str(exception)
print(f'Exception string: {exc_str}')
with error_context:
pkl = pickle.dumps(exception)
unpickled_exc = pickle.loads(pkl)
Expand Down

0 comments on commit 107d246

Please sign in to comment.