Skip to content

Commit

Permalink
[MCLOUD-4910] Escape UC names during data prep (#1343)
Browse files Browse the repository at this point in the history

Co-authored-by: Naren Loganathan <[email protected]>
  • Loading branch information
naren-loganathan and narenlog-db authored Jul 10, 2024
1 parent 304bf28 commit 08a3624
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
25 changes: 25 additions & 0 deletions scripts/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
MINIMUM_DB_CONNECT_DBR_VERSION = '14.1'
MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2'

TABLENAME_PATTERN = re.compile(r'(\S+)\.(\S+)\.(\S+)')

log = logging.getLogger(__name__)

Result = namedtuple(
Expand Down Expand Up @@ -284,6 +286,27 @@ def download_starargs(args: Tuple) -> None:
return download(*args)


def format_tablename(table_name: str) -> str:
"""Escape catalog, schema and table names with backticks.
This needs to be done when running SQL queries/setting spark sessions to prevent invalid identifier errors.
Args:
table_name (str): catalog.scheme.tablename on UC
"""
match = re.match(TABLENAME_PATTERN, table_name)

if match is None:
return table_name

formatted_identifiers = []
for i in range(1, 4):
identifier = f'`{match.group(i)}`'
formatted_identifiers.append(identifier)

return '.'.join(formatted_identifiers)


def fetch_data(
method: str,
cursor: Optional[Cursor],
Expand Down Expand Up @@ -582,6 +605,8 @@ def fetch_DT(args: Namespace) -> None:
use_serverless=args.use_serverless,
)

args.delta_table_name = format_tablename(args.delta_table_name)

fetch(
method,
args.delta_table_name,
Expand Down
15 changes: 15 additions & 0 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from scripts.data_prep.convert_delta_to_json import (
download,
fetch_DT,
format_tablename,
iterative_combine_jsons,
run_query,
)
Expand Down Expand Up @@ -359,3 +360,17 @@ def test_serverless(
fetch_DT(args)
assert not mock_sql_connect.called
assert not mock_databricks_session.builder.remote.called

def test_format_tablename(self):
self.assertEqual(
format_tablename('test_catalog.hyphenated-schema.test_table'),
'`test_catalog`.`hyphenated-schema`.`test_table`',
)
self.assertEqual(
format_tablename('catalog.schema.table'),
'`catalog`.`schema`.`table`',
)
self.assertEqual(
format_tablename('hyphenated-catalog.schema.test_table'),
'`hyphenated-catalog`.`schema`.`test_table`',
)

0 comments on commit 08a3624

Please sign in to comment.