diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 9eb23f1030..ec85f7624a 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -93,7 +93,7 @@ class HuggingFaceCheckpointer(Callback): will not be registered. Default is ``None``. mlflow_logging_config (Optional[dict]): A dictionary of config arguments that will get passed along to the MLflow ``save_model`` call. - Expected to contain ``metadata`` and ``task`` keys. If either is + Expected to contain ``task`` and ``metadata`` keys. If either is unspecified, the defaults are ``'text-generation'`` and ``{'task': 'llm/v1/completions'}`` respectively. A default input example and signature intended for text generation is also included under the @@ -134,7 +134,7 @@ def __init__( # and databricks optimized model serving to work passed_metadata = mlflow_logging_config.get('metadata', {}) mlflow_logging_config['metadata'] = passed_metadata - mlflow_logging_config.setdefault('task', 'llm/v1/completions') + mlflow_logging_config.setdefault('task', 'text-generation') default_input_example = { 'prompt': np.array(['What is Machine Learning?'])