Skip to content

Commit

Permalink
make everything DataSpec with token counting function and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Oct 15, 2023
1 parent aecadc9 commit b3fe63b
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 8 deletions.
11 changes: 9 additions & 2 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

import numpy as np
import torch
from composer.core.data_spec import DataSpec
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.text_data import StreamingTextDataset
from llmfoundry.data.text_data import (StreamingTextDataset,
get_tokens_per_batch_func)
from llmfoundry.models import utils

__all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader']
Expand Down Expand Up @@ -506,7 +508,7 @@ def build_text_denoising_dataloader(
'but cfg.dataset.packing_ratio has not been set. Please set ' +\
'the latter to turn on packing or remove the former from the config.')

return DataLoader(
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=device_batch_size,
Expand All @@ -518,6 +520,11 @@ def build_text_denoising_dataloader(
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id)

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)


def noise_token_sequence(
example: Union[torch.Tensor, Mapping[str, Any]],
Expand Down
13 changes: 10 additions & 3 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import datasets as hf_datasets
import torch
from composer.core.data_spec import DataSpec
from composer.utils import dist, get_file, parse_uri
from omegaconf import DictConfig
from torch.utils.data import DataLoader
Expand All @@ -14,6 +15,7 @@
from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.text_data import get_tokens_per_batch_func

log = logging.getLogger(__name__)

Expand All @@ -23,7 +25,7 @@

def build_finetuning_dataloader(cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int) -> DataLoader:
device_batch_size: int) -> DataSpec:
"""Builds a finetuning dataloader for training or evaluating.
The underlying dataset can be built through one of two code paths:
Expand Down Expand Up @@ -143,7 +145,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)

return DataLoader(
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
Expand Down Expand Up @@ -193,7 +195,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)

assert dataset is not None
return DataLoader(
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
Expand All @@ -208,6 +210,11 @@ def build_finetuning_dataloader(cfg: DictConfig,
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id)

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)


def _validate_config(dataset_cfg: DictConfig) -> None:
"""Validates the dataset configuration.
Expand Down
29 changes: 27 additions & 2 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import numpy as np
import torch
import transformers
from composer.core.data_spec import DataSpec
from composer.core.types import Batch
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from streaming import Stream, StreamingDataset
Expand Down Expand Up @@ -237,7 +239,7 @@ def build_text_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
) -> DataLoader:
) -> DataSpec:
assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'
if cfg.dataset.get('group_method', None) is not None:
raise NotImplementedError(
Expand Down Expand Up @@ -281,7 +283,7 @@ def build_text_dataloader(
eos_token_id=eos_token_id,
bos_token_id=bos_token_id)

return DataLoader(
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=device_batch_size,
Expand All @@ -293,6 +295,29 @@ def build_text_dataloader(
timeout=cfg.get('timeout', 0),
)

# If we pretokenized, we may not have padding, in which the
# tokenizer may not have a pad_token_id. In this case, we can
# just use the default token counting function.
token_counting_func = None
if tokenizer.pad_token_id is not None:
token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id)

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)


def get_tokens_per_batch_func(pad_token_id: int) -> Callable[[Batch], int]:
def get_num_samples_in_batch(batch: Batch) -> int:
if not isinstance(batch, Mapping) or 'input_ids' not in batch:
raise ValueError(
'get_tokens_per_batch_func() requires a batch with an input_ids key'
)

# count number of non padding tokens in batch
return int(torch.sum(batch['input_ids'] != pad_token_id).item())

return get_num_samples_in_batch


# Helpful to test if your dataloader is working locally
# Run `python data.py --local_path [local] [--remote_path remote, optional]` and verify that batches are printed out
Expand Down
134 changes: 133 additions & 1 deletion tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,27 @@
import contextlib
import os
import pathlib
import random
import shutil
import sys
import tempfile
from argparse import Namespace
from typing import Optional
from unittest.mock import MagicMock

import pytest
import torch
import transformers
from composer.utils import dist, using_torch_2
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from streaming import MDSWriter

from llmfoundry import (build_finetuning_dataloader,
build_text_denoising_dataloader)
from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper,
build_text_dataloader)
build_text_dataloader,
get_tokens_per_batch_func)
from llmfoundry.utils.builders import build_tokenizer

