Skip to content

Commit

Permalink
Add the model license file for mlflow (#915)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Jan 29, 2024
1 parent bdcce63 commit 34cdaf6
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 4 deletions.
27 changes: 27 additions & 0 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import math
import os
import re
import tempfile
from pathlib import Path
from typing import Optional, Sequence, Union
Expand All @@ -27,6 +28,23 @@

log = logging.getLogger(__name__)

_LICENSE_FILE_PATTERN = re.compile(r'license(\.[a-z]+|$)', re.IGNORECASE)


def _maybe_get_license_filename(local_dir: str) -> Optional[str]:
"""Returns the name of the license file if it exists in the local_dir.
Note: This is intended to be consistent with the code in MLflow.
https://github.com/mlflow/mlflow/blob/5d13d6ec620a02de9a5e31201bf1becdb9722ea5/mlflow/transformers/__init__.py#L1152
If the license file does not exist, returns None.
"""
try:
return next(file for file in os.listdir(local_dir)
if _LICENSE_FILE_PATTERN.search(file))
except StopIteration:
return None


class HuggingFaceCheckpointer(Callback):
"""Save a huggingface formatted checkpoint during training.
Expand Down Expand Up @@ -279,6 +297,15 @@ def _save_checkpoint(self, state: State, logger: Logger):
path=local_save_path,
**self.mlflow_logging_config,
)

license_filename = _maybe_get_license_filename(
local_save_path)
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
mlflow_logger._run_id,
os.path.join(local_save_path, license_filename),
)

mlflow_logger.register_model(
model_uri=local_save_path,
name=self.mlflow_registered_model_name,
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
]

install_requires = [
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17.2,<0.18',
'mosaicml[libcloud,wandb,oci,gcs]>=0.17.2,<0.18',
'mlflow>=2.10,<3',
'accelerate>=0.25,<0.26', # for HF inference `device_map`
'transformers>=4.37,<4.38',
'mosaicml-streaming>=0.7.2,<0.8',
Expand Down
25 changes: 22 additions & 3 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pathlib
import shutil
from argparse import Namespace
from typing import Callable, Optional, cast
from typing import Any, Callable, Optional, cast
from unittest.mock import ANY, MagicMock, patch

import pytest
Expand All @@ -22,13 +22,18 @@

from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.callbacks.hf_checkpointer import _maybe_get_license_filename
from llmfoundry.data.finetuning import build_finetuning_dataloader
from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM
from llmfoundry.utils.builders import build_optimizer, build_tokenizer
from scripts.inference.convert_composer_to_hf import convert_composer_to_hf
from tests.data_utils import make_tiny_ft_dataset


def _save_model_mock(*args: Any, path: str, **kwargs: Any):
os.makedirs(path, exist_ok=True)


def check_hf_tokenizer_equivalence(tokenizer1: PreTrainedTokenizerBase,
tokenizer2: PreTrainedTokenizerBase):
"""WARNING: Parameters are updated within the check so don't call check_hf_tokenizer_equivalence on the same
Expand Down Expand Up @@ -297,7 +302,7 @@ def test_huggingface_conversion_callback_interval(

mlflow_logger_mock = MagicMock(spec=MLFlowLogger)
mlflow_logger_mock.state_dict = lambda *args, **kwargs: {}
mlflow_logger_mock.save_model = MagicMock()
mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock)
mlflow_logger_mock.register_model = MagicMock()
mlflow_logger_mock.model_registry_prefix = ''
mlflow_logger_mock._experiment_id = 'mlflow-experiment-id'
Expand Down Expand Up @@ -533,7 +538,7 @@ def test_huggingface_conversion_callback(

mlflow_logger_mock = MagicMock(spec=MLFlowLogger)
mlflow_logger_mock.state_dict = lambda *args, **kwargs: {}
mlflow_logger_mock.save_model = MagicMock()
mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock)
mlflow_logger_mock.register_model = MagicMock()
mlflow_logger_mock.model_registry_prefix = ''
mlflow_logger_mock._experiment_id = 'mlflow-experiment-id'
Expand Down Expand Up @@ -817,3 +822,17 @@ def test_convert_and_generate_meta(tie_word_embeddings: str,
assert torch.allclose(p1, p2)

delete_transformers_cache()


@pytest.mark.parametrize(
'license_file_name',
['LICENSE', 'LICENSE.txt', 'license', 'license.md', None])
def test_license_file_finder(tmp_path: pathlib.Path,
license_file_name: Optional[str]):
if license_file_name is not None:
with open(os.path.join(tmp_path, license_file_name), 'w') as f:
f.write('test')

found_path = _maybe_get_license_filename(str(tmp_path))
assert (found_path == license_file_name
) if license_file_name is not None else (found_path is None)

0 comments on commit 34cdaf6

Please sign in to comment.