diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 3050529a5a..44f2b1348d 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -74,12 +74,12 @@ def __init__( if self.mlflow_registered_model_name is not None: # Both the metadata and the task are needed in order for mlflow # and databricks optimized model serving to work - if 'metadata' not in mlflow_logging_config: - mlflow_logging_config['metadata'] = { - 'task': 'llm/v1/completions' - } - if 'task' not in mlflow_logging_config: - mlflow_logging_config['task'] = 'text-generation' + default_metadata = { + 'task': 'llm/v1/completions' + } + passed_metadata = mlflow_logging_config.get('metadata', {}) + mlflow_logging_config['metadata'] = {**default_metadata, **passed_metadata} + mlflow_logging_config.setdefault('task', 'text-generation') self.mlflow_logging_config = mlflow_logging_config self.huggingface_folder_name_fstr = os.path.join(