From b3fe63b8d6f4c6206088093f5e4f8391967f7faf Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 14 Oct 2023 20:06:52 -0700 Subject: [PATCH] make everything DataSpec with token counting function and tests --- llmfoundry/data/denoising.py | 11 +- llmfoundry/data/finetuning/dataloader.py | 13 ++- llmfoundry/data/text_data.py | 29 ++++- tests/test_dataloader.py | 134 ++++++++++++++++++++++- 4 files changed, 179 insertions(+), 8 deletions(-) diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index d685d0077d..98b2f662fe 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -10,13 +10,15 @@ import numpy as np import torch +from composer.core.data_spec import DataSpec from omegaconf import DictConfig from omegaconf import OmegaConf as om from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerBase from llmfoundry.data.packing import BinPackWrapper -from llmfoundry.data.text_data import StreamingTextDataset +from llmfoundry.data.text_data import (StreamingTextDataset, + get_tokens_per_batch_func) from llmfoundry.models import utils __all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader'] @@ -506,7 +508,7 @@ def build_text_denoising_dataloader( 'but cfg.dataset.packing_ratio has not been set. Please set ' +\ 'the latter to turn on packing or remove the former from the config.') - return DataLoader( + dl = DataLoader( dataset, collate_fn=collate_fn, batch_size=device_batch_size, @@ -518,6 +520,11 @@ def build_text_denoising_dataloader( timeout=cfg.get('timeout', 0), ) + token_counting_func = get_tokens_per_batch_func( + pad_token_id=tokenizer.pad_token_id) + + return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func) + def noise_token_sequence( example: Union[torch.Tensor, Mapping[str, Any]], diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index ebb7991dde..5c6c9e01ce 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -6,6 +6,7 @@ import datasets as hf_datasets import torch +from composer.core.data_spec import DataSpec from composer.utils import dist, get_file, parse_uri from omegaconf import DictConfig from torch.utils.data import DataLoader @@ -14,6 +15,7 @@ from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator from llmfoundry.data.finetuning.tasks import dataset_constructor from llmfoundry.data.packing import BinPackWrapper +from llmfoundry.data.text_data import get_tokens_per_batch_func log = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def build_finetuning_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - device_batch_size: int) -> DataLoader: + device_batch_size: int) -> DataSpec: """Builds a finetuning dataloader for training or evaluating. The underlying dataset can be built through one of two code paths: @@ -143,7 +145,7 @@ def build_finetuning_dataloader(cfg: DictConfig, collate_fn, dataloader_batch_size = _build_collate_fn( cfg.dataset, tokenizer, device_batch_size) - return DataLoader( + dl = DataLoader( dataset, collate_fn=collate_fn, batch_size=dataloader_batch_size, @@ -193,7 +195,7 @@ def build_finetuning_dataloader(cfg: DictConfig, ) assert dataset is not None - return DataLoader( + dl = DataLoader( dataset, collate_fn=collate_fn, batch_size=dataloader_batch_size, @@ -208,6 +210,11 @@ def build_finetuning_dataloader(cfg: DictConfig, timeout=cfg.get('timeout', 0), ) + token_counting_func = get_tokens_per_batch_func( + pad_token_id=tokenizer.pad_token_id) + + return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func) + def _validate_config(dataset_cfg: DictConfig) -> None: """Validates the dataset configuration. diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index afdd243adf..adb1421510 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -11,6 +11,8 @@ import numpy as np import torch import transformers +from composer.core.data_spec import DataSpec +from composer.core.types import Batch from omegaconf import DictConfig from omegaconf import OmegaConf as om from streaming import Stream, StreamingDataset @@ -237,7 +239,7 @@ def build_text_dataloader( cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int, -) -> DataLoader: +) -> DataSpec: assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}' if cfg.dataset.get('group_method', None) is not None: raise NotImplementedError( @@ -281,7 +283,7 @@ def build_text_dataloader( eos_token_id=eos_token_id, bos_token_id=bos_token_id) - return DataLoader( + dl = DataLoader( dataset, collate_fn=collate_fn, batch_size=device_batch_size, @@ -293,6 +295,29 @@ def build_text_dataloader( timeout=cfg.get('timeout', 0), ) + # If we pretokenized, we may not have padding, in which the + # tokenizer may not have a pad_token_id. In this case, we can + # just use the default token counting function. + token_counting_func = None + if tokenizer.pad_token_id is not None: + token_counting_func = get_tokens_per_batch_func( + pad_token_id=tokenizer.pad_token_id) + + return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func) + + +def get_tokens_per_batch_func(pad_token_id: int) -> Callable[[Batch], int]: + def get_num_samples_in_batch(batch: Batch) -> int: + if not isinstance(batch, Mapping) or 'input_ids' not in batch: + raise ValueError( + 'get_tokens_per_batch_func() requires a batch with an input_ids key' + ) + + # count number of non padding tokens in batch + return int(torch.sum(batch['input_ids'] != pad_token_id).item()) + + return get_num_samples_in_batch + # Helpful to test if your dataloader is working locally # Run `python data.py --local_path [local] [--remote_path remote, optional]` and verify that batches are printed out diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 6495eccf65..f1e3bd4d0f 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -3,22 +3,27 @@ import contextlib import os import pathlib +import random import shutil import sys import tempfile from argparse import Namespace from typing import Optional +from unittest.mock import MagicMock import pytest import torch +import transformers from composer.utils import dist, using_torch_2 +from omegaconf import DictConfig from omegaconf import OmegaConf as om from streaming import MDSWriter from llmfoundry import (build_finetuning_dataloader, build_text_denoising_dataloader) from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper, - build_text_dataloader) + build_text_dataloader, + get_tokens_per_batch_func) from llmfoundry.utils.builders import build_tokenizer # Add repo root to path so we can import scripts and test it @@ -552,3 +557,130 @@ def test_malformed_data( actual_num_batches += 1 assert actual_num_batches == expected_num_batches + + +@pytest.mark.parametrize('pad_token_id', [0, 100, 1000]) +@pytest.mark.parametrize('batch_size', [1, 8, 16]) +@pytest.mark.parametrize('model_max_length', [1024, 2048]) +@pytest.mark.parametrize('padding_side', ['left', 'right']) +def test_token_counting_func(pad_token_id: int, batch_size: int, + model_max_length: int, padding_side: str): + gptt = transformers.AutoTokenizer.from_pretrained('gpt2') + gptt.pad_token_id = pad_token_id + gptt.model_max_length = model_max_length + gptt.padding_side = padding_side + + batch_strings = [] + expected_token_count = 0 + for _ in range(batch_size): + sample_length = random.randint(1, model_max_length) + batch_strings.append(' '.join(['hello'] * sample_length)) + expected_token_count += sample_length + + batch_tokenized = gptt(batch_strings, padding=True, return_tensors='pt') + + token_counting_func = get_tokens_per_batch_func(pad_token_id) + + actual_token_count = token_counting_func(batch_tokenized) + + assert actual_token_count == expected_token_count + + +@pytest.mark.parametrize('dataloader_type', + ['finetuning-hf', 'finetuning-streaming', 'text']) +@pytest.mark.parametrize('pad_token_id', [100, None]) +@pytest.mark.parametrize('batch_size', [1, 8]) +@pytest.mark.parametrize('model_max_length', [1024]) +@pytest.mark.parametrize('padding_side', ['left']) +def test_token_counting_func_dataloader_setting( + dataloader_type: str, pad_token_id: Optional[int], batch_size: int, + model_max_length: int, padding_side: str, + monkeypatch: pytest.MonkeyPatch): + gptt = transformers.AutoTokenizer.from_pretrained('gpt2') + gptt.pad_token_id = pad_token_id + gptt.model_max_length = model_max_length + gptt.padding_side = padding_side + + batch_strings = [] + expected_token_count = 0 + for _ in range(batch_size): + sample_length = random.randint( + 1, + model_max_length) if pad_token_id is not None else model_max_length + batch_strings.append(' '.join(['hello'] * sample_length)) + expected_token_count += sample_length + + batch_tokenized = gptt(batch_strings, + padding=True if pad_token_id is not None else False, + return_tensors='pt') + + common_args = { + 'drop_last': False, + 'num_workers': 0, + 'prefetch_factor': None if using_torch_2() else 2, + 'pin_memory': False, + 'persistent_workers': False, + 'timeout': 0 + } + + if dataloader_type == 'finetuning-hf': + cfg = DictConfig({ + 'name': 'finetuning', + 'dataset': { + 'hf_name': 'dummy-path', + 'split': 'train', + 'max_seq_len': model_max_length, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + **common_args + }) + monkeypatch.setattr( + 'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_hf', + lambda *args, **kwargs: []) + dl = build_finetuning_dataloader(cfg, gptt, batch_size) + elif dataloader_type == 'finetuning-streaming': + cfg = DictConfig({ + 'name': 'finetuning', + 'dataset': { + 'remote': 'dummy-path', + 'local': 'dummy-path', + 'split': 'train', + 'max_seq_len': model_max_length, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + **common_args + }) + monkeypatch.setattr( + 'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_streaming', + lambda *args, **kwargs: []) + dl = build_finetuning_dataloader(cfg, gptt, batch_size) + elif dataloader_type == 'text': + cfg = DictConfig({ + 'name': 'text', + 'dataset': { + 'local': 'dummy-path', + 'remote': 'dummy-path', + 'split': 'train', + 'max_seq_len': model_max_length, + 'shuffle': True, + 'shuffle_seed': 0, + }, + **common_args + }) + monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset', + lambda *args, **kwargs: MagicMock()) + dl = build_text_dataloader(cfg, gptt, batch_size) + else: + raise NotImplementedError() + + cfg = om.create(cfg) + + actual_token_count = dl.get_num_tokens_in_batch(batch_tokenized) + + assert actual_token_count == expected_token_count