Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
wenfeiy-db committed Nov 10, 2023
1 parent a0626e2 commit f3ee2ae
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
1 change: 0 additions & 1 deletion llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from copy import deepcopy

import torch
from composer import algorithms
Expand Down
32 changes: 20 additions & 12 deletions tests/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

import pytest
from composer.callbacks import Generate
from transformers import PreTrainedTokenizerBase
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from transformers import PreTrainedTokenizerBase

from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
from llmfoundry.utils.builders import build_callback, build_tokenizer
from llmfoundry.callbacks import HuggingFaceCheckpointer


@pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [
Expand Down Expand Up @@ -82,19 +82,27 @@ def test_build_generate_callback_unspecified_interval():
'something': 'else',
})


def test_build_hf_checkpointer_callback():
with mock.patch.object(HuggingFaceCheckpointer, '__init__') as mock_hf_checkpointer:
with mock.patch.object(HuggingFaceCheckpointer,
'__init__') as mock_hf_checkpointer:
mock_hf_checkpointer.return_value = None
save_folder = "path_to_save_folder"
save_folder = 'path_to_save_folder'
save_interval = 1
mlflow_logging_config_dict = {'metadata': {'databricks_model_family': 'MptForCausalLM', 'databricks_model_size_parameters': '7b', 'databricks_model_source': 'mosaic-fine-tuning', 'task': 'llm/v1/completions'}}
build_callback(
name='hf_checkpointer',
kwargs=om.create({
"save_folder": save_folder,
"save_interval": save_interval,
"mlflow_logging_config": mlflow_logging_config_dict
}))
mlflow_logging_config_dict = {
'metadata': {
'databricks_model_family': 'MptForCausalLM',
'databricks_model_size_parameters': '7b',
'databricks_model_source': 'mosaic-fine-tuning',
'task': 'llm/v1/completions'
}
}
build_callback(name='hf_checkpointer',
kwargs=om.create({
'save_folder': save_folder,
'save_interval': save_interval,
'mlflow_logging_config': mlflow_logging_config_dict
}))

assert mock_hf_checkpointer.call_count == 1
_, _, kwargs = mock_hf_checkpointer.mock_calls[0]
Expand Down

0 comments on commit f3ee2ae

Please sign in to comment.