diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index ddd53a376c..8abe845f1b 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -429,6 +429,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: ) import mlflow + import mlflow.environment_variables mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set( '1GB', ) @@ -870,6 +871,7 @@ def _save_and_register_peft_model( # TODO: Remove after mlflow fixes the bug that makes this necessary import mlflow + import mlflow.store mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' model_saving_kwargs: dict[str, Any] = { diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 2b8d148781..eb50b591e6 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -712,6 +712,14 @@ def fetch_DT( message= f'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive. {e}', ) from e + if isinstance( + e, + spark_errors.SparkConnectGrpcException, + ) and 'is not usable' in str(e): + raise FaultyDataPrepCluster( + message= + f'The data preparation cluster you provided is not usable. Please retry with a cluster that is healthy and alive. {e}', + ) from e if isinstance(e, grpc.RpcError) and e.code( ) == grpc.StatusCode.INTERNAL and 'Job aborted due to stage failure' in e.details( ): diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 03f8812fa3..997273de7f 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -21,6 +21,12 @@ import mlflow from composer.loggers import Logger from composer.utils import dist, parse_uri +from mlflow.data import ( + delta_dataset_source, + http_dataset_source, + huggingface_dataset_source, + uc_volume_dataset_source, +) from omegaconf import MISSING, DictConfig, ListConfig, MissingMandatoryValue from omegaconf import OmegaConf as om from transformers import PretrainedConfig @@ -769,15 +775,15 @@ def log_dataset_uri(cfg: dict[str, Any]) -> None: data_paths = _parse_source_dataset(cfg) dataset_source_mapping = { - 's3': mlflow.data.http_dataset_source.HTTPDatasetSource, - 'oci': mlflow.data.http_dataset_source.HTTPDatasetSource, - 'azure': mlflow.data.http_dataset_source.HTTPDatasetSource, - 'gs': mlflow.data.http_dataset_source.HTTPDatasetSource, - 'https': mlflow.data.http_dataset_source.HTTPDatasetSource, - 'hf': mlflow.data.huggingface_dataset_source.HuggingFaceDatasetSource, - 'delta_table': mlflow.data.delta_dataset_source.DeltaDatasetSource, - 'uc_volume': mlflow.data.uc_volume_dataset_source.UCVolumeDatasetSource, - 'local': mlflow.data.http_dataset_source.HTTPDatasetSource, + 's3': http_dataset_source.HTTPDatasetSource, + 'oci': http_dataset_source.HTTPDatasetSource, + 'azure': http_dataset_source.HTTPDatasetSource, + 'gs': http_dataset_source.HTTPDatasetSource, + 'https': http_dataset_source.HTTPDatasetSource, + 'hf': huggingface_dataset_source.HuggingFaceDatasetSource, + 'delta_table': delta_dataset_source.DeltaDatasetSource, + 'uc_volume': uc_volume_dataset_source.UCVolumeDatasetSource, + 'local': http_dataset_source.HTTPDatasetSource, } # Map data source types to their respective MLFlow DataSource. @@ -795,7 +801,7 @@ def log_dataset_uri(cfg: dict[str, Any]) -> None: log.info( f'{dataset_type} unknown, defaulting to http dataset source', ) - source = mlflow.data.http_dataset_source.HTTPDatasetSource(url=path) + source = http_dataset_source.HTTPDatasetSource(url=path) mlflow.log_input( mlflow.data.meta_dataset.MetaDataset(source, name=split), diff --git a/setup.py b/setup.py index 00934a30e0..86d696ed5c 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ install_requires = [ 'mosaicml[libcloud,wandb,oci,gcs,mlflow]>=0.26.0,<0.27', - 'mlflow>=2.14.1,<2.17', + 'mlflow>=2.14.1,<2.18', 'accelerate>=0.25,<0.34', # for HF inference `device_map` 'transformers>=4.43.2,<4.44', 'mosaicml-streaming>=0.9.0,<0.10', diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 5292f86e6d..74c9f2a6c6 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -534,7 +534,7 @@ def test_format_tablename(self): @patch( 'llmfoundry.command_utils.data_prep.convert_delta_to_json.validate_and_get_cluster_info', ) - def test_fetch_DT_grpc_error_handling( + def test_fetch_DT_catches_grpc_errors( self, mock_validate_cluster_info: MagicMock, mock_fetch: MagicMock, @@ -543,99 +543,70 @@ def test_fetch_DT_grpc_error_handling( # Mock the validate_and_get_cluster_info to return test values mock_validate_cluster_info.return_value = ('dbconnect', None, None) - # Create a grpc.RpcError with StatusCode.INTERNAL and specific details - grpc_error = grpc.RpcError() - grpc_error.code = lambda: grpc.StatusCode.INTERNAL - grpc_error.details = lambda: 'Job aborted due to stage failure: Task failed due to an error.' - - # Configure the fetch function to raise the grpc.RpcError - mock_fetch.side_effect = grpc_error - - # Test inputs - delta_table_name = 'test_table' - json_output_folder = '/tmp/to/jsonl' - http_path = None - cluster_id = None - use_serverless = False - DATABRICKS_HOST = 'https://test-host' - DATABRICKS_TOKEN = 'test-token' - - # Act & Assert - with self.assertRaises(FaultyDataPrepCluster) as context: - fetch_DT( - delta_table_name=delta_table_name, - json_output_folder=json_output_folder, - http_path=http_path, - cluster_id=cluster_id, - use_serverless=use_serverless, - DATABRICKS_HOST=DATABRICKS_HOST, - DATABRICKS_TOKEN=DATABRICKS_TOKEN, - ) - - # Verify that the FaultyDataPrepCluster contains the expected message - self.assertIn( - 'Faulty data prep cluster, please try swapping data prep cluster: ', - str(context.exception), - ) - self.assertIn( - 'Job aborted due to stage failure', - str(context.exception), - ) - - # Verify that fetch was called - mock_fetch.assert_called_once() - - @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') - @patch( - 'llmfoundry.command_utils.data_prep.convert_delta_to_json.validate_and_get_cluster_info', - ) - def test_fetch_DT_catches_cluster_failed_to_start( - self, - mock_validate_cluster_info: MagicMock, - mock_fetch: MagicMock, - ): - # Arrange - # Mock the validate_and_get_cluster_info to return test values - mock_validate_cluster_info.return_value = ('dbconnect', None, None) - - # Create a SparkConnectGrpcException indicating that the cluster failed to start - - grpc_error = SparkConnectGrpcException( - message='Cannot start cluster etc...', - ) - - # Configure the fetch function to raise the SparkConnectGrpcException - mock_fetch.side_effect = grpc_error - - # Test inputs - delta_table_name = 'test_table' - json_output_folder = '/tmp/to/jsonl' - http_path = None - cluster_id = None - use_serverless = False - DATABRICKS_HOST = 'https://test-host' - DATABRICKS_TOKEN = 'test-token' - - # Act & Assert - with self.assertRaises(FaultyDataPrepCluster) as context: - fetch_DT( - delta_table_name=delta_table_name, - json_output_folder=json_output_folder, - http_path=http_path, - cluster_id=cluster_id, - use_serverless=use_serverless, - DATABRICKS_HOST=DATABRICKS_HOST, - DATABRICKS_TOKEN=DATABRICKS_TOKEN, - ) + grpc_lib_error = grpc.RpcError() + grpc_lib_error.code = lambda: grpc.StatusCode.INTERNAL + grpc_lib_error.details = lambda: 'Job aborted due to stage failure: Task failed due to an error.' + + error_contexts = [ + ( + SparkConnectGrpcException('Cannot start cluster etc...'), + FaultyDataPrepCluster, + [ + 'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive.', + ], + ), + ( + SparkConnectGrpcException('cluster ... is not usable'), + FaultyDataPrepCluster, + [ + 'The data preparation cluster you provided is not usable. Please retry with a cluster that is healthy and alive.', + ], + ), + ( + grpc_lib_error, + FaultyDataPrepCluster, + [ + 'Faulty data prep cluster, please try swapping data prep cluster: ', + 'Job aborted due to stage failure', + ], + ), + ] + + for ( + err_to_throw, + err_to_catch, + texts_to_check_in_error, + ) in error_contexts: + # Configure the fetch function to raise the SparkConnectGrpcException + mock_fetch.side_effect = err_to_throw + + # Test inputs + delta_table_name = 'test_table' + json_output_folder = '/tmp/to/jsonl' + http_path = None + cluster_id = None + use_serverless = False + DATABRICKS_HOST = 'https://test-host' + DATABRICKS_TOKEN = 'test-token' + + # Act & Assert + with self.assertRaises(err_to_catch) as context: + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + use_serverless=use_serverless, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + ) - # Verify that the FaultyDataPrepCluster contains the expected message - self.assertIn( - 'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive.', - str(context.exception), - ) + # Verify that the FaultyDataPrepCluster contains the expected message + for text in texts_to_check_in_error: + self.assertIn(text, str(context.exception)) # Verify that fetch was called - mock_fetch.assert_called_once() + mock_fetch.assert_called() @patch( 'llmfoundry.command_utils.data_prep.convert_delta_to_json.get_total_rows',