Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wenfeiy-db committed Nov 15, 2023
1 parent 157e059 commit e801bbd
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import pathlib
import sys
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch

from composer import Trainer
from composer.loggers import MLFlowLogger
Expand Down Expand Up @@ -242,9 +242,22 @@ def get_config(
return cast(DictConfig, test_cfg)


def test_callback_inits_with_defaults():
def test_callback_inits():
# test with defaults
_ = HuggingFaceCheckpointer(save_folder='test', save_interval='1ba')

# test default metatdata when mlflow registered name is given
hf_checkpointer = HuggingFaceCheckpointer(
save_folder='test',
save_interval='1ba',
mlflow_registered_model_name='test_model_name')
assert hf_checkpointer.mlflow_logging_config == {
'task': 'text-generation',
'metadata': {
'task': 'llm/v1/completions'
}
}


@pytest.mark.world_size(2)
@pytest.mark.gpu
Expand Down Expand Up @@ -421,19 +434,14 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
trainer.fit()

if dist.get_global_rank() == 0:
# assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow
# else 0)
# assert mlflow_logger_mock.register_model.call_count == (
# 1 if log_to_mlflow else 0)
if log_to_mlflow:
assert mlflow_logger_mock.save_model.call_count == 1
mlflow_logger_mock.save_model.assert_called_with(
flavor='transformers',
transformers_model=ANY,
path=ANY,
task='text-generation',
metatdata={'task': 'llm/v1/completions'}
)
metadata={'task': 'llm/v1/completions'})
assert mlflow_logger_mock.register_model.call_count == 1
else:
assert mlflow_logger_mock.save_model.call_count == 0
Expand Down

0 comments on commit e801bbd

Please sign in to comment.