# Add repo root to path so we can import scripts and test it
Expand Down Expand Up @@ -552,3 +557,130 @@ def test_malformed_data(
actual_num_batches += 1

assert actual_num_batches == expected_num_batches


@pytest.mark.parametrize('pad_token_id', [0, 100, 1000])
@pytest.mark.parametrize('batch_size', [1, 8, 16])
@pytest.mark.parametrize('model_max_length', [1024, 2048])
@pytest.mark.parametrize('padding_side', ['left', 'right'])
def test_token_counting_func(pad_token_id: int, batch_size: int,
model_max_length: int, padding_side: str):
gptt = transformers.AutoTokenizer.from_pretrained('gpt2')
gptt.pad_token_id = pad_token_id
gptt.model_max_length = model_max_length
gptt.padding_side = padding_side

batch_strings = []
expected_token_count = 0
for _ in range(batch_size):
sample_length = random.randint(1, model_max_length)
batch_strings.append(' '.join(['hello'] * sample_length))
expected_token_count += sample_length

batch_tokenized = gptt(batch_strings, padding=True, return_tensors='pt')

token_counting_func = get_tokens_per_batch_func(pad_token_id)

actual_token_count = token_counting_func(batch_tokenized)

assert actual_token_count == expected_token_count


@pytest.mark.parametrize('dataloader_type',
['finetuning-hf', 'finetuning-streaming', 'text'])
@pytest.mark.parametrize('pad_token_id', [100, None])
@pytest.mark.parametrize('batch_size', [1, 8])
@pytest.mark.parametrize('model_max_length', [1024])
@pytest.mark.parametrize('padding_side', ['left'])
def test_token_counting_func_dataloader_setting(
dataloader_type: str, pad_token_id: Optional[int], batch_size: int,
model_max_length: int, padding_side: str,
monkeypatch: pytest.MonkeyPatch):
gptt = transformers.AutoTokenizer.from_pretrained('gpt2')
gptt.pad_token_id = pad_token_id
gptt.model_max_length = model_max_length
gptt.padding_side = padding_side

batch_strings = []
expected_token_count = 0
for _ in range(batch_size):
sample_length = random.randint(
1,
model_max_length) if pad_token_id is not None else model_max_length
batch_strings.append(' '.join(['hello'] * sample_length))
expected_token_count += sample_length

batch_tokenized = gptt(batch_strings,
padding=True if pad_token_id is not None else False,
return_tensors='pt')

common_args = {
'drop_last': False,
'num_workers': 0,
'prefetch_factor': None if using_torch_2() else 2,
'pin_memory': False,
'persistent_workers': False,
'timeout': 0
}

if dataloader_type == 'finetuning-hf':
cfg = DictConfig({
'name': 'finetuning',
'dataset': {
'hf_name': 'dummy-path',
'split': 'train',
'max_seq_len': model_max_length,
'decoder_only_format': True,
'allow_pad_trimming': False,
'packing_ratio': None,
'shuffle': True,
},
**common_args
})
monkeypatch.setattr(
'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_hf',
lambda *args, **kwargs: [])
dl = build_finetuning_dataloader(cfg, gptt, batch_size)
elif dataloader_type == 'finetuning-streaming':
cfg = DictConfig({
'name': 'finetuning',
'dataset': {
'remote': 'dummy-path',
'local': 'dummy-path',
'split': 'train',
'max_seq_len': model_max_length,
'decoder_only_format': True,
'allow_pad_trimming': False,
'packing_ratio': None,
'shuffle': True,
},
**common_args
})
monkeypatch.setattr(
'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_streaming',
lambda *args, **kwargs: [])
dl = build_finetuning_dataloader(cfg, gptt, batch_size)
elif dataloader_type == 'text':
cfg = DictConfig({
'name': 'text',
'dataset': {
'local': 'dummy-path',
'remote': 'dummy-path',
'split': 'train',
'max_seq_len': model_max_length,
'shuffle': True,
'shuffle_seed': 0,
},
**common_args
})
monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset',
lambda *args, **kwargs: MagicMock())
dl = build_text_dataloader(cfg, gptt, batch_size)
else:
raise NotImplementedError()

cfg = om.create(cfg)

actual_token_count = dl.get_num_tokens_in_batch(batch_tokenized)

assert actual_token_count == expected_token_count

0 comments on commit b3fe63b

Please sign in to comment.