Skip to content

Commit

Permalink
fix conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Nov 1, 2024
2 parents 674506d + 7c991e9 commit 7a423b1
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 102 deletions.
2 changes: 2 additions & 0 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
)

import mlflow
import mlflow.environment_variables
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set(
'1GB',
)
Expand Down Expand Up @@ -870,6 +871,7 @@ def _save_and_register_peft_model(

# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
import mlflow.store
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''

model_saving_kwargs: dict[str, Any] = {
Expand Down
8 changes: 8 additions & 0 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,14 @@ def fetch_DT(
message=
f'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive. {e}',
) from e
if isinstance(
e,
spark_errors.SparkConnectGrpcException,
) and 'is not usable' in str(e):
raise FaultyDataPrepCluster(
message=
f'The data preparation cluster you provided is not usable. Please retry with a cluster that is healthy and alive. {e}',
) from e
if isinstance(e, grpc.RpcError) and e.code(
) == grpc.StatusCode.INTERNAL and 'Job aborted due to stage failure' in e.details(
):
Expand Down
26 changes: 16 additions & 10 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
import mlflow
from composer.loggers import Logger
from composer.utils import dist, parse_uri
from mlflow.data import (
delta_dataset_source,
http_dataset_source,
huggingface_dataset_source,
uc_volume_dataset_source,
)
from omegaconf import MISSING, DictConfig, ListConfig, MissingMandatoryValue
from omegaconf import OmegaConf as om
from transformers import PretrainedConfig
Expand Down Expand Up @@ -769,15 +775,15 @@ def log_dataset_uri(cfg: dict[str, Any]) -> None:
data_paths = _parse_source_dataset(cfg)

dataset_source_mapping = {
's3': mlflow.data.http_dataset_source.HTTPDatasetSource,
'oci': mlflow.data.http_dataset_source.HTTPDatasetSource,
'azure': mlflow.data.http_dataset_source.HTTPDatasetSource,
'gs': mlflow.data.http_dataset_source.HTTPDatasetSource,
'https': mlflow.data.http_dataset_source.HTTPDatasetSource,
'hf': mlflow.data.huggingface_dataset_source.HuggingFaceDatasetSource,
'delta_table': mlflow.data.delta_dataset_source.DeltaDatasetSource,
'uc_volume': mlflow.data.uc_volume_dataset_source.UCVolumeDatasetSource,
'local': mlflow.data.http_dataset_source.HTTPDatasetSource,
's3': http_dataset_source.HTTPDatasetSource,
'oci': http_dataset_source.HTTPDatasetSource,
'azure': http_dataset_source.HTTPDatasetSource,
'gs': http_dataset_source.HTTPDatasetSource,
'https': http_dataset_source.HTTPDatasetSource,
'hf': huggingface_dataset_source.HuggingFaceDatasetSource,
'delta_table': delta_dataset_source.DeltaDatasetSource,
'uc_volume': uc_volume_dataset_source.UCVolumeDatasetSource,
'local': http_dataset_source.HTTPDatasetSource,
}

# Map data source types to their respective MLFlow DataSource.
Expand All @@ -795,7 +801,7 @@ def log_dataset_uri(cfg: dict[str, Any]) -> None:
log.info(
f'{dataset_type} unknown, defaulting to http dataset source',
)
source = mlflow.data.http_dataset_source.HTTPDatasetSource(url=path)
source = http_dataset_source.HTTPDatasetSource(url=path)

mlflow.log_input(
mlflow.data.meta_dataset.MetaDataset(source, name=split),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

install_requires = [
'mosaicml[libcloud,wandb,oci,gcs,mlflow]>=0.26.0,<0.27',
'mlflow>=2.14.1,<2.17',
'mlflow>=2.14.1,<2.18',
'accelerate>=0.25,<0.34', # for HF inference `device_map`
'transformers>=4.43.2,<4.44',
'mosaicml-streaming>=0.9.0,<0.10',
Expand Down
153 changes: 62 additions & 91 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def test_format_tablename(self):
@patch(
'llmfoundry.command_utils.data_prep.convert_delta_to_json.validate_and_get_cluster_info',
)
def test_fetch_DT_grpc_error_handling(
def test_fetch_DT_catches_grpc_errors(
self,
mock_validate_cluster_info: MagicMock,
mock_fetch: MagicMock,
Expand All @@ -543,99 +543,70 @@ def test_fetch_DT_grpc_error_handling(
# Mock the validate_and_get_cluster_info to return test values
mock_validate_cluster_info.return_value = ('dbconnect', None, None)

# Create a grpc.RpcError with StatusCode.INTERNAL and specific details
grpc_error = grpc.RpcError()
grpc_error.code = lambda: grpc.StatusCode.INTERNAL
grpc_error.details = lambda: 'Job aborted due to stage failure: Task failed due to an error.'

# Configure the fetch function to raise the grpc.RpcError
mock_fetch.side_effect = grpc_error

# Test inputs
delta_table_name = 'test_table'
json_output_folder = '/tmp/to/jsonl'
http_path = None
cluster_id = None
use_serverless = False
DATABRICKS_HOST = 'https://test-host'
DATABRICKS_TOKEN = 'test-token'

# Act & Assert
with self.assertRaises(FaultyDataPrepCluster) as context:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
use_serverless=use_serverless,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
)

# Verify that the FaultyDataPrepCluster contains the expected message
self.assertIn(
'Faulty data prep cluster, please try swapping data prep cluster: ',
str(context.exception),
)
self.assertIn(
'Job aborted due to stage failure',
str(context.exception),
)

# Verify that fetch was called
mock_fetch.assert_called_once()

@patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch')
@patch(
'llmfoundry.command_utils.data_prep.convert_delta_to_json.validate_and_get_cluster_info',
)
def test_fetch_DT_catches_cluster_failed_to_start(
self,
mock_validate_cluster_info: MagicMock,
mock_fetch: MagicMock,
):
# Arrange
# Mock the validate_and_get_cluster_info to return test values
mock_validate_cluster_info.return_value = ('dbconnect', None, None)

# Create a SparkConnectGrpcException indicating that the cluster failed to start

grpc_error = SparkConnectGrpcException(
message='Cannot start cluster etc...',
)

# Configure the fetch function to raise the SparkConnectGrpcException
mock_fetch.side_effect = grpc_error

# Test inputs
delta_table_name = 'test_table'
json_output_folder = '/tmp/to/jsonl'
http_path = None
cluster_id = None
use_serverless = False
DATABRICKS_HOST = 'https://test-host'
DATABRICKS_TOKEN = 'test-token'

# Act & Assert
with self.assertRaises(FaultyDataPrepCluster) as context:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
use_serverless=use_serverless,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
)
grpc_lib_error = grpc.RpcError()
grpc_lib_error.code = lambda: grpc.StatusCode.INTERNAL
grpc_lib_error.details = lambda: 'Job aborted due to stage failure: Task failed due to an error.'

error_contexts = [
(
SparkConnectGrpcException('Cannot start cluster etc...'),
FaultyDataPrepCluster,
[
'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive.',
],
),
(
SparkConnectGrpcException('cluster ... is not usable'),
FaultyDataPrepCluster,
[
'The data preparation cluster you provided is not usable. Please retry with a cluster that is healthy and alive.',
],
),
(
grpc_lib_error,
FaultyDataPrepCluster,
[
'Faulty data prep cluster, please try swapping data prep cluster: ',
'Job aborted due to stage failure',
],
),
]

for (
err_to_throw,
err_to_catch,
texts_to_check_in_error,
) in error_contexts:
# Configure the fetch function to raise the SparkConnectGrpcException
mock_fetch.side_effect = err_to_throw

# Test inputs
delta_table_name = 'test_table'
json_output_folder = '/tmp/to/jsonl'
http_path = None
cluster_id = None
use_serverless = False
DATABRICKS_HOST = 'https://test-host'
DATABRICKS_TOKEN = 'test-token'

# Act & Assert
with self.assertRaises(err_to_catch) as context:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
use_serverless=use_serverless,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
)

# Verify that the FaultyDataPrepCluster contains the expected message
self.assertIn(
'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive.',
str(context.exception),
)
# Verify that the FaultyDataPrepCluster contains the expected message
for text in texts_to_check_in_error:
self.assertIn(text, str(context.exception))

# Verify that fetch was called
mock_fetch.assert_called_once()
mock_fetch.assert_called()

@patch(
'llmfoundry.command_utils.data_prep.convert_delta_to_json.get_total_rows',
Expand Down

0 comments on commit 7a423b1

Please sign in to comment.