Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorganize tests to make them easier to find (#768) #16

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,46 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:

__all__ = ['dataset_constructor']

_ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
_ALLOWED_PROMPT_KEYS = {'prompt'}


def _tokenize_formatted_example(
example: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]:
if ('prompt' not in example) or ('response' not in example):
"""Tokenize a formatted example and validate expected keys."""
example_keys = set(example.keys())
prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS)
response_keys = example_keys.intersection(_ALLOWED_RESPONSE_KEYS)

if len(prompt_keys) != 1:
raise KeyError(
f'Unable to tokenize example because {len(prompt_keys)} of the allowed prompt keys ' +\
f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_PROMPT_KEYS=}'
)

if len(response_keys) != 1:
raise KeyError(
'Unable to tokenize example because it has not been properly formatted. ' +\
'"prompt" and "response" are required keys but at least one was missing ' +\
f'from {example=}.'
f'Unable to tokenize example because {len(response_keys)} of the allowed response keys ' +\
f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_RESPONSE_KEYS=}'
)
if not isinstance(example['prompt'], str):

prompt_key = prompt_keys.pop()
response_key = response_keys.pop()
prompt = example[prompt_key]
response = example[response_key]

if not isinstance(prompt, str):
raise TypeError(
f'Unable to tokenize example because "prompt" was not a string. {example=}'
f'Unable to tokenize example because {prompt_key} was not a string. {example=}'
)
if not isinstance(example['response'], str):

if not isinstance(response, str):
raise TypeError(
f'Unable to tokenize example because "response" was not a string. {example=}'
f'Unable to tokenize example because {response_key} was not a string. {example=}'
)
return tokenizer(text=example['prompt'], text_target=example['response'])

return tokenizer(text=prompt, text_target=response)


class StreamingFinetuningDataset(StreamingDataset):
Expand Down
30 changes: 21 additions & 9 deletions scripts/inference/convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,11 @@ def parse_args() -> Namespace:
return parser.parse_args()


def convert_composer_to_hf(args: Namespace) -> None:
def _convert_composer_to_hf(args: Namespace) -> None:
print()
print('#' * 30)
print('Converting Composer checkpoint to HuggingFace checkpoint format...')

# Register MPT auto classes so that this script works with MPT
# This script will not work without modification for other custom models,
# but will work for other HuggingFace causal LMs
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
MPTConfig.register_for_auto_class()
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')

_, _, local_folder_path = parse_uri(args.hf_output_path)

config, tokenizer = write_huggingface_pretrained_from_composer_checkpoint(
Expand Down Expand Up @@ -296,5 +288,25 @@ def convert_composer_to_hf(args: Namespace) -> None:
)


def convert_composer_to_hf(args: Namespace) -> None:
# Register MPT auto classes so that this script works with MPT
# This script will not work without modification for other custom models,
# but will work for other HuggingFace causal LMs
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
MPTConfig.register_for_auto_class()
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')

try:
_convert_composer_to_hf(args)
except Exception as e:
raise e
finally:
# Undo auto registration after running the script
del CONFIG_MAPPING._extra_content['mpt']
delattr(MPTConfig, '_auto_class')
delattr(MPTForCausalLM, '_auto_class')


if __name__ == '__main__':
convert_composer_to_hf(parse_args())
6 changes: 6 additions & 0 deletions tests/a_scripts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# TODO: This test directory is called "a_scripts" to enforce that these tests are run
# first. More clean up should be done to ensure tests can be run in any order and
# don't leave around artifacts
2 changes: 2 additions & 0 deletions tests/a_scripts/data_prep/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
28 changes: 28 additions & 0 deletions tests/a_scripts/data_prep/test_convert_dataset_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import os
from argparse import Namespace
from pathlib import Path

from scripts.data_prep.convert_dataset_hf import main as main_hf


def test_download_script_from_api(tmp_path: Path):
# test calling it directly
path = os.path.join(tmp_path, 'my-copy-c4-1')
main_hf(
Namespace(
**{
'dataset': 'c4',
'data_subset': 'en',
'splits': ['val_xsmall'],
'out_root': path,
'compression': None,
'concat_tokens': None,
'bos_text': None,
'eos_text': None,
'no_wrap': False,
'num_workers': None
}))
assert os.path.exists(path)
27 changes: 27 additions & 0 deletions tests/a_scripts/data_prep/test_convert_dataset_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import os
from argparse import Namespace
from pathlib import Path

from scripts.data_prep.convert_dataset_json import main as main_json


def test_json_script_from_api(tmp_path: Path):
# test calling it directly
path = os.path.join(tmp_path, 'my-copy-arxiv-1')
main_json(
Namespace(
**{
'path': 'scripts/data_prep/example_data/arxiv.jsonl',
'out_root': path,
'compression': None,
'split': 'train',
'concat_tokens': None,
'bos_text': None,
'eos_text': None,
'no_wrap': False,
'num_workers': None
}))
assert os.path.exists(path)
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,14 @@
# SPDX-License-Identifier: Apache-2.0

import os
import sys

import pytest

# Add repo root to path so we can import scripts and test it
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(repo_dir)
import pathlib
from concurrent.futures import ProcessPoolExecutor
from glob import glob
from typing import Callable, Iterable, List
from unittest.mock import Mock, patch

import numpy as np
import pytest
from streaming import StreamingDataset
from transformers import AutoTokenizer

Expand Down
2 changes: 2 additions & 0 deletions tests/a_scripts/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
34 changes: 16 additions & 18 deletions tests/test_eval.py → tests/a_scripts/eval/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import copy
import os
import pathlib
import sys
from typing import Any
from typing import Any, Union

