Skip to content

Commit

Permalink
Added simple tests
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 17, 2023
1 parent ebea921 commit 5e2f29a
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 121 deletions.
1 change: 0 additions & 1 deletion llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,6 @@ def _maybe_apply_bin_packing(
target_batch_size=device_batch_size,
max_seq_len=dataset_cfg.max_seq_len,
pad_token_id=tokenizer.pad_token_id,
padding_side=tokenizer.padding_side,
)

return hf_datasets.IterableDataset.from_generator(bpd)
Expand Down
56 changes: 6 additions & 50 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,38 @@ def __init__(
target_batch_size: int,
max_seq_len: int,
pad_token_id: int,
padding_side: Literal['left', 'right'],
):
self.dataset = dataset
self.packing_ratio = int(packing_ratio)
self.out_size = int(target_batch_size)
self.max_seq_len = int(max_seq_len)
self.pad_token_id = int(pad_token_id)
self.padding_side = padding_side
self.collator = BinPackCollator(
lambda x: x,
target_batch_size=self.out_size,
max_seq_len=max_seq_len,
pad_token_id=pad_token_id,
padding_side=padding_side,
max_leftover_bins_to_keep=None # Keep all leftovers.
)

def __call__(self) -> Iterable:
def __iter__(self) -> Iterable:
examples = []
for example in self.dataset:
examples.append(example)

if len(examples) == self.packing_ratio * self.out_size:
packed_examples = self.collator(examples)
print('len packed examples', len(packed_examples))
for packed_example in packed_examples:
yield packed_example
examples = []

# Finish the last batch
packed_examples = self.collator(examples)
for packed_example in packed_examples:
yield packed_example
examples = []
print('leftovers!', len(self.collator._leftover_bins))

# Iterate over leftovers.
# for _, leftover in self.collator._leftover_bins:
# yield leftover
for _, leftover in self.collator._leftover_bins:
yield leftover

class BinPackCollator:
"""Utility collator for packing to reduce padding."""
Expand All @@ -65,13 +61,11 @@ def __init__(self,
target_batch_size: int,
max_seq_len: int,
pad_token_id: int,
padding_side: Literal['left', 'right'],
max_leftover_bins_to_keep: Optional[int] = None):
self.base_collator = collator
self.out_size = int(target_batch_size)
self.max_seq_len = int(max_seq_len)
self.pad_token_id = int(pad_token_id)
self.padding_side = padding_side

if self.out_size <= 0:
raise ValueError(f'{target_batch_size=} must be >0.')
Expand Down Expand Up @@ -125,8 +119,6 @@ def __call__(
if self.max_leftover_bins_to_keep is not None:
leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep]
self._leftover_bins = leftover_bins


return packed_examples


