diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index b44859e15a..63c9b9ca0d 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -40,7 +40,8 @@ class HuggingFaceCheckpointer(Callback): huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``. precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``. overwrite (bool): Whether to overwrite previous checkpoints. - log_to_mlflow (bool): Whether to register the model to MLflow. This will only register one model at the end of training. Default is ``False``. + mlflow_registered_model_name (Optional[str]): The name to register the model under in the MLflow model registry. If ``None``, the model will not + be registered. Default is ``None``. mlflow_logging_config (Optional[dict]): A dictionary of config arguments that will get passed along to the MLflow ``save_model`` call. Expected to contain ``metadata`` and ``task`` keys. If either is unspecified, the defaults are ``'text-generation'`` and ``{'task': 'llm/v1/completions'}`` respectively. @@ -53,7 +54,7 @@ def __init__( huggingface_folder_name: str = 'ba{batch}', precision: str = 'float32', overwrite: bool = False, - log_to_mlflow: bool = False, + mlflow_registered_model_name: Optional[str] = None, mlflow_logging_config: Optional[dict] = None, ): self.backend, self.bucket_name, self.save_dir_format_str = parse_uri( @@ -67,13 +68,14 @@ def __init__( }[precision] # mlflow config setup - self.log_to_mlflow = log_to_mlflow - if mlflow_logging_config is None: - mlflow_logging_config = {} - if 'metadata' not in mlflow_logging_config: - mlflow_logging_config['metadata'] = {'task': 'llm/v1/completions'} - if 'task' not in mlflow_logging_config: - mlflow_logging_config['task'] = 'text-generation' + self.mlflow_registered_model_name = mlflow_registered_model_name + if self.mlflow_registered_model_name is not None: + if mlflow_logging_config is None: + mlflow_logging_config = {} + if 'metadata' not in mlflow_logging_config: + mlflow_logging_config['metadata'] = {'task': 'llm/v1/completions'} + if 'task' not in mlflow_logging_config: + mlflow_logging_config['task'] = 'text-generation' self.mlflow_logging_config = mlflow_logging_config self.huggingface_folder_name_fstr = os.path.join( @@ -106,7 +108,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: self.remote_ud.init(state, logger) state.callbacks.append(self.remote_ud) - if self.log_to_mlflow: + if self.mlflow_registered_model_name is not None: self.mlflow_loggers = [ logger_destination for logger_destination in logger.destinations @@ -216,16 +218,15 @@ def _save_checkpoint(self, state: State, logger: Logger): ) elapsed_duration = state.get_elapsed_duration() - if self.log_to_mlflow and elapsed_duration is not None and elapsed_duration >= 1.0: + if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0: components = {'model': new_model_instance} if original_tokenizer is not None: components['tokenizer'] = original_tokenizer log.debug('Logging Hugging Face model to MLFlow') - registered_model_name = f'{state.run_name}_{os.path.basename(save_dir)}' for i, mlflow_logger in enumerate(self.mlflow_loggers): log.debug( - f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{registered_model_name}' + f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}' ) local_save_path = str( Path(temp_save_dir) / f'mlflow_save_{i}') @@ -241,6 +242,6 @@ def _save_checkpoint(self, state: State, logger: Logger): ) mlflow_logger.register_model( model_uri=local_save_path, - name=registered_model_name, + name=self.mlflow_registered_model_name, await_registration_for=3600, ) diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index a5113973d4..e0774240f8 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -224,7 +224,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, save_folder=os.path.join(tmp_path, 'checkpoints'), save_interval=f'{huggingface_save_interval_batches}ba', precision=precision_str, - log_to_mlflow=log_to_mlflow, + mlflow_registered_model_name='dummy-registered-name', ) # get small version of each model