diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index f9618c5fa2..082c767288 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -394,8 +394,8 @@ def _save_checkpoint(self, state: State, logger: Logger): os.path.join(local_save_path, license_filename), ) - mlflow_logger.register_model( + mlflow_logger.register_model_with_run_id( model_uri=local_save_path, name=self.mlflow_registered_model_name, - await_registration_for=3600, + await_creation_for=3600, ) diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index ab2d569132..bc4214e76a 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -306,7 +306,7 @@ def test_huggingface_conversion_callback_interval( mlflow_logger_mock = MagicMock(spec=MLFlowLogger) mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock) - mlflow_logger_mock.register_model = MagicMock() + mlflow_logger_mock.register_model_with_run_id = MagicMock() mlflow_logger_mock.model_registry_prefix = '' mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' mlflow_logger_mock._run_id = 'mlflow-run-id' @@ -334,10 +334,10 @@ def test_huggingface_conversion_callback_interval( input_example=ANY, signature=ANY, metadata={'task': 'llm/v1/completions'}) - assert mlflow_logger_mock.register_model.call_count == 1 + assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 else: assert mlflow_logger_mock.save_model.call_count == 0 - assert mlflow_logger_mock.register_model.call_count == 0 + assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 normal_checkpoints = [ name for name in os.listdir(os.path.join(tmp_path, 'checkpoints')) @@ -564,7 +564,7 @@ def test_huggingface_conversion_callback( mlflow_logger_mock = MagicMock(spec=MLFlowLogger) mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock) - mlflow_logger_mock.register_model = MagicMock() + mlflow_logger_mock.register_model_with_run_id = MagicMock() mlflow_logger_mock.model_registry_prefix = '' mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' mlflow_logger_mock._run_id = 'mlflow-run-id' @@ -628,10 +628,10 @@ def test_huggingface_conversion_callback( } } mlflow_logger_mock.save_model.assert_called_with(**expectation) - assert mlflow_logger_mock.register_model.call_count == 1 + assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 else: assert mlflow_logger_mock.log_model.call_count == 0 - assert mlflow_logger_mock.register_model.call_count == 0 + assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 # summon full params to check equivalence from torch.distributed.fsdp import FullyShardedDataParallel as FSDP