import omegaconf as om
import pytest
Expand All @@ -14,15 +13,10 @@

from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.utils import build_tokenizer
from scripts.eval.eval import main # noqa: E402
from tests.data_utils import (create_arxiv_dataset, create_c4_dataset_xxsmall,
gpt_tiny_cfg)

# Add repo root to path so we can import scripts and test it
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(repo_dir)

from scripts.eval.eval import main # noqa: E402


@pytest.fixture(autouse=True)
def set_correct_cwd():
Expand All @@ -35,11 +29,16 @@ def set_correct_cwd():
os.chdir('..')


@pytest.fixture()
def mock_saved_model_path():
# load the eval and model config
with open('eval/yamls/test_eval.yaml', 'r', encoding='utf-8') as f:
@pytest.fixture
def eval_cfg(foundry_dir: str) -> Union[om.ListConfig, om.DictConfig]:
yaml_path = os.path.join(foundry_dir, 'scripts/eval/yamls/test_eval.yaml')
with open(yaml_path, 'r', encoding='utf-8') as f:
eval_cfg = om.OmegaConf.load(f)
return eval_cfg


@pytest.fixture()
def mock_saved_model_path(eval_cfg: Union[om.ListConfig, om.DictConfig]):
model_cfg = eval_cfg.models[0]
# set device to cpu
device = 'cpu'
Expand All @@ -60,12 +59,11 @@ def mock_saved_model_path():
os.remove(saved_model_path)


def test_icl_eval(capfd: Any, mock_saved_model_path: Any):
with open('eval/yamls/test_eval.yaml', 'r', encoding='utf-8') as f:
test_cfg = om.OmegaConf.load(f)
test_cfg.models[0].load_path = mock_saved_model_path
assert isinstance(test_cfg, om.DictConfig)
main(test_cfg)
def test_icl_eval(eval_cfg: Union[om.ListConfig, om.DictConfig], capfd: Any,
mock_saved_model_path: Any):
eval_cfg.models[0].load_path = mock_saved_model_path
assert isinstance(eval_cfg, om.DictConfig)
main(eval_cfg)
out, _ = capfd.readouterr()
expected_results = '| Category | Benchmark | Subtask | Accuracy | Number few shot | Model |\n|:----------------------------|:---------------|:----------|-----------:|:------------------|:---------|\n| language_understanding_lite | lambada_openai | | 0 | 0-shot | tiny_mpt |'
assert expected_results in out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,26 @@
# SPDX-License-Identifier: Apache-2.0
import copy
import os
import sys
import warnings

import omegaconf
import pytest
from omegaconf import DictConfig
from omegaconf import OmegaConf as om

# Add repo root to path so we can import scripts and test it
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(repo_dir)

from scripts.eval.eval import main # noqa: E402


class TestHuggingFaceEvalYAMLInputs:
"""Validate and tests error handling for the input YAML file."""

@pytest.fixture
def cfg(self) -> DictConfig:
def cfg(self, foundry_dir: str) -> DictConfig:
"""Create YAML cfg fixture for testing purposes."""
conf_path: str = os.path.join(repo_dir,
'scripts/eval/yamls/hf_eval.yaml')
conf_path: str = os.path.join(
foundry_dir,
'scripts/eval/yamls/hf_eval.yaml',
)
with open(conf_path, 'r', encoding='utf-8') as config:
test_cfg = om.load(config)
assert isinstance(test_cfg, DictConfig)
Expand Down Expand Up @@ -78,15 +75,17 @@ def test_optional_mispelled_params_raise_warning(self,
class TestMPTEvalYAMLInputs:

@pytest.fixture
def cfg(self) -> DictConfig:
def cfg(self, foundry_dir: str) -> DictConfig:
"""Create YAML cfg fixture for testing purposes."""
conf_path: str = os.path.join(repo_dir,
'scripts/eval/yamls/mpt_eval.yaml')
conf_path: str = os.path.join(
foundry_dir,
'scripts/eval/yamls/mpt_eval.yaml',
)
with open(conf_path, 'r', encoding='utf-8') as config:
test_cfg = om.load(config)

test_cfg.icl_tasks[0].dataset_uri = os.path.join(
repo_dir, 'scripts', test_cfg.icl_tasks[0].dataset_uri)
foundry_dir, 'scripts', test_cfg.icl_tasks[0].dataset_uri)

# make tests use cpu initialized transformer models only
test_cfg.models[0].model.init_device = 'cpu'
Expand Down
2 changes: 2 additions & 0 deletions tests/a_scripts/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,26 @@
import math
import os
import pathlib
import sys
from typing import Callable
from unittest.mock import ANY, MagicMock, patch

from composer import Trainer
from composer.loggers import MLFlowLogger
from composer.utils import dist, get_device, using_torch_2

from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM

# Add repo root to path so we can import scripts and test it
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(repo_dir)
import shutil
from argparse import Namespace
from typing import Optional, cast
from typing import Callable, Optional, cast
from unittest.mock import ANY, MagicMock, patch

import pytest
import torch
import transformers
from composer import Trainer
from composer.loggers import MLFlowLogger
from composer.utils import dist, get_device, using_torch_2
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from torch.utils.data import DataLoader
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.data.finetuning import build_finetuning_dataloader
from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM
from llmfoundry.utils.builders import build_optimizer, build_tokenizer
from scripts.inference.convert_composer_to_hf import convert_composer_to_hf
from tests.data_utils import make_tiny_ft_dataset
Expand Down
2 changes: 2 additions & 0 deletions tests/a_scripts/train/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
Loading