Expand Down Expand Up @@ -243,42 +235,6 @@ def first_fit_bin_packing(
return packed_examples[:num_bins], sum(
bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:]


# def repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int,
# pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]:

# def pad_tensor(tensor: torch.Tensor, pad_value: int):
# if len(tensor) == max_seq_len:
# return tensor
# t = torch.full((max_seq_len,),
# pad_value,
# dtype=tensor.dtype,
# device=tensor.device)
# if padding_side == 'left':
# t[-len(tensor):] = tensor
# elif padding_side == 'right':
# t[:len(tensor)] = tensor
# else:
# raise ValueError(f'Unknown {padding_side=}')
# return t

# pad_vals = {
# 'input_ids': pad_token_id,
# 'labels': -100,
# 'attention_mask': 0,
# 'bidirectional_mask': 0,
# 'sequence_id': -1,
# }
# keys = packed_examples[0].keys()
# batch = {}
# for key in keys:
# batch[key] = torch.stack([
# pad_tensor(example[key], pad_vals[key])
# for example in packed_examples
# ])
# return batch


def auto_packing_ratio(dataloader_cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int):
Expand Down
175 changes: 105 additions & 70 deletions tests/test_packing.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,113 @@
from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader

from llmfoundry.utils.builders import build_tokenizer
from omegaconf import DictConfig
from composer.utils import reproducibility
from typing import List

from llmfoundry.data.packing import BinPackDataset
from torch.utils.data import IterableDataset

class TestDataset(IterableDataset):
def __init__(self, data: List[List[int]]):
super().__init__()
self.data = data

def __iter__(self):
for d in self.data:
yield {'input_ids': d }

def test_simple_packing():
dataset = TestDataset([
[1],
[2] * 2,
[8] * 8,
[9] * 9,
])

packed_dataset = BinPackDataset(
dataset,
packing_ratio=2,
target_batch_size=2,
max_seq_len=10,
pad_token_id=0,
)

packed_samples = [sample['input_ids'] for sample in packed_dataset]

assert packed_samples[0] == [8] * 8 + [2] * 2
assert packed_samples[1] == [9] * 9 + [1]

def test_simple_packing_with_leftovers():
dataset = TestDataset([
[5] * 5,
[6] * 6,
[5] * 5,
[7] * 7,
])

packed_dataset = BinPackDataset(
dataset,
packing_ratio=2,
target_batch_size=2,
max_seq_len=10,
pad_token_id=0,
)

packed_samples = [sample['input_ids'] for sample in packed_dataset]

assert packed_samples[0] == [5] * 10
assert packed_samples[1] == [7] * 7
assert packed_samples[2] == [6] * 6

# def test_auto_packing():
# reproducibility.seed_all(17)
# dataloader_cfg = DictConfig({
# 'name': 'finetuning',
# 'dataset': {
# 'hf_name': 'mosaicml/dolly_hhrlhf',
# 'split': 'train',
# 'max_seq_len': 1024,
# 'allow_pad_trimming': False,
# 'decoder_only_format': True,
# 'packing_ratio': 'auto',
# 'shuffle': False,
# },
# 'drop_last': False,
# 'num_workers': 1,
# 'pin_memory': False,
# 'prefetch_factor': 1,
# 'persistent_workers': True,
# 'timeout': 0,
# })

# def test_simple_packing():
# # TODO: write simple test
# # TODO: investigate base version, do outputs match okay?
# tokenizer = build_tokenizer('mosaicml/mpt-7b', {})
# BinPackDataset(
# dataset,
# packing_ratio,
# target_batch_size,
# 10, tokenizer.pad_token_id,
# 'left',
# )

def test_auto_packing():
reproducibility.seed_all(17)
dataloader_cfg = DictConfig({
'name': 'finetuning',
'dataset': {
'hf_name': 'mosaicml/dolly_hhrlhf',
'split': 'train',
'max_seq_len': 1024,
'allow_pad_trimming': False,
'decoder_only_format': True,
'packing_ratio': 'auto',
'shuffle': False,
},
'drop_last': False,
'num_workers': 1,
'pin_memory': False,
'prefetch_factor': 1,
'persistent_workers': True,
'timeout': 0,
})

tokenizer = build_tokenizer('mosaicml/mpt-7b', {})

dataloader = build_finetuning_dataloader(dataloader_cfg, tokenizer, 6)

print(next(iter(dataloader)))

# print('length!', len([sample for sample in dataloader]))

# dataloader_cfg = DictConfig({
# 'name': 'finetuning',
# 'dataset': {
# 'hf_name': 'mosaicml/dolly_hhrlhf',
# 'split': 'train',
# 'max_seq_len': 1024,
# 'allow_pad_trimming': False,
# 'decoder_only_format': True,
# 'shuffle': False,
# },
# 'drop_last': False,
# 'num_workers': 1,
# 'pin_memory': False,
# 'prefetch_factor': 1,
# 'persistent_workers': True,
# 'timeout': 0,
# })

# tokenizer = build_tokenizer('mosaicml/mpt-7b', {})

# dataloader = build_finetuning_dataloader(dataloader_cfg, tokenizer, 6)

# print(len(dataloader))

# dataloader = build_finetuning_dataloader(dataloader_cfg, tokenizer, 6)

# print(next(iter(dataloader)))

# # print('length!', len([sample for sample in dataloader]))

# # dataloader_cfg = DictConfig({
# # 'name': 'finetuning',
# # 'dataset': {
# # 'hf_name': 'mosaicml/dolly_hhrlhf',
# # 'split': 'train',
# # 'max_seq_len': 1024,
# # 'allow_pad_trimming': False,
# # 'decoder_only_format': True,
# # 'shuffle': False,
# # },
# # 'drop_last': False,
# # 'num_workers': 1,
# # 'pin_memory': False,
# # 'prefetch_factor': 1,
# # 'persistent_workers': True,
# # 'timeout': 0,
# # })

# # tokenizer = build_tokenizer('mosaicml/mpt-7b', {})

# # dataloader = build_finetuning_dataloader(dataloader_cfg, tokenizer, 6)

# # print(len(dataloader))


# for sample in dataloader:
# print(sample)
# # for sample in dataloader:
# # print(sample)

0 comments on commit 5e2f29a

Please sign in to comment.