Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix passed metadata to mlflow logging #713

Merged
merged 24 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,13 @@ def __init__(
if self.mlflow_registered_model_name is not None:
# Both the metadata and the task are needed in order for mlflow
# and databricks optimized model serving to work
if 'metadata' not in mlflow_logging_config:
mlflow_logging_config['metadata'] = {
'task': 'llm/v1/completions'
}
if 'task' not in mlflow_logging_config:
mlflow_logging_config['task'] = 'text-generation'
default_metadata = {'task': 'llm/v1/completions'}
passed_metadata = mlflow_logging_config.get('metadata', {})
mlflow_logging_config['metadata'] = {
**default_metadata,
**passed_metadata
}
mlflow_logging_config.setdefault('task', 'text-generation')
self.mlflow_logging_config = mlflow_logging_config

self.huggingface_folder_name_fstr = os.path.join(
Expand All @@ -93,7 +94,6 @@ def __init__(
self.save_interval = save_interval
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)

self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(
save_folder, loggers=[])
if self.remote_ud is not None:
Expand Down
5 changes: 4 additions & 1 deletion llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def build_icl_data_and_gauntlet(
return icl_evaluators, logger_keys, eval_gauntlet_cb


def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback:
def build_callback(name: str, kwargs: Union[DictConfig, Dict[str,
Any]]) -> Callback:
if name == 'lr_monitor':
return LRMonitor()
elif name == 'memory_monitor':
Expand Down Expand Up @@ -117,6 +118,8 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback:
elif name == 'early_stopper':
return EarlyStopper(**kwargs)
elif name == 'hf_checkpointer':
if isinstance(kwargs, DictConfig):
kwargs = om.to_object(kwargs) # pyright: ignore
return HuggingFaceCheckpointer(**kwargs)
else:
raise ValueError(f'Not sure how to build callback: {name}')
Expand Down
32 changes: 32 additions & 0 deletions tests/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import pytest
from composer.callbacks import Generate
from omegaconf import OmegaConf as om
from transformers import PreTrainedTokenizerBase

from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
from llmfoundry.utils.builders import build_callback, build_tokenizer

Expand Down Expand Up @@ -78,3 +80,33 @@ def test_build_generate_callback_unspecified_interval():
'foo': 'bar',
'something': 'else',
})


def test_build_hf_checkpointer_callback():
with mock.patch.object(HuggingFaceCheckpointer,
'__init__') as mock_hf_checkpointer:
mock_hf_checkpointer.return_value = None
save_folder = 'path_to_save_folder'
save_interval = 1
mlflow_logging_config_dict = {
'metadata': {
'databricks_model_family': 'MptForCausalLM',
'databricks_model_size_parameters': '7b',
'databricks_model_source': 'mosaic-fine-tuning',
'task': 'llm/v1/completions'
}
}
build_callback(name='hf_checkpointer',
kwargs=om.create({
'save_folder': save_folder,
'save_interval': save_interval,
'mlflow_logging_config': mlflow_logging_config_dict
}))

assert mock_hf_checkpointer.call_count == 1
_, _, kwargs = mock_hf_checkpointer.mock_calls[0]
assert kwargs['save_folder'] == save_folder
assert kwargs['save_interval'] == save_interval
assert isinstance(kwargs['mlflow_logging_config'], dict)
assert isinstance(kwargs['mlflow_logging_config']['metadata'], dict)
assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict
33 changes: 27 additions & 6 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import pathlib
import sys
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch

from composer import Trainer
from composer.loggers import MLFlowLogger
Expand Down Expand Up @@ -242,9 +242,22 @@ def get_config(
return cast(DictConfig, test_cfg)


def test_callback_inits_with_defaults():
def test_callback_inits():
# test with defaults
_ = HuggingFaceCheckpointer(save_folder='test', save_interval='1ba')

# test default metatdata when mlflow registered name is given
hf_checkpointer = HuggingFaceCheckpointer(
save_folder='test',
save_interval='1ba',
mlflow_registered_model_name='test_model_name')
assert hf_checkpointer.mlflow_logging_config == {
'task': 'text-generation',
'metadata': {
'task': 'llm/v1/completions'
}
}


@pytest.mark.world_size(2)
@pytest.mark.gpu
Expand Down Expand Up @@ -425,10 +438,18 @@ def test_huggingface_conversion_callback(
trainer.fit()

if dist.get_global_rank() == 0:
assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow
else 0)
assert mlflow_logger_mock.register_model.call_count == (
1 if log_to_mlflow else 0)
if log_to_mlflow:
assert mlflow_logger_mock.save_model.call_count == 1
mlflow_logger_mock.save_model.assert_called_with(
flavor='transformers',
transformers_model=ANY,
path=ANY,
task='text-generation',
metadata={'task': 'llm/v1/completions'})
assert mlflow_logger_mock.register_model.call_count == 1
else:
assert mlflow_logger_mock.save_model.call_count == 0
assert mlflow_logger_mock.register_model.call_count == 0
else:
assert mlflow_logger_mock.log_model.call_count == 0
assert mlflow_logger_mock.register_model.call_count == 0
Expand Down
Loading