diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index ea693a4105..96be6ad45d 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -5,7 +5,6 @@ import os import warnings from typing import Any, Dict, List, Optional, Tuple, Union -from copy import deepcopy import torch from composer import algorithms diff --git a/tests/test_builders.py b/tests/test_builders.py index 487eb0ffe1..a8b484bb24 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -6,13 +6,13 @@ import pytest from composer.callbacks import Generate -from transformers import PreTrainedTokenizerBase from omegaconf import DictConfig from omegaconf import OmegaConf as om +from transformers import PreTrainedTokenizerBase +from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.builders import build_callback, build_tokenizer -from llmfoundry.callbacks import HuggingFaceCheckpointer @pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [ @@ -82,19 +82,27 @@ def test_build_generate_callback_unspecified_interval(): 'something': 'else', }) + def test_build_hf_checkpointer_callback(): - with mock.patch.object(HuggingFaceCheckpointer, '__init__') as mock_hf_checkpointer: + with mock.patch.object(HuggingFaceCheckpointer, + '__init__') as mock_hf_checkpointer: mock_hf_checkpointer.return_value = None - save_folder = "path_to_save_folder" + save_folder = 'path_to_save_folder' save_interval = 1 - mlflow_logging_config_dict = {'metadata': {'databricks_model_family': 'MptForCausalLM', 'databricks_model_size_parameters': '7b', 'databricks_model_source': 'mosaic-fine-tuning', 'task': 'llm/v1/completions'}} - build_callback( - name='hf_checkpointer', - kwargs=om.create({ - "save_folder": save_folder, - "save_interval": save_interval, - "mlflow_logging_config": mlflow_logging_config_dict - })) + mlflow_logging_config_dict = { + 'metadata': { + 'databricks_model_family': 'MptForCausalLM', + 'databricks_model_size_parameters': '7b', + 'databricks_model_source': 'mosaic-fine-tuning', + 'task': 'llm/v1/completions' + } + } + build_callback(name='hf_checkpointer', + kwargs=om.create({ + 'save_folder': save_folder, + 'save_interval': save_interval, + 'mlflow_logging_config': mlflow_logging_config_dict + })) assert mock_hf_checkpointer.call_count == 1 _, _, kwargs = mock_hf_checkpointer.mock_calls[0]