Skip to content

Commit

Permalink
Add proper user error for accessing schema (#1548)
Browse files Browse the repository at this point in the history
Co-authored-by: v-chen_data <[email protected]>
  • Loading branch information
KuuCi and v-chen_data authored Sep 25, 2024
1 parent 722526d commit c786def
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
24 changes: 23 additions & 1 deletion llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,27 @@ def run_query(
elif method == 'dbconnect':
if spark == None:
raise ValueError(f'sparkSession is required for dbconnect')
df = spark.sql(query)

try:
df = spark.sql(query)
except Exception as e:
from pyspark.errors import AnalysisException
if isinstance(e, AnalysisException):
if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore
match = re.search(
r"Schema\s+'([^']+)'",
e.message, # pyright: ignore
)
if match:
schema_name = match.group(1)
action = f'using the schema {schema_name}'
else:
action = 'using the schema'
raise InsufficientPermissionsError(action=action,) from e
raise RuntimeError(
f'Error in querying into schema. Restart sparkSession and try again',
) from e

if collect:
return df.collect()
return df
Expand Down Expand Up @@ -461,6 +481,8 @@ def fetch(
raise InsufficientPermissionsError(
action=f'reading from {tablename}',
) from e
if isinstance(e, InsufficientPermissionsError):
raise e
raise RuntimeError(
f'Error in get rows from {tablename}. Restart sparkSession and try again',
) from e
Expand Down
35 changes: 35 additions & 0 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import sys
import unittest
from argparse import Namespace
from typing import Any
from unittest.mock import MagicMock, mock_open, patch

from llmfoundry.command_utils.data_prep.convert_delta_to_json import (
InsufficientPermissionsError,
download,
fetch_DT,
format_tablename,
Expand All @@ -17,6 +19,39 @@

class TestConvertDeltaToJsonl(unittest.TestCase):

def test_run_query_dbconnect_insufficient_permissions(self):
error_message = (
'[INSUFFICIENT_PERMISSIONS] Insufficient privileges: User does not have USE SCHEMA '
"on Schema 'main.oogabooga'. SQLSTATE: 42501"
)

class MockAnalysisException(Exception):

def __init__(self, message: str):
self.message = message

with patch.dict('sys.modules', {'pyspark.errors': MagicMock()}):
sys.modules[
'pyspark.errors'
].AnalysisException = MockAnalysisException # pyright: ignore

mock_spark = MagicMock()
mock_spark.sql.side_effect = MockAnalysisException(error_message)

with self.assertRaises(InsufficientPermissionsError) as context:
run_query(
'SELECT * FROM table',
method='dbconnect',
cursor=None,
spark=mock_spark,
)

self.assertIn(
'using the schema main.oogabooga',
str(context.exception),
)
mock_spark.sql.assert_called_once_with('SELECT * FROM table')

@patch(
'databricks.sql.connect',
)
Expand Down

0 comments on commit c786def

Please sign in to comment.