From 34cdaf68fde5e92f42b41ae0e68d45ccc22e0d8a Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Sun, 28 Jan 2024 21:59:18 -0800 Subject: [PATCH] Add the model license file for mlflow (#915) --- llmfoundry/callbacks/hf_checkpointer.py | 27 +++++++++++++++++++ setup.py | 3 ++- .../inference/test_convert_composer_to_hf.py | 25 ++++++++++++++--- 3 files changed, 51 insertions(+), 4 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index b50d81d09e..904f2e208f 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -6,6 +6,7 @@ import logging import math import os +import re import tempfile from pathlib import Path from typing import Optional, Sequence, Union @@ -27,6 +28,23 @@ log = logging.getLogger(__name__) +_LICENSE_FILE_PATTERN = re.compile(r'license(\.[a-z]+|$)', re.IGNORECASE) + + +def _maybe_get_license_filename(local_dir: str) -> Optional[str]: + """Returns the name of the license file if it exists in the local_dir. + + Note: This is intended to be consistent with the code in MLflow. + https://github.com/mlflow/mlflow/blob/5d13d6ec620a02de9a5e31201bf1becdb9722ea5/mlflow/transformers/__init__.py#L1152 + + If the license file does not exist, returns None. + """ + try: + return next(file for file in os.listdir(local_dir) + if _LICENSE_FILE_PATTERN.search(file)) + except StopIteration: + return None + class HuggingFaceCheckpointer(Callback): """Save a huggingface formatted checkpoint during training. @@ -279,6 +297,15 @@ def _save_checkpoint(self, state: State, logger: Logger): path=local_save_path, **self.mlflow_logging_config, ) + + license_filename = _maybe_get_license_filename( + local_save_path) + if license_filename is not None: + mlflow_logger._mlflow_client.log_artifact( + mlflow_logger._run_id, + os.path.join(local_save_path, license_filename), + ) + mlflow_logger.register_model( model_uri=local_save_path, name=self.mlflow_registered_model_name, diff --git a/setup.py b/setup.py index 511e665ed4..e5bc7e81d2 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,8 @@ ] install_requires = [ - 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17.2,<0.18', + 'mosaicml[libcloud,wandb,oci,gcs]>=0.17.2,<0.18', + 'mlflow>=2.10,<3', 'accelerate>=0.25,<0.26', # for HF inference `device_map` 'transformers>=4.37,<4.38', 'mosaicml-streaming>=0.7.2,<0.8', diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index deed181475..e85cdb213c 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -6,7 +6,7 @@ import pathlib import shutil from argparse import Namespace -from typing import Callable, Optional, cast +from typing import Any, Callable, Optional, cast from unittest.mock import ANY, MagicMock, patch import pytest @@ -22,6 +22,7 @@ from llmfoundry import COMPOSER_MODEL_REGISTRY from llmfoundry.callbacks import HuggingFaceCheckpointer +from llmfoundry.callbacks.hf_checkpointer import _maybe_get_license_filename from llmfoundry.data.finetuning import build_finetuning_dataloader from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM from llmfoundry.utils.builders import build_optimizer, build_tokenizer @@ -29,6 +30,10 @@ from tests.data_utils import make_tiny_ft_dataset +def _save_model_mock(*args: Any, path: str, **kwargs: Any): + os.makedirs(path, exist_ok=True) + + def check_hf_tokenizer_equivalence(tokenizer1: PreTrainedTokenizerBase, tokenizer2: PreTrainedTokenizerBase): """WARNING: Parameters are updated within the check so don't call check_hf_tokenizer_equivalence on the same @@ -297,7 +302,7 @@ def test_huggingface_conversion_callback_interval( mlflow_logger_mock = MagicMock(spec=MLFlowLogger) mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} - mlflow_logger_mock.save_model = MagicMock() + mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock) mlflow_logger_mock.register_model = MagicMock() mlflow_logger_mock.model_registry_prefix = '' mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' @@ -533,7 +538,7 @@ def test_huggingface_conversion_callback( mlflow_logger_mock = MagicMock(spec=MLFlowLogger) mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} - mlflow_logger_mock.save_model = MagicMock() + mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock) mlflow_logger_mock.register_model = MagicMock() mlflow_logger_mock.model_registry_prefix = '' mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' @@ -817,3 +822,17 @@ def test_convert_and_generate_meta(tie_word_embeddings: str, assert torch.allclose(p1, p2) delete_transformers_cache() + + +@pytest.mark.parametrize( + 'license_file_name', + ['LICENSE', 'LICENSE.txt', 'license', 'license.md', None]) +def test_license_file_finder(tmp_path: pathlib.Path, + license_file_name: Optional[str]): + if license_file_name is not None: + with open(os.path.join(tmp_path, license_file_name), 'w') as f: + f.write('test') + + found_path = _maybe_get_license_filename(str(tmp_path)) + assert (found_path == license_file_name + ) if license_file_name is not None else (found_path is None)