Skip to content

Commit

Permalink
fixes, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Nov 28, 2023
1 parent 204d2f7 commit 5b85218
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 94 deletions.
8 changes: 4 additions & 4 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ def build_tokenizer(

signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup'

if dist.is_available() and dist.is_initialized(
) and dist.get_world_size() > 1:
if dist.is_available() and dist.is_initialized() and dist.get_world_size(
) > 1:
# Make sure the tokenizer files are downloaded and cached first by local rank 0
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass
Expand All @@ -244,8 +244,8 @@ def build_tokenizer(
int(1e30),
)

if dist.is_available() and dist.is_initialized(
) and dist.get_world_size() > 1:
if dist.is_available() and dist.is_initialized() and dist.get_world_size(
) > 1:
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_tokenizer_setup')
Expand Down
4 changes: 4 additions & 0 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def main(cfg: DictConfig):
eval_gauntlet_df = None
models_df = None
composite_scores = None
trainers = []
for model_cfg in model_configs:
(trainer, logger_keys, eval_gauntlet_callback,
eval_gauntlet_df) = evaluate_model(
Expand All @@ -311,6 +312,7 @@ def main(cfg: DictConfig):
precision=precision,
eval_gauntlet_df=eval_gauntlet_df,
icl_subset_num_batches=icl_subset_num_batches)
trainers.append(trainer)

if eval_gauntlet_callback is not None:
composite_scores = eval_gauntlet_callback.eval_after_all(
Expand Down Expand Up @@ -349,6 +351,8 @@ def main(cfg: DictConfig):
assert models_df is not None
print(models_df.to_markdown(index=False))

return trainers, eval_gauntlet_df


def calculate_markdown_results(logger_keys: List[str], trainer: Trainer,
benchmark_to_taxonomy: Dict[str, str],
Expand Down
95 changes: 95 additions & 0 deletions tests/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,23 @@

import json
import os
import pathlib
import shutil
import sys
from argparse import Namespace
from typing import Optional

from omegaconf import DictConfig
from omegaconf import OmegaConf as om

from scripts.data_prep.convert_dataset_hf import main as main_hf # noqa: E402
from scripts.data_prep.convert_dataset_json import \
main as main_json # noqa: E402

# Add repo root to path so we can import scripts and test it
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(repo_dir)


def make_tiny_ft_dataset(
path: str,
Expand Down Expand Up @@ -65,3 +80,83 @@ def make_tiny_ft_dataset(
for sample in samples:
_f.write(json.dumps(sample))
_f.write('\n')


def create_c4_dataset_xsmall(path: pathlib.Path) -> str:
"""Creates a small mocked version of the C4 dataset."""
c4_dir = os.path.join(path, f'my-copy-c4')
downloaded_split = 'val_xsmall' # very fast to convert

# Hyperparameters from https://github.com/mosaicml/llm-foundry/blob/340a56658560ebceb2a3aa69d6e37813e415acd0/README.md#L188
main_hf(
Namespace(
**{
'dataset': 'c4',
'data_subset': 'en',
'splits': [downloaded_split],
'out_root': c4_dir,
'compression': None,
'concat_tokens': 2048,
'tokenizer': 'EleutherAI/gpt-neox-20b',
'tokenizer_kwargs': {},
'bos_text': '',
'eos_text': '<|endoftext|>',
'no_wrap': False,
'num_workers': 8
}))

# copy the small downloaded_split to other c4 splits for mocking purposes
mocked_splits = ['train', 'val']
for mocked_split in mocked_splits:
shutil.copytree(os.path.join(c4_dir, 'val_xsmall'),
os.path.join(c4_dir, mocked_split))
assert os.path.exists(c4_dir)
return c4_dir


def create_arxiv_dataset(path: pathlib.Path) -> str:
"""Creates an arxiv dataset."""
arxiv_dir = os.path.join(path, f'my-copy-arxiv')
downloaded_split = 'train'

main_json(
Namespace(
**{
'path': 'data_prep/example_data/arxiv.jsonl',
'out_root': arxiv_dir,
'compression': None,
'split': downloaded_split,
'concat_tokens': None,
'bos_text': None,
'eos_text': None,
'no_wrap': False,
'num_workers': None
}))

return arxiv_dir


def gpt_tiny_cfg(dataset_name: str, device: str):
"""Create gpt tiny cfg."""
conf_path: str = os.path.join(repo_dir,
'scripts/train/yamls/pretrain/testing.yaml')
with open(conf_path) as f:
test_cfg = om.load(f)
assert isinstance(test_cfg, DictConfig)

test_cfg.data_local = dataset_name
test_cfg.global_train_batch_size = 8
test_cfg.device_eval_batch_size = 4
test_cfg.device_train_microbatch_size = 4
test_cfg.max_duration = '4ba'
test_cfg.eval_interval = '4ba'
test_cfg.run_name = 'gpt-mini-integration-test'

if device == 'cpu':
test_cfg.model.init_device = 'cpu'
test_cfg.fsdp_config = None
test_cfg.model.attn_config.attn_impl = 'torch'
test_cfg.model.loss_fn = 'torch_crossentropy'
test_cfg.precision = 'fp32'

return test_cfg
86 changes: 86 additions & 0 deletions tests/test_eval.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import copy
import os
import pathlib
import sys
from typing import Any

import omegaconf as om
import pytest
from composer import Trainer
from composer.loggers import InMemoryLogger

from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.utils import build_tokenizer
from tests.data_utils import (create_arxiv_dataset, create_c4_dataset_xsmall,
gpt_tiny_cfg)

# Add repo root to path so we can import scripts and test it
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
Expand Down Expand Up @@ -66,3 +71,84 @@ def test_icl_eval(capfd: Any, mock_saved_model_path: Any):
assert expected_results in out
expected_results = '| model_name | default_average | language_understanding_lite |\n|:-------------|------------------:|------------------------------:|\n| tiny_mpt | 0 | 0 |'
assert expected_results in out


@pytest.mark.gpu
def test_loader_eval(capfd: Any, mock_saved_model_path: Any,
tmp_path: pathlib.Path):

c4_dataset_name = create_c4_dataset_xsmall(tmp_path)

# Use a training config that already has eval loader configured
test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu')

# define icl eval task
test_cfg.icl_tasks = om.ListConfig([
om.DictConfig({
'label':
'lambada_openai',
'dataset_uri':
'eval/local_data/language_understanding/lambada_openai_small.jsonl',
'num_fewshot': [0],
'icl_task_type':
'language_modeling'
})
])

# convert the model from a training to eval model
model = test_cfg.pop('model')
new_model = {
'model_name': model.get('name'),
'model': model,
'load_path': mock_saved_model_path
}

tokenizer = test_cfg.pop('tokenizer', None)
if tokenizer:
new_model['tokenizer'] = tokenizer
test_cfg.models = [new_model]

# Set up multiple eval dataloaders
first_eval_loader = test_cfg.eval_loader
first_eval_loader.label = 'c4'
# Create second eval dataloader using the arxiv dataset.
second_eval_loader = copy.deepcopy(first_eval_loader)
arxiv_dataset_name = create_arxiv_dataset(tmp_path)
second_eval_loader.data_local = arxiv_dataset_name
second_eval_loader.label = 'arxiv'
test_cfg.eval_loader = om.OmegaConf.create(
[first_eval_loader, second_eval_loader])

trainers, eval_gauntlet_df = main(test_cfg)
assert eval_gauntlet_df is None

assert len(trainers) == 1 # one per model
trainer = trainers[0]

assert isinstance(trainer.logger.destinations, tuple)

assert len(trainer.logger.destinations) > 0
inmemorylogger = trainer.logger.destinations[
0] # pyright: ignore [reportGeneralTypeIssues]
assert isinstance(inmemorylogger, InMemoryLogger)
print(inmemorylogger.data.keys())

# Checks for first eval dataloader
assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys()
assert isinstance(
inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], list)
assert len(
inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1]) > 0
assert isinstance(
inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], tuple)

# Checks for second eval dataloader
assert 'metrics/eval/arxiv/LanguageCrossEntropy' in inmemorylogger.data.keys(
)
assert isinstance(
inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'], list)
assert len(
inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1]) > 0
assert isinstance(
inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1],
tuple)
92 changes: 2 additions & 90 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,104 +3,16 @@
import copy
import os
import pathlib
import shutil
import sys
from argparse import Namespace
from typing import Any, Optional

import pytest
from composer.loggers import InMemoryLogger
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om

# Add repo root to path so we can import scripts and test it
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(repo_dir)

from scripts.data_prep.convert_dataset_hf import main as main_hf # noqa: E402
from scripts.data_prep.convert_dataset_json import \
main as main_json # noqa: E402
from scripts.train.train import main # noqa: E402


def create_c4_dataset_xsmall(path: pathlib.Path) -> str:
"""Creates a small mocked version of the C4 dataset."""
c4_dir = os.path.join(path, f'my-copy-c4')
downloaded_split = 'val_xsmall' # very fast to convert

# Hyperparameters from https://github.com/mosaicml/llm-foundry/blob/340a56658560ebceb2a3aa69d6e37813e415acd0/README.md#L188
main_hf(
Namespace(
**{
'dataset': 'c4',
'data_subset': 'en',
'splits': [downloaded_split],
'out_root': c4_dir,
'compression': None,
'concat_tokens': 2048,
'tokenizer': 'EleutherAI/gpt-neox-20b',
'tokenizer_kwargs': {},
'bos_text': '',
'eos_text': '<|endoftext|>',
'no_wrap': False,
'num_workers': 8
}))

# copy the small downloaded_split to other c4 splits for mocking purposes
mocked_splits = ['train', 'val']
for mocked_split in mocked_splits:
shutil.copytree(os.path.join(c4_dir, 'val_xsmall'),
os.path.join(c4_dir, mocked_split))
assert os.path.exists(c4_dir)
return c4_dir


def create_arxiv_dataset(path: pathlib.Path) -> str:
"""Creates an arxiv dataset."""
arxiv_dir = os.path.join(path, f'my-copy-arxiv')
downloaded_split = 'train'

main_json(
Namespace(
**{
'path': 'data_prep/example_data/arxiv.jsonl',
'out_root': arxiv_dir,
'compression': None,
'split': downloaded_split,
'concat_tokens': None,
'bos_text': None,
'eos_text': None,
'no_wrap': False,
'num_workers': None
}))

return arxiv_dir


def gpt_tiny_cfg(dataset_name: str, device: str):
"""Create gpt tiny cfg."""
conf_path: str = os.path.join(repo_dir,
'scripts/train/yamls/pretrain/testing.yaml')
with open(conf_path) as f:
test_cfg = om.load(f)
assert isinstance(test_cfg, DictConfig)

test_cfg.data_local = dataset_name
test_cfg.global_train_batch_size = 8
test_cfg.device_eval_batch_size = 4
test_cfg.device_train_microbatch_size = 4
test_cfg.max_duration = '4ba'
test_cfg.eval_interval = '4ba'
test_cfg.run_name = 'gpt-mini-integration-test'

if device == 'cpu':
test_cfg.model.init_device = 'cpu'
test_cfg.fsdp_config = None
test_cfg.model.attn_config.attn_impl = 'torch'
test_cfg.model.loss_fn = 'torch_crossentropy'
test_cfg.precision = 'fp32'

return test_cfg
from tests.data_utils import (create_arxiv_dataset, create_c4_dataset_xsmall,
gpt_tiny_cfg)


@pytest.fixture(autouse=False)
Expand Down

0 comments on commit 5b85218

Please sign in to comment.