Skip to content

Commit

Permalink
Add support for auto packing ratio
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Sep 22, 2023
1 parent 6883562 commit 6d53fca
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 74 deletions.
13 changes: 9 additions & 4 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.packing import BinPackWrapper, auto_packing_ratio

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -141,7 +141,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)
cfg, tokenizer, device_batch_size)

return DataLoader(
dataset,
Expand Down Expand Up @@ -172,7 +172,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)
cfg, tokenizer, device_batch_size)

if cfg.drop_last:
world_size = dist.get_world_size()
Expand Down Expand Up @@ -355,9 +355,10 @@ def _build_hf_dataset_from_remote(


def _build_collate_fn(
dataset_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int
) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackWrapper], int]:
dataset_cfg = dataloader_cfg.dataset
collate_fn = Seq2SeqFinetuningCollator(
tokenizer=tokenizer,
max_seq_len=dataset_cfg.max_seq_len,
Expand All @@ -374,6 +375,10 @@ def _build_collate_fn(
'the latter to turn on packing or remove the former from the config.')
return collate_fn, device_batch_size

if packing_ratio == 'auto':
packing_ratio = auto_packing_ratio(dataloader_cfg, tokenizer,
device_batch_size)

if packing_ratio == 1.0:
return collate_fn, device_batch_size
elif packing_ratio < 1.0:
Expand Down
202 changes: 132 additions & 70 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,140 @@ def pad_tensor(tensor: torch.Tensor, pad_value: int):
return batch


def auto_packing_ratio(dataloader_cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int):
"""Find a packing ratio that minimizes padding with zero waste.
Args:
dataloader_cfg (DictConfig): The dataloader configuration for profiling.
tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling.
device_batch_size (int): The size of the batches (number of examples) per device.
Returns:
A packing ratio that minimizes padding while maintaining zero waste.
"""
min_ratio = 1
max_ratio = 20
num_packing_ratios = 10
profiling_results = profile_packing(dataloader_cfg, tokenizer, min_ratio,
max_ratio, num_packing_ratios,
device_batch_size)

# Obtain the maximum packing_ratio/minimum padding that has no waste.
i = 0
waste = 0
packing_ratio = 1
while i < len(profiling_results) and waste == 0:
packing_ratio, _, waste = profiling_results[i]
i += 1
return packing_ratio


def profile_packing(dataloader_cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase, min_ratio: float,
max_ratio: float, num_packing_ratios: int,
device_batch_size: int):
"""Profile packing.
Args:
dataloader_cfg (DictConfig): The dataloader configuration for profiling.
tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling.
min_ratio (float): Smallest packing_ratio to test. Must be >=1.
max_ratio (float): Largest packing_ratio to test. Must be larger than `min_ratio`.
num_packing_ratios (int): Number of packing_ratio values (spaced between `min_ratio` and `max_ratio`) to try.
device_batch_size (int): The size of the batches (number of examples) per device.
Returns:
A list of tuples of packing ratio, padding, and waste.
"""
import copy

from llmfoundry import (build_finetuning_dataloader,
build_text_denoising_dataloader)
from llmfoundry.data import build_text_dataloader

# Turn off packing for the dataloader (we want raw, pre-packed examples)
dataloader_cfg = copy.deepcopy(dataloader_cfg)
dataloader_cfg.dataset.packing_ratio = None
dataloader_cfg.dataset.max_leftovers_to_keep = None
dataloader_cfg.drop_last = False

# Determine the packing_ratio values we'll try
packing_ratios, raw_batch_sizes = [], []
for packing_ratio in np.linspace(min_ratio,
max_ratio,
num_packing_ratios,
endpoint=True):
packing_ratio = np.round(10 * packing_ratio) / 10
raw_batch_size = int(packing_ratio * device_batch_size)
if raw_batch_size not in raw_batch_sizes:
packing_ratios.append(packing_ratio)
raw_batch_sizes.append(raw_batch_size)

n_profile_examples = max(raw_batch_sizes) * 100

def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase):
if cfg.name == 'text':
return build_text_dataloader(cfg, tokenizer, n_profile_examples)
elif cfg.name == 'text_denoising':
return build_text_denoising_dataloader(cfg, tokenizer,
n_profile_examples)
elif cfg.name == 'finetuning':
return build_finetuning_dataloader(cfg, tokenizer,
n_profile_examples)
else:
raise ValueError(
f'Not sure how to build dataloader with config: {cfg}')

train_dataloader = build_dataloader(dataloader_cfg, tokenizer)

# Get a bunch of raw examples
big_batch = next(iter(train_dataloader))

def split_big_batch(raw_batch_size: int) -> List:
input_ids = big_batch['input_ids'].split(raw_batch_size)
batches = [{'input_ids': x} for x in input_ids]

for key in big_batch.keys():
if key == 'input_ids':
continue
for idx, split in enumerate(big_batch[key].split(raw_batch_size)):
batches[idx].update({key: split})
return batches

