Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert to DataSpec and add token counts that include padding #676

Merged
merged 15 commits into from
Oct 17, 2023
18 changes: 14 additions & 4 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

import numpy as np
import torch
from composer.core.data_spec import DataSpec
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.text_data import StreamingTextDataset
from llmfoundry.data.text_data import (StreamingTextDataset,
get_tokens_per_batch_func)
from llmfoundry.models import utils

__all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader']
Expand Down Expand Up @@ -353,7 +355,7 @@ def build_text_denoising_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
) -> DataLoader[Dict]:
) -> DataSpec:
"""Constructor function for a Mixture of Denoisers dataloader.

This function constructs a dataloader that can be used to train an
Expand Down Expand Up @@ -506,7 +508,7 @@ def build_text_denoising_dataloader(
'but cfg.dataset.packing_ratio has not been set. Please set ' +\
'the latter to turn on packing or remove the former from the config.')

return DataLoader(
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=device_batch_size,
Expand All @@ -518,6 +520,12 @@ def build_text_denoising_dataloader(
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id,
decoder_only=cfg.mixture_of_denoisers.decoder_only_format)

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)


def noise_token_sequence(
example: Union[torch.Tensor, Mapping[str, Any]],
Expand Down Expand Up @@ -869,7 +877,9 @@ def _format_tokens_for_decoder_only(
tokenizer = build_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_kwargs=tokenizer_kwargs)

loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size)
loader = build_text_denoising_dataloader(cfg, tokenizer,
device_batch_size).dataloader
assert isinstance(loader, DataLoader)
assert isinstance(loader.dataset, StreamingTextDataset)

print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n')
Expand Down
16 changes: 12 additions & 4 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import datasets as hf_datasets
import torch
from composer.core.data_spec import DataSpec
from composer.utils import dist, get_file, parse_uri
from omegaconf import DictConfig
from torch.utils.data import DataLoader
Expand All @@ -14,6 +15,7 @@
from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.text_data import get_tokens_per_batch_func

log = logging.getLogger(__name__)

Expand All @@ -23,7 +25,7 @@

def build_finetuning_dataloader(cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int) -> DataLoader:
device_batch_size: int) -> DataSpec:
"""Builds a finetuning dataloader for training or evaluating.

The underlying dataset can be built through one of two code paths:
Expand Down Expand Up @@ -143,7 +145,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)

return DataLoader(
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
Expand Down Expand Up @@ -193,7 +195,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)

assert dataset is not None
return DataLoader(
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
Expand All @@ -208,6 +210,11 @@ def build_finetuning_dataloader(cfg: DictConfig,
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id)

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)


def _validate_config(dataset_cfg: DictConfig) -> None:
"""Validates the dataset configuration.
Expand Down Expand Up @@ -442,7 +449,8 @@ def _build_collate_fn(
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

device_batch_size = 2
dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)
dataloader = build_finetuning_dataloader(cfg, tokenizer,
device_batch_size).dataloader

packing = cfg.dataset.get('packing_ratio') is not None

Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
dataloader_cfg.dataset.packing_ratio = None
dataloader_cfg.dataset.max_leftovers_to_keep = None
train_dataloader = build_dataloader(dataloader_cfg, tokenizer,
max(raw_batch_sizes) * 100)
max(raw_batch_sizes) * 100).dataloader

# Get a bunch of raw examples
big_batch = next(iter(train_dataloader))
Expand Down
61 changes: 58 additions & 3 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import numpy as np
import torch
import transformers
from composer.core.data_spec import DataSpec
from composer.core.types import Batch
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from streaming import Stream, StreamingDataset
Expand Down Expand Up @@ -237,7 +239,7 @@ def build_text_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
) -> DataLoader:
) -> DataSpec:
assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'
if cfg.dataset.get('group_method', None) is not None:
raise NotImplementedError(
Expand Down Expand Up @@ -281,7 +283,7 @@ def build_text_dataloader(
eos_token_id=eos_token_id,
bos_token_id=bos_token_id)

return DataLoader(
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=device_batch_size,
Expand All @@ -293,6 +295,58 @@ def build_text_dataloader(
timeout=cfg.get('timeout', 0),
)

# If we pretokenized, we may not have padding, in which case the
# tokenizer may not have a pad_token_id. In this case, we can
# just use the default token counting function. This is correct
# because we do not support training on pretokenized data with padding,
# and if tokenizing on the fly, we require that the tokenizer has a pad token.
token_counting_func = None
if tokenizer.pad_token_id is not None:
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id)

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)


def get_tokens_per_batch_func(pad_token_id: int,
decoder_only: bool = True
) -> Callable[[Batch], int]:
"""Returns a callable that counts the number of tokens in a batch.

Args:
pad_token_id (int): The id of the padding token.
decoder_only (bool, optional): Whether to expect the batch to just contain ``input_ids`` (decoder only)
or to also contain ``decoder_input_ids`` (encoder decoder). Defaults to ``True``.

Returns:
Callable[[Batch], int]: A callable that counts the number of tokens in a batch.
"""

def get_num_samples_in_batch(batch: Batch) -> int:
if not isinstance(batch, Mapping) or 'input_ids' not in batch:
raise ValueError(
'get_tokens_per_batch_func() requires a batch with an input_ids key'
)

if not decoder_only and 'decoder_input_ids' not in batch:
raise ValueError(
'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_input_ids key'
)

# Count number of non padding tokens in batch
input_ids_tokens = int(
torch.sum(batch['input_ids'] != pad_token_id).item())

# For encoder decoder models only
decoder_input_ids_tokens = 0
if not decoder_only:
decoder_input_ids_tokens = int(
torch.sum(batch['decoder_input_ids'] != pad_token_id).item())

return input_ids_tokens + decoder_input_ids_tokens

return get_num_samples_in_batch


# Helpful to test if your dataloader is working locally
# Run `python data.py --local_path [local] [--remote_path remote, optional]` and verify that batches are printed out
Expand Down Expand Up @@ -353,7 +407,8 @@ def build_text_dataloader(
tokenizer_kwargs = {'model_max_length': args.max_seq_len}
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

loader = build_text_dataloader(cfg, tokenizer, device_batch_size)
loader = build_text_dataloader(cfg, tokenizer, device_batch_size).dataloader
assert isinstance(loader, DataLoader)
assert isinstance(loader.dataset, StreamingTextDataset)
tokenizer = loader.dataset.tokenizer

Expand Down
Loading
Loading