Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically select packing ratio #622

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.packing import BinPackCollator
from llmfoundry.data.text_data import StreamingTextDataset
from llmfoundry.models import utils

Expand Down Expand Up @@ -490,7 +490,7 @@ def build_text_denoising_dataloader(
raise NotImplementedError(
'On-the-fly packing is currently only supported for decoder-only formats.'
)
collate_fn = BinPackWrapper(
collate_fn = BinPackCollator(
collator=collate_fn,
target_batch_size=device_batch_size,
max_seq_len=cfg.dataset.max_seq_len,
Expand Down
105 changes: 54 additions & 51 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import os
from typing import Tuple, Union
from typing import Tuple, TypeVar, Union

import datasets as hf_datasets
import torch
Expand All @@ -13,7 +13,7 @@

from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.packing import BinPackCollator, BinPackDataset, auto_packing_ratio

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -141,20 +141,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)

return DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)

cfg, tokenizer, device_batch_size)
else:
backend, _, _ = parse_uri(cfg.dataset.hf_name)
if backend not in ['', None]:
Expand All @@ -172,7 +159,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)
cfg, tokenizer, device_batch_size)

if cfg.drop_last:
world_size = dist.get_world_size()
Expand All @@ -192,21 +179,23 @@ def build_finetuning_dataloader(cfg: DictConfig,
f'of samples in your dataset to at least {minimum_dataset_size}.'
)

assert dataset is not None
return DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
sampler=dist.get_sampler(dataset,
drop_last=cfg.drop_last,
shuffle=cfg.dataset.shuffle),
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)

assert dataset is not None
dataset = _maybe_apply_bin_packing(dataset, cfg, tokenizer, device_batch_size)
return DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
# sampler=dist.get_sampler(dataset, # TODO why was this not used in the first return in the original code?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo: add back in

# drop_last=cfg.drop_last,
# shuffle=cfg.dataset.shuffle),
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)


def _validate_config(dataset_cfg: DictConfig) -> None:
Expand Down Expand Up @@ -353,29 +342,28 @@ def _build_hf_dataset_from_remote(
)
return dataset

T = TypeVar('T')

def _build_collate_fn(
dataset_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int
) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackWrapper], int]:
collate_fn = Seq2SeqFinetuningCollator(
tokenizer=tokenizer,
max_seq_len=dataset_cfg.max_seq_len,
decoder_only_format=dataset_cfg.decoder_only_format,
allow_pad_trimming=dataset_cfg.get('allow_pad_trimming', False),
)

def _maybe_apply_bin_packing(
dataset: T,
dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int
) -> T:
dataset_cfg = dataloader_cfg.dataset
packing_ratio = dataset_cfg.get('packing_ratio')
if packing_ratio is None:
if dataset_cfg.get('max_leftover_bins_to_keep') is not None:
raise ValueError(
'dataset.max_leftover_bins_to_keep has been defined, ' +\
'but dataset.packing_ratio has not been set. Please set ' +\
'the latter to turn on packing or remove the former from the config.')
return collate_fn, device_batch_size
return dataset
if packing_ratio == 'auto':
packing_ratio = auto_packing_ratio(dataloader_cfg, tokenizer,
device_batch_size)

if packing_ratio == 1.0:
return collate_fn, device_batch_size
return dataset
elif packing_ratio < 1.0:
raise ValueError('packing_ratio must be >= 1, if supplied')

Expand All @@ -384,16 +372,31 @@ def _build_collate_fn(
'On-the-fly packing is currently only supported for decoder-only formats.'
)

collate_fn = BinPackWrapper(
collator=collate_fn,
bpd = BinPackDataset(
dataset,
packing_ratio,
target_batch_size=device_batch_size,
max_seq_len=dataset_cfg.max_seq_len,
pad_token_id=tokenizer.pad_token_id,
padding_side=tokenizer.padding_side,
max_leftover_bins_to_keep=dataset_cfg.get('max_leftover_bins_to_keep'),
)
n_examples_to_pack = int(device_batch_size * packing_ratio)
return collate_fn, n_examples_to_pack

return hf_datasets.IterableDataset.from_generator(bpd)

def _build_collate_fn(
dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int
) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]:
dataset_cfg = dataloader_cfg.dataset
collate_fn = Seq2SeqFinetuningCollator(
tokenizer=tokenizer,
max_seq_len=dataset_cfg.max_seq_len,
decoder_only_format=dataset_cfg.decoder_only_format,
allow_pad_trimming=dataset_cfg.get('allow_pad_trimming', False),
)

return collate_fn, device_batch_size




if __name__ == '__main__':
Expand Down
Loading
Loading