Skip to content

Commit

Permalink
Merge branch 'main' into tp
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok authored Sep 25, 2024
2 parents 5004fe5 + e6b8d14 commit d0f6751
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 4 deletions.
12 changes: 12 additions & 0 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from llmfoundry.utils.exceptions import (
ClusterDoesNotExistError,
ClusterInvalidAccessMode,
FailedToConnectToDatabricksError,
FailedToCreateSQLConnectionError,
InsufficientPermissionsError,
Expand Down Expand Up @@ -568,6 +569,17 @@ def validate_and_get_cluster_info(
if res is None:
raise ClusterDoesNotExistError(cluster_id)

data_security_mode = str(
res.data_security_mode,
).upper()[len('DATASECURITYMODE.'):]

# NONE stands for No Isolation Shared
if data_security_mode == 'NONE':
raise ClusterInvalidAccessMode(
cluster_id=cluster_id,
access_mode=data_security_mode,
)

assert res.spark_version is not None
stripped_runtime = re.sub(
r'[a-zA-Z]',
Expand Down
13 changes: 13 additions & 0 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,19 @@ def __init__(self, cluster_id: str) -> None:
super().__init__(message, cluster_id=cluster_id)


class ClusterInvalidAccessMode(NetworkError):
"""Error thrown when the cluster does not exist."""

def __init__(self, cluster_id: str, access_mode: str) -> None:
message = f'Cluster with id {cluster_id} has access mode {access_mode}. ' + \
'please make sure the cluster used has access mode Shared or Single User!'
super().__init__(
message,
cluster_id=cluster_id,
access_mode=access_mode,
)


class FailedToCreateSQLConnectionError(
NetworkError,
):
Expand Down
20 changes: 16 additions & 4 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,10 @@ def test_dbconnect_called(
DATABRICKS_TOKEN = 'token'
use_serverless = False

mock_cluster_response = Namespace(spark_version='14.1.0-scala2.12')
mock_cluster_response = Namespace(
spark_version='14.1.0-scala2.12',
data_security_mode='SINGLE_USER',
)
mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response

mock_remote = MagicMock()
Expand Down Expand Up @@ -321,7 +324,10 @@ def test_sqlconnect_called_dbr13(
DATABRICKS_TOKEN = 'token'
use_serverless = False

mock_cluster_response = Namespace(spark_version='13.0.0-scala2.12')
mock_cluster_response = Namespace(
spark_version='13.0.0-scala2.12',
data_security_mode='SINGLE_USER',
)
mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response

fetch_DT(
Expand Down Expand Up @@ -373,7 +379,10 @@ def test_sqlconnect_called_dbr14(
DATABRICKS_TOKEN = 'token'
use_serverless = False

mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12')
mock_cluster_response = Namespace(
spark_version='14.2.0-scala2.12',
data_security_mode='SINGLE_USER',
)
mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response

fetch_DT(
Expand Down Expand Up @@ -425,7 +434,10 @@ def test_sqlconnect_called_https(
DATABRICKS_TOKEN = 'token'
use_serverless = False

mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12')
mock_cluster_response = Namespace(
spark_version='14.2.0-scala2.12',
data_security_mode='SINGLE_USER',
)
mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response

fetch_DT(
Expand Down

0 comments on commit d0f6751

Please sign in to comment.