Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
wenfeiy-db committed Nov 7, 2023
1 parent 5872f3f commit 3e17fd0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
9 changes: 5 additions & 4 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,12 @@ def __init__(
if self.mlflow_registered_model_name is not None:
# Both the metadata and the task are needed in order for mlflow
# and databricks optimized model serving to work
default_metadata = {
'task': 'llm/v1/completions'
}
default_metadata = {'task': 'llm/v1/completions'}
passed_metadata = mlflow_logging_config.get('metadata', {})
mlflow_logging_config['metadata'] = {**default_metadata, **passed_metadata}
mlflow_logging_config['metadata'] = {
**default_metadata,
**passed_metadata
}
mlflow_logging_config.setdefault('task', 'text-generation')
self.mlflow_logging_config = mlflow_logging_config

Expand Down
24 changes: 19 additions & 5 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import pathlib
import sys
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, patch, ANY

from composer import Trainer
from composer.loggers import MLFlowLogger
Expand Down Expand Up @@ -421,10 +421,24 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
trainer.fit()

if dist.get_global_rank() == 0:
assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow
else 0)
assert mlflow_logger_mock.register_model.call_count == (
1 if log_to_mlflow else 0)
if log_to_mlflow:
# assert mlflow_logger_mock.save_model.call_count == 1
mlflow_logger_mock.save_model.assert_called_once_with(
flavor='transformers',
transformers_model=ANY,
path=ANY,
task='text-generation',
metatdata={'task': 'llm/v1/completions'})
assert mlflow_logger_mock.register_model.call_count == 1
# mlflow_logger.save_model(
# flavor='transformers',
# transformers_model=components,
# path=local_save_path,
# **self.mlflow_logging_config,
# )
else:
assert mlflow_logger_mock.save_model.call_count == 0
assert mlflow_logger_mock.register_model.call_count == 0
else:
assert mlflow_logger_mock.log_model.call_count == 0
assert mlflow_logger_mock.register_model.call_count == 0
Expand Down

0 comments on commit 3e17fd0

Please sign in to comment.