From 3e17fd0639cf9acfe3aba4f8c780a05a1bad4fed Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Tue, 7 Nov 2023 13:09:40 -0800 Subject: [PATCH] test --- llmfoundry/callbacks/hf_checkpointer.py | 9 +++++---- tests/test_hf_conversion_script.py | 24 +++++++++++++++++++----- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 44f2b1348d..b8990f67df 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -74,11 +74,12 @@ def __init__( if self.mlflow_registered_model_name is not None: # Both the metadata and the task are needed in order for mlflow # and databricks optimized model serving to work - default_metadata = { - 'task': 'llm/v1/completions' - } + default_metadata = {'task': 'llm/v1/completions'} passed_metadata = mlflow_logging_config.get('metadata', {}) - mlflow_logging_config['metadata'] = {**default_metadata, **passed_metadata} + mlflow_logging_config['metadata'] = { + **default_metadata, + **passed_metadata + } mlflow_logging_config.setdefault('task', 'text-generation') self.mlflow_logging_config = mlflow_logging_config diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index d2c2a9e1c9..a2de0ea854 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -5,7 +5,7 @@ import os import pathlib import sys -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, ANY from composer import Trainer from composer.loggers import MLFlowLogger @@ -421,10 +421,24 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, trainer.fit() if dist.get_global_rank() == 0: - assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow - else 0) - assert mlflow_logger_mock.register_model.call_count == ( - 1 if log_to_mlflow else 0) + if log_to_mlflow: + # assert mlflow_logger_mock.save_model.call_count == 1 + mlflow_logger_mock.save_model.assert_called_once_with( + flavor='transformers', + transformers_model=ANY, + path=ANY, + task='text-generation', + metatdata={'task': 'llm/v1/completions'}) + assert mlflow_logger_mock.register_model.call_count == 1 + # mlflow_logger.save_model( + # flavor='transformers', + # transformers_model=components, + # path=local_save_path, + # **self.mlflow_logging_config, + # ) + else: + assert mlflow_logger_mock.save_model.call_count == 0 + assert mlflow_logger_mock.register_model.call_count == 0 else: assert mlflow_logger_mock.log_model.call_count == 0 assert mlflow_logger_mock.register_model.call_count == 0