diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index d2c2a9e1c9..7ba2559de2 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -421,10 +421,23 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, trainer.fit() if dist.get_global_rank() == 0: - assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow - else 0) - assert mlflow_logger_mock.register_model.call_count == ( - 1 if log_to_mlflow else 0) + # assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow + # else 0) + # assert mlflow_logger_mock.register_model.call_count == ( + # 1 if log_to_mlflow else 0) + if log_to_mlflow: + assert mlflow_logger_mock.save_model.call_count == 1 + mlflow_logger_mock.save_model.assert_called_with( + flavor='transformers', + transformers_model=ANY, + path=ANY, + task='text-generation', + metatdata={'task': 'llm/v1/completions'} + ) + assert mlflow_logger_mock.register_model.call_count == 1 + else: + assert mlflow_logger_mock.save_model.call_count == 0 + assert mlflow_logger_mock.register_model.call_count == 0 else: assert mlflow_logger_mock.log_model.call_count == 0 assert mlflow_logger_mock.register_model.call_count == 0