diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index ffdb09ca98..9eb214e83d 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -388,6 +388,9 @@ def test_huggingface_conversion_callback_interval( checkpointer_callback.transform_model_pre_registration = MagicMock( wraps=checkpointer_callback.transform_model_pre_registration, ) + checkpointer_callback.pre_register_edit = MagicMock( + wraps=checkpointer_callback.pre_register_edit, + ) trainer = Trainer( model=original_model, device='gpu', @@ -413,9 +416,11 @@ def test_huggingface_conversion_callback_interval( metadata={}, ) assert checkpointer_callback.transform_model_pre_registration.call_count == 1 + assert checkpointer_callback.pre_register_edit.call_count == 1 assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 else: assert checkpointer_callback.transform_model_pre_registration.call_count == 0 + assert checkpointer_callback.pre_register_edit.call_count == 0 assert mlflow_logger_mock.save_model.call_count == 0 assert mlflow_logger_mock.register_model_with_run_id.call_count == 0