diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 6098bbeefd..8956858a93 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -239,6 +239,7 @@ def __init__( + f'Defaulting to final_register_only=False and saving the HuggingFace checkpoint to {save_folder=}.', ) + self.use_mlflow_log_model = False # mlflow config setup if mlflow_logging_config is None: @@ -269,6 +270,8 @@ def __init__( 'input_example', default_input_example, ) + if mlflow_logging_config['use_mlflow_log_model']: + self.use_mlflow_log_model = True self.mlflow_logging_config = mlflow_logging_config if 'metadata' in self.mlflow_logging_config: @@ -709,7 +712,7 @@ def tensor_hook( log.debug('Logging Hugging Face model to MLFlow') for i, mlflow_logger in enumerate(self.mlflow_loggers): log.debug( - f'Logging model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}', + f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}', ) local_save_path = str( Path(temp_save_dir) / f'mlflow_save_{i}', @@ -743,7 +746,7 @@ def tensor_hook( 'transformers', 'torch', ] - mlflow_logger.log_model(**model_saving_kwargs) + mlflow_logger.save_model(**model_saving_kwargs) # Upload the license file generated by mlflow during the model saving. license_filename = _maybe_get_license_filename( @@ -767,8 +770,7 @@ def tensor_hook( # Spawn a new process to register the model. # Slower method to register the model via log_model. - # TODO: design this with some extra param in the model saving config to invoke this - if True: + if self.use_mlflow_log_model: process = SpawnProcess( target=_log_model_multiprocess, kwargs={