diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 44e8651cdf..dac57dcd3d 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -12,6 +12,7 @@ from uuid import uuid4 import google.protobuf.any_pb2 as any_pb2 +import grpc import pandas as pd import pyarrow as pa import requests @@ -24,6 +25,7 @@ FailedToConnectToDatabricksError, FailedToCreateSQLConnectionError, InsufficientPermissionsError, + InternalError, ) if TYPE_CHECKING: @@ -660,16 +662,24 @@ def fetch_DT( ) formatted_delta_table_name = format_tablename(delta_table_name) - - fetch( - method, - formatted_delta_table_name, - json_output_folder, - batch_size, - processes, - sparkSession, - dbsql, - ) + try: + fetch( + method, + formatted_delta_table_name, + json_output_folder, + batch_size, + processes, + sparkSession, + dbsql, + ) + except grpc.RpcError as e: + if e.code( + ) == grpc.StatusCode.INTERNAL and 'Job aborted due to stage failure' in e.details( + ): + raise InternalError( + message=f'Possible Hardware Failure: {e.details()}' + ) from e + raise e if dbsql is not None: dbsql.close() diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 981f5c1ed6..ce6274974d 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -7,8 +7,11 @@ from typing import Any from unittest.mock import MagicMock, mock_open, patch +import grpc + from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( InsufficientPermissionsError, + InternalError, download, fetch, fetch_DT, @@ -524,3 +527,52 @@ def test_format_tablename(self): format_tablename('hyphenated-catalog.schema.test_table'), '`hyphenated-catalog`.`schema`.`test_table`', ) + + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.validate_and_get_cluster_info' + ) + def test_fetch_DT_grpc_error_handling( + self, mock_validate_cluster_info: MagicMock, mock_fetch: MagicMock + ): + # Arrange + # Mock the validate_and_get_cluster_info to return test values + mock_validate_cluster_info.return_value = ('dbconnect', None, None) + + # Create a grpc.RpcError with StatusCode.INTERNAL and specific details + grpc_error = grpc.RpcError() + grpc_error.code = lambda: grpc.StatusCode.INTERNAL + grpc_error.details = lambda: 'Job aborted due to stage failure: Task failed due to an error.' + + # Configure the fetch function to raise the grpc.RpcError + mock_fetch.side_effect = grpc_error + + # Test inputs + delta_table_name = 'test_table' + json_output_folder = '/tmp/to/jsonl' + http_path = None + cluster_id = None + use_serverless = False + DATABRICKS_HOST = 'https://test-host' + DATABRICKS_TOKEN = 'test-token' + + # Act & Assert + with self.assertRaises(InternalError) as context: + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + use_serverless=use_serverless, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + ) + + # Verify that the InternalError contains the expected message + self.assertIn('Possible Hardware Failure', str(context.exception)) + self.assertIn( + 'Job aborted due to stage failure', str(context.exception) + ) + + # Verify that fetch was called + mock_fetch.assert_called_once()