def profile(raw_batch_size: int) -> Tuple[float, float]:
packer = BinPackWrapper(
collator=lambda x: x,
target_batch_size=device_batch_size,
max_seq_len=dataloader_cfg.dataset.max_seq_len,
pad_token_id=0, # <-- Doesn't need to be correct for profiling
padding_side='left', # <-- Doesn't need to be correct for profiling
max_leftover_bins_to_keep=max_leftovers_to_keep)

# Simulate feeding the packing collator a bunch of data
for batch in split_big_batch(raw_batch_size):
if batch['input_ids'].shape[0] < device_batch_size:
continue
_ = packer(batch)

# Return the padding / waste stats over that bunch of data
padding_percent = 100 * (1 - packer.efficiency)
waste_percent = 100 * packer.waste
return padding_percent, waste_percent

results = []
for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes):
padding, waste = profile(raw_batch_size)
results.append((packing_ratio, padding, waste))
return results


if __name__ == '__main__':
from argparse import ArgumentParser, Namespace

from omegaconf import OmegaConf as om

from llmfoundry import (build_finetuning_dataloader,
build_text_denoising_dataloader)
from llmfoundry.data import build_text_dataloader
from llmfoundry.utils import build_tokenizer

def parse_args() -> Namespace:
Expand Down Expand Up @@ -316,20 +442,6 @@ def parse_args() -> Namespace:
raise ValueError('`num_packing_ratios` must be a positive integer.')
return args

def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int):
if cfg.name == 'text':
return build_text_dataloader(cfg, tokenizer, device_batch_size)
elif cfg.name == 'text_denoising':
return build_text_denoising_dataloader(cfg, tokenizer,
device_batch_size)
elif cfg.name == 'finetuning':
return build_finetuning_dataloader(cfg, tokenizer,
device_batch_size)
else:
raise ValueError(
f'Not sure how to build dataloader with config: {cfg}')

args = parse_args()

with open(args.yaml_path) as f:
Expand All @@ -339,18 +451,6 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
cfg = om.create(cfg)
device_batch_size = cfg.global_train_batch_size // args.num_devices

# Determine the packing_ratio values we'll try
packing_ratios, raw_batch_sizes = [], []
for packing_ratio in np.linspace(args.min,
args.max,
args.num_packing_ratios,
endpoint=True):
packing_ratio = np.round(10 * packing_ratio) / 10
raw_batch_size = int(packing_ratio * device_batch_size)
if raw_batch_size not in raw_batch_sizes:
packing_ratios.append(packing_ratio)
raw_batch_sizes.append(raw_batch_size)

# Fetch a bunch of raw examples once, which we'll re-use
if 'train_loader' not in cfg:
raise ValueError('config must define train_loader')
Expand All @@ -373,51 +473,13 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

# Turn off packing for the dataloader (we want raw, pre-packed examples)
dataloader_cfg.dataset.packing_ratio = None
dataloader_cfg.dataset.max_leftovers_to_keep = None
train_dataloader = build_dataloader(dataloader_cfg, tokenizer,
max(raw_batch_sizes) * 100)

# Get a bunch of raw examples
big_batch = next(iter(train_dataloader))

def split_big_batch(raw_batch_size: int) -> List:
input_ids = big_batch['input_ids'].split(raw_batch_size)
batches = [{'input_ids': x} for x in input_ids]

for key in big_batch.keys():
if key == 'input_ids':
continue
for idx, split in enumerate(big_batch[key].split(raw_batch_size)):
batches[idx].update({key: split})
return batches

def profile_packing(raw_batch_size: int) -> Tuple[float, float]:
packer = BinPackWrapper(
collator=lambda x: x,
target_batch_size=device_batch_size,
max_seq_len=dataloader_cfg.dataset.max_seq_len,
pad_token_id=0, # <-- Doesn't need to be correct for profiling
padding_side='left', # <-- Doesn't need to be correct for profiling
max_leftover_bins_to_keep=max_leftovers_to_keep)

# Simulate feeding the packing collator a bunch of data
for batch in split_big_batch(raw_batch_size):
if batch['input_ids'].shape[0] < device_batch_size:
continue
_ = packer(batch)

# Return the padding / waste stats over that bunch of data
padding_percent = 100 * (1 - packer.efficiency)
waste_percent = 100 * packer.waste
return padding_percent, waste_percent
results = profile_packing(dataloader_cfg, tokenizer, args.min, args.max,
args.num_packing_ratios, device_batch_size)

header = '\n\n\n packing_ratio | % PADDING | % WASTE'
fstr = ' {:5.1f} | {:5.2f}% | {:6.2f}%'

print(header)
print('-' * len(header))
for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes):
padding, waste = profile_packing(raw_batch_size)
for packing_ratio, padding, waste in results:
print(fstr.format(packing_ratio, padding, waste))

0 comments on commit 6d53fca

Please sign in to comment.