Skip to content

Commit

Permalink
add another cluster connection failure wrapper (#1630)
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress authored Nov 1, 2024
1 parent 92252ce commit 7c991e9
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 91 deletions.
8 changes: 8 additions & 0 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,14 @@ def fetch_DT(
message=
f'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive. {e}',
) from e
if isinstance(
e,
spark_errors.SparkConnectGrpcException,
) and 'is not usable' in str(e):
raise FaultyDataPrepCluster(
message=
f'The data preparation cluster you provided is not usable. Please retry with a cluster that is healthy and alive. {e}',
) from e
if isinstance(e, grpc.RpcError) and e.code(
) == grpc.StatusCode.INTERNAL and 'Job aborted due to stage failure' in e.details(
):
Expand Down
153 changes: 62 additions & 91 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def test_format_tablename(self):
@patch(
'llmfoundry.command_utils.data_prep.convert_delta_to_json.validate_and_get_cluster_info',
)
def test_fetch_DT_grpc_error_handling(
def test_fetch_DT_catches_grpc_errors(
self,
mock_validate_cluster_info: MagicMock,
mock_fetch: MagicMock,
Expand All @@ -543,99 +543,70 @@ def test_fetch_DT_grpc_error_handling(
# Mock the validate_and_get_cluster_info to return test values
mock_validate_cluster_info.return_value = ('dbconnect', None, None)

# Create a grpc.RpcError with StatusCode.INTERNAL and specific details
grpc_error = grpc.RpcError()
grpc_error.code = lambda: grpc.StatusCode.INTERNAL
grpc_error.details = lambda: 'Job aborted due to stage failure: Task failed due to an error.'

# Configure the fetch function to raise the grpc.RpcError
mock_fetch.side_effect = grpc_error

# Test inputs
delta_table_name = 'test_table'
json_output_folder = '/tmp/to/jsonl'
http_path = None
cluster_id = None
use_serverless = False
DATABRICKS_HOST = 'https://test-host'
DATABRICKS_TOKEN = 'test-token'

# Act & Assert
with self.assertRaises(FaultyDataPrepCluster) as context:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
use_serverless=use_serverless,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
)

# Verify that the FaultyDataPrepCluster contains the expected message
self.assertIn(
'Faulty data prep cluster, please try swapping data prep cluster: ',
str(context.exception),
)
self.assertIn(
'Job aborted due to stage failure',
str(context.exception),
)

# Verify that fetch was called
mock_fetch.assert_called_once()

@patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch')
@patch(
'llmfoundry.command_utils.data_prep.convert_delta_to_json.validate_and_get_cluster_info',
)
def test_fetch_DT_catches_cluster_failed_to_start(
self,
mock_validate_cluster_info: MagicMock,
mock_fetch: MagicMock,
):
# Arrange
# Mock the validate_and_get_cluster_info to return test values
mock_validate_cluster_info.return_value = ('dbconnect', None, None)

# Create a SparkConnectGrpcException indicating that the cluster failed to start

grpc_error = SparkConnectGrpcException(
message='Cannot start cluster etc...',
)

# Configure the fetch function to raise the SparkConnectGrpcException
mock_fetch.side_effect = grpc_error

# Test inputs
delta_table_name = 'test_table'
json_output_folder = '/tmp/to/jsonl'
http_path = None
cluster_id = None
use_serverless = False
DATABRICKS_HOST = 'https://test-host'
DATABRICKS_TOKEN = 'test-token'

# Act & Assert
with self.assertRaises(FaultyDataPrepCluster) as context:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
use_serverless=use_serverless,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
)
grpc_lib_error = grpc.RpcError()
grpc_lib_error.code = lambda: grpc.StatusCode.INTERNAL
grpc_lib_error.details = lambda: 'Job aborted due to stage failure: Task failed due to an error.'

error_contexts = [
(
SparkConnectGrpcException('Cannot start cluster etc...'),
FaultyDataPrepCluster,
[
'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive.',
],
),
(
SparkConnectGrpcException('cluster ... is not usable'),
FaultyDataPrepCluster,
[
'The data preparation cluster you provided is not usable. Please retry with a cluster that is healthy and alive.',
],
),
(
grpc_lib_error,
FaultyDataPrepCluster,
[
'Faulty data prep cluster, please try swapping data prep cluster: ',
'Job aborted due to stage failure',
],
),
]

for (
err_to_throw,
err_to_catch,
texts_to_check_in_error,
) in error_contexts:
# Configure the fetch function to raise the SparkConnectGrpcException
mock_fetch.side_effect = err_to_throw

# Test inputs
delta_table_name = 'test_table'
json_output_folder = '/tmp/to/jsonl'
http_path = None
cluster_id = None
use_serverless = False
DATABRICKS_HOST = 'https://test-host'
DATABRICKS_TOKEN = 'test-token'

# Act & Assert
with self.assertRaises(err_to_catch) as context:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
use_serverless=use_serverless,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
)

# Verify that the FaultyDataPrepCluster contains the expected message
self.assertIn(
'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive.',
str(context.exception),
)
# Verify that the FaultyDataPrepCluster contains the expected message
for text in texts_to_check_in_error:
self.assertIn(text, str(context.exception))

# Verify that fetch was called
mock_fetch.assert_called_once()
mock_fetch.assert_called()

@patch(
'llmfoundry.command_utils.data_prep.convert_delta_to_json.get_total_rows',
Expand Down

0 comments on commit 7c991e9

Please sign in to comment.