diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 7ba2559de2..07b951a382 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -5,7 +5,7 @@ import os import pathlib import sys -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch from composer import Trainer from composer.loggers import MLFlowLogger @@ -242,9 +242,22 @@ def get_config( return cast(DictConfig, test_cfg) -def test_callback_inits_with_defaults(): +def test_callback_inits(): + # test with defaults _ = HuggingFaceCheckpointer(save_folder='test', save_interval='1ba') + # test default metatdata when mlflow registered name is given + hf_checkpointer = HuggingFaceCheckpointer( + save_folder='test', + save_interval='1ba', + mlflow_registered_model_name='test_model_name') + assert hf_checkpointer.mlflow_logging_config == { + 'task': 'text-generation', + 'metadata': { + 'task': 'llm/v1/completions' + } + } + @pytest.mark.world_size(2) @pytest.mark.gpu @@ -421,10 +434,6 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, trainer.fit() if dist.get_global_rank() == 0: - # assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow - # else 0) - # assert mlflow_logger_mock.register_model.call_count == ( - # 1 if log_to_mlflow else 0) if log_to_mlflow: assert mlflow_logger_mock.save_model.call_count == 1 mlflow_logger_mock.save_model.assert_called_with( @@ -432,8 +441,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, transformers_model=ANY, path=ANY, task='text-generation', - metatdata={'task': 'llm/v1/completions'} - ) + metadata={'task': 'llm/v1/completions'}) assert mlflow_logger_mock.register_model.call_count == 1 else: assert mlflow_logger_mock.save_model.call_count == 0