diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index d685d0077d..f7ff642be0 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -15,7 +15,7 @@ from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerBase -from llmfoundry.data.packing import BinPackWrapper +from llmfoundry.data.packing import BinPackCollator from llmfoundry.data.text_data import StreamingTextDataset from llmfoundry.models import utils @@ -490,7 +490,7 @@ def build_text_denoising_dataloader( raise NotImplementedError( 'On-the-fly packing is currently only supported for decoder-only formats.' ) - collate_fn = BinPackWrapper( + collate_fn = BinPackCollator( collator=collate_fn, target_batch_size=device_batch_size, max_seq_len=cfg.dataset.max_seq_len, diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 661b1e808d..fd4b438fb6 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging import os -from typing import Tuple, Union +from typing import Tuple, TypeVar, Union import datasets as hf_datasets import torch @@ -13,7 +13,7 @@ from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator from llmfoundry.data.finetuning.tasks import dataset_constructor -from llmfoundry.data.packing import BinPackWrapper +from llmfoundry.data.packing import BinPackCollator, BinPackDataset, auto_packing_ratio log = logging.getLogger(__name__) @@ -141,20 +141,7 @@ def build_finetuning_dataloader(cfg: DictConfig, ) collate_fn, dataloader_batch_size = _build_collate_fn( - cfg.dataset, tokenizer, device_batch_size) - - return DataLoader( - dataset, - collate_fn=collate_fn, - batch_size=dataloader_batch_size, - drop_last=cfg.drop_last, - num_workers=cfg.num_workers, - pin_memory=cfg.get('pin_memory', True), - prefetch_factor=cfg.get('prefetch_factor', 2), - persistent_workers=cfg.get('persistent_workers', True), - timeout=cfg.get('timeout', 0), - ) - + cfg, tokenizer, device_batch_size) else: backend, _, _ = parse_uri(cfg.dataset.hf_name) if backend not in ['', None]: @@ -172,7 +159,7 @@ def build_finetuning_dataloader(cfg: DictConfig, ) collate_fn, dataloader_batch_size = _build_collate_fn( - cfg.dataset, tokenizer, device_batch_size) + cfg, tokenizer, device_batch_size) if cfg.drop_last: world_size = dist.get_world_size() @@ -192,21 +179,23 @@ def build_finetuning_dataloader(cfg: DictConfig, f'of samples in your dataset to at least {minimum_dataset_size}.' ) - assert dataset is not None - return DataLoader( - dataset, - collate_fn=collate_fn, - batch_size=dataloader_batch_size, - drop_last=cfg.drop_last, - sampler=dist.get_sampler(dataset, - drop_last=cfg.drop_last, - shuffle=cfg.dataset.shuffle), - num_workers=cfg.num_workers, - pin_memory=cfg.get('pin_memory', True), - prefetch_factor=cfg.get('prefetch_factor', 2), - persistent_workers=cfg.get('persistent_workers', True), - timeout=cfg.get('timeout', 0), - ) + + assert dataset is not None + dataset = _maybe_apply_bin_packing(dataset, cfg, tokenizer, device_batch_size) + return DataLoader( + dataset, + collate_fn=collate_fn, + batch_size=dataloader_batch_size, + drop_last=cfg.drop_last, + # sampler=dist.get_sampler(dataset, # TODO why was this not used in the first return in the original code? + # drop_last=cfg.drop_last, + # shuffle=cfg.dataset.shuffle), + num_workers=cfg.num_workers, + pin_memory=cfg.get('pin_memory', True), + prefetch_factor=cfg.get('prefetch_factor', 2), + persistent_workers=cfg.get('persistent_workers', True), + timeout=cfg.get('timeout', 0), + ) def _validate_config(dataset_cfg: DictConfig) -> None: @@ -353,18 +342,14 @@ def _build_hf_dataset_from_remote( ) return dataset +T = TypeVar('T') -def _build_collate_fn( - dataset_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - device_batch_size: int -) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackWrapper], int]: - collate_fn = Seq2SeqFinetuningCollator( - tokenizer=tokenizer, - max_seq_len=dataset_cfg.max_seq_len, - decoder_only_format=dataset_cfg.decoder_only_format, - allow_pad_trimming=dataset_cfg.get('allow_pad_trimming', False), - ) - +def _maybe_apply_bin_packing( + dataset: T, + dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, + device_batch_size: int +) -> T: + dataset_cfg = dataloader_cfg.dataset packing_ratio = dataset_cfg.get('packing_ratio') if packing_ratio is None: if dataset_cfg.get('max_leftover_bins_to_keep') is not None: @@ -372,10 +357,13 @@ def _build_collate_fn( 'dataset.max_leftover_bins_to_keep has been defined, ' +\ 'but dataset.packing_ratio has not been set. Please set ' +\ 'the latter to turn on packing or remove the former from the config.') - return collate_fn, device_batch_size + return dataset + if packing_ratio == 'auto': + packing_ratio = auto_packing_ratio(dataloader_cfg, tokenizer, + device_batch_size) if packing_ratio == 1.0: - return collate_fn, device_batch_size + return dataset elif packing_ratio < 1.0: raise ValueError('packing_ratio must be >= 1, if supplied') @@ -384,16 +372,31 @@ def _build_collate_fn( 'On-the-fly packing is currently only supported for decoder-only formats.' ) - collate_fn = BinPackWrapper( - collator=collate_fn, + bpd = BinPackDataset( + dataset, + packing_ratio, target_batch_size=device_batch_size, max_seq_len=dataset_cfg.max_seq_len, pad_token_id=tokenizer.pad_token_id, - padding_side=tokenizer.padding_side, - max_leftover_bins_to_keep=dataset_cfg.get('max_leftover_bins_to_keep'), ) - n_examples_to_pack = int(device_batch_size * packing_ratio) - return collate_fn, n_examples_to_pack + + return hf_datasets.IterableDataset.from_generator(bpd) + +def _build_collate_fn( + dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, + device_batch_size: int +) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]: + dataset_cfg = dataloader_cfg.dataset + collate_fn = Seq2SeqFinetuningCollator( + tokenizer=tokenizer, + max_seq_len=dataset_cfg.max_seq_len, + decoder_only_format=dataset_cfg.decoder_only_format, + allow_pad_trimming=dataset_cfg.get('allow_pad_trimming', False), + ) + + return collate_fn, device_batch_size + + if __name__ == '__main__': diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index d0a73be801..1378e30189 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -2,15 +2,58 @@ # SPDX-License-Identifier: Apache-2.0 import os -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple import numpy as np import torch from omegaconf import DictConfig from transformers import PreTrainedTokenizerBase - - -class BinPackWrapper: +from torch.utils.data import IterableDataset + + +class BinPackDataset(IterableDataset): + """An IterableDataset that returns packed examples.""" + def __init__( + self, + dataset: IterableDataset, + packing_ratio: int, + target_batch_size: int, + max_seq_len: int, + pad_token_id: int, + ): + self.dataset = dataset + self.packing_ratio = int(packing_ratio) + self.out_size = int(target_batch_size) + self.max_seq_len = int(max_seq_len) + self.pad_token_id = int(pad_token_id) + self.collator = BinPackCollator( + lambda x: x, + target_batch_size=self.out_size, + max_seq_len=max_seq_len, + pad_token_id=pad_token_id, + max_leftover_bins_to_keep=None # Keep all leftovers. + ) + def __iter__(self) -> Iterable: + examples = [] + for example in self.dataset: + examples.append(example) + if len(examples) == self.packing_ratio * self.out_size: + packed_examples = self.collator(examples) + print('len packed examples', len(packed_examples)) + for packed_example in packed_examples: + yield packed_example + examples = [] + # Finish the last batch + packed_examples = self.collator(examples) + for packed_example in packed_examples: + yield packed_example + examples = [] + + # Iterate over leftovers. + for _, leftover in self.collator._leftover_bins: + yield leftover + +class BinPackCollator: """Utility collator for packing to reduce padding.""" def __init__(self, @@ -18,13 +61,11 @@ def __init__(self, target_batch_size: int, max_seq_len: int, pad_token_id: int, - padding_side: Literal['left', 'right'], max_leftover_bins_to_keep: Optional[int] = None): self.base_collator = collator self.out_size = int(target_batch_size) self.max_seq_len = int(max_seq_len) self.pad_token_id = int(pad_token_id) - self.padding_side = padding_side if self.out_size <= 0: raise ValueError(f'{target_batch_size=} must be >0.') @@ -33,13 +74,13 @@ def __init__(self, if self.pad_token_id < 0: raise ValueError(f'{pad_token_id=} must be >=0.') - if max_leftover_bins_to_keep is None: - self.max_leftover_bins_to_keep = int(10 * self.out_size) - elif max_leftover_bins_to_keep < 0: - raise ValueError( - f'{max_leftover_bins_to_keep=} must be >=0 or None.') - else: - self.max_leftover_bins_to_keep = int(max_leftover_bins_to_keep) + self.max_leftover_bins_to_keep = max_leftover_bins_to_keep + if max_leftover_bins_to_keep is not None: + if max_leftover_bins_to_keep < 0: + raise ValueError( + f'{max_leftover_bins_to_keep=} must be >=0 or None.') + else: + self.max_leftover_bins_to_keep = int(max_leftover_bins_to_keep) self.n_packed_tokens = 0 self.n_total_tokens = 0 @@ -59,76 +100,46 @@ def efficiency(self) -> float: def __call__( self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: - batch = self.base_collator(examples) - - assert 'attention_mask' in batch - assert 'input_ids' in batch - - for key in batch.keys(): - assert key in [ - 'input_ids', - 'labels', - 'attention_mask', - 'bidirectional_mask', - ] - # Cut everything down to size - sizes, trimmed_examples = [], [] - for idx in range(batch['attention_mask'].shape[0]): - size, trimmed_example = extract_trim_batch_idx(batch, idx) - sizes.append(size) - trimmed_examples.append(trimmed_example) + sizes = [len(example['input_ids']) for example in examples] # Apply our CS 101 bin packing algorithm. packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = first_fit_bin_packing( sizes=sizes, - examples=trimmed_examples, + examples=examples, num_bins=self.out_size, max_bin_size=self.max_seq_len, existing_bins=self._leftover_bins, ) + print('leftovers', len(leftover_bins)) self.n_packed_tokens += n_packed_tokens self.n_total_tokens += n_total_tokens self.n_packed_examples += self.out_size - self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep] - - # Re-pad to max_seq_len and batch - batch = repad(packed_examples, - max_seq_len=self.max_seq_len, - pad_token_id=self.pad_token_id, - padding_side=self.padding_side) - return batch - -def extract_trim_batch_idx(batch: Dict[str, torch.Tensor], - idx: int) -> Tuple[int, Dict[str, torch.Tensor]]: - example = {k: v[idx] for k, v in batch.items()} - - keep = example['attention_mask'] == 1 - size = int(keep.sum()) - trim_example = {k: v[keep] for k, v in example.items()} - trim_example['sequence_id'] = torch.zeros_like(trim_example['input_ids']) - - return size, trim_example + if self.max_leftover_bins_to_keep is not None: + leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep] + self._leftover_bins = leftover_bins + return packed_examples def combine_in_place( - example: Dict[str, torch.Tensor], - add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + example: Dict[str, List[int]], + add_on: Dict[str, List[int]]) -> Dict[str, List[int]]: + if 'labels' in add_on: # Prevents the last token in example from being trained to # predict the first token in add_on, which would make no sense. add_on['labels'][0] = -100 for k in example.keys(): - if k == 'sequence_id': - example[k] = torch.cat( - [example[k], add_on[k] + 1 + torch.max(example[k])]) - else: - example[k] = torch.cat([example[k], add_on[k]]) + # if k == 'sequence_id': + # example[k] = torch.cat( + # [example[k], add_on[k] + 1 + torch.max(example[k])]) + # else: + example[k] = example[k] + add_on[k] + return example - def first_fit_bin_packing( sizes: List[int], examples: List[Dict[str, torch.Tensor]], num_bins: int, max_bin_size: int, existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]] @@ -224,40 +235,134 @@ def first_fit_bin_packing( return packed_examples[:num_bins], sum( bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:] +def auto_packing_ratio(dataloader_cfg: DictConfig, + tokenizer: PreTrainedTokenizerBase, + device_batch_size: int): + """Find a packing ratio that minimizes padding with zero waste. + + Args: + dataloader_cfg (DictConfig): The dataloader configuration for profiling. + tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling. + device_batch_size (int): The size of the batches (number of examples) per device. + + Returns: + A packing ratio that minimizes padding while maintaining zero waste. + """ + # min_ratio = 2 + # max_ratio = 2 + # num_packing_ratios = 1 + # profiling_results = profile_packing(dataloader_cfg, tokenizer, min_ratio, + # max_ratio, num_packing_ratios, + # device_batch_size) + + # # Obtain the maximum packing_ratio/minimum padding that has no waste. + # i = 0 + # waste = 0 + # packing_ratio = 1 + # while i < len(profiling_results) and waste == 0: + # packing_ratio, _, waste = profiling_results[i] + # i += 1 + packing_ratio = 15 + return packing_ratio + + +def profile_packing(dataloader_cfg: DictConfig, + tokenizer: PreTrainedTokenizerBase, min_ratio: float, + max_ratio: float, num_packing_ratios: int, + device_batch_size: int): + """Profile packing. + + Args: + dataloader_cfg (DictConfig): The dataloader configuration for profiling. + tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling. + min_ratio (float): Smallest packing_ratio to test. Must be >=1. + max_ratio (float): Largest packing_ratio to test. Must be larger than `min_ratio`. + num_packing_ratios (int): Number of packing_ratio values (spaced between `min_ratio` and `max_ratio`) to try. + device_batch_size (int): The size of the batches (number of examples) per device. + + Returns: + A list of tuples of packing ratio, padding, and waste. + """ + import copy + + from llmfoundry import (build_finetuning_dataloader, + build_text_denoising_dataloader) + from llmfoundry.data import build_text_dataloader + + # Turn off packing for the dataloader (we want raw, pre-packed examples) + dataloader_cfg = copy.deepcopy(dataloader_cfg) + dataloader_cfg.dataset.packing_ratio = None + dataloader_cfg.dataset.max_leftovers_to_keep = None + dataloader_cfg.drop_last = False + + # Determine the packing_ratio values we'll try + packing_ratios, raw_batch_sizes = [], [] + for packing_ratio in np.linspace(min_ratio, + max_ratio, + num_packing_ratios, + endpoint=True): + packing_ratio = np.round(10 * packing_ratio) / 10 + raw_batch_size = int(packing_ratio * device_batch_size) + if raw_batch_size not in raw_batch_sizes: + packing_ratios.append(packing_ratio) + raw_batch_sizes.append(raw_batch_size) + + n_profile_examples = max(raw_batch_sizes) * 100 -def repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int, - pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]: - - def pad_tensor(tensor: torch.Tensor, pad_value: int): - if len(tensor) == max_seq_len: - return tensor - t = torch.full((max_seq_len,), - pad_value, - dtype=tensor.dtype, - device=tensor.device) - if padding_side == 'left': - t[-len(tensor):] = tensor - elif padding_side == 'right': - t[:len(tensor)] = tensor + def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase): + if cfg.name == 'text': + return build_text_dataloader(cfg, tokenizer, n_profile_examples) + elif cfg.name == 'text_denoising': + return build_text_denoising_dataloader(cfg, tokenizer, + n_profile_examples) + elif cfg.name == 'finetuning': + return build_finetuning_dataloader(cfg, tokenizer, + n_profile_examples) else: - raise ValueError(f'Unknown {padding_side=}') - return t - - pad_vals = { - 'input_ids': pad_token_id, - 'labels': -100, - 'attention_mask': 0, - 'bidirectional_mask': 0, - 'sequence_id': -1, - } - keys = packed_examples[0].keys() - batch = {} - for key in keys: - batch[key] = torch.stack([ - pad_tensor(example[key], pad_vals[key]) - for example in packed_examples - ]) - return batch + raise ValueError( + f'Not sure how to build dataloader with config: {cfg}') + + train_dataloader = build_dataloader(dataloader_cfg, tokenizer) + + # Get a bunch of raw examples + big_batch = next(iter(train_dataloader)) + + def split_big_batch(raw_batch_size: int) -> List: + input_ids = big_batch['input_ids'].split(raw_batch_size) + batches = [{'input_ids': x} for x in input_ids] + + for key in big_batch.keys(): + if key == 'input_ids': + continue + for idx, split in enumerate(big_batch[key].split(raw_batch_size)): + batches[idx].update({key: split}) + return batches + + def profile(raw_batch_size: int) -> Tuple[float, float]: + packer = BinPackCollator( + collator=lambda x: x, + target_batch_size=device_batch_size, + max_seq_len=dataloader_cfg.dataset.max_seq_len, + pad_token_id=0, # <-- Doesn't need to be correct for profiling + padding_side='left', # <-- Doesn't need to be correct for profiling + max_leftover_bins_to_keep=dataloader_cfg.dataset.max_leftovers_to_keep) + + # Simulate feeding the packing collator a bunch of data + for batch in split_big_batch(raw_batch_size): + if batch['input_ids'].shape[0] < device_batch_size: + continue + _ = packer(batch) + + # Return the padding / waste stats over that bunch of data + padding_percent = 100 * (1 - packer.efficiency) + waste_percent = 100 * packer.waste + return padding_percent, waste_percent + + results = [] + for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes): + padding, waste = profile(raw_batch_size) + results.append((packing_ratio, padding, waste)) + return results if __name__ == '__main__': @@ -265,9 +370,6 @@ def pad_tensor(tensor: torch.Tensor, pad_value: int): from omegaconf import OmegaConf as om - from llmfoundry import (build_finetuning_dataloader, - build_text_denoising_dataloader) - from llmfoundry.data import build_text_dataloader from llmfoundry.utils import build_tokenizer def parse_args() -> Namespace: @@ -316,20 +418,6 @@ def parse_args() -> Namespace: raise ValueError('`num_packing_ratios` must be a positive integer.') return args - def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - device_batch_size: int): - if cfg.name == 'text': - return build_text_dataloader(cfg, tokenizer, device_batch_size) - elif cfg.name == 'text_denoising': - return build_text_denoising_dataloader(cfg, tokenizer, - device_batch_size) - elif cfg.name == 'finetuning': - return build_finetuning_dataloader(cfg, tokenizer, - device_batch_size) - else: - raise ValueError( - f'Not sure how to build dataloader with config: {cfg}') - args = parse_args() with open(args.yaml_path) as f: @@ -339,18 +427,6 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, cfg = om.create(cfg) device_batch_size = cfg.global_train_batch_size // args.num_devices - # Determine the packing_ratio values we'll try - packing_ratios, raw_batch_sizes = [], [] - for packing_ratio in np.linspace(args.min, - args.max, - args.num_packing_ratios, - endpoint=True): - packing_ratio = np.round(10 * packing_ratio) / 10 - raw_batch_size = int(packing_ratio * device_batch_size) - if raw_batch_size not in raw_batch_sizes: - packing_ratios.append(packing_ratio) - raw_batch_sizes.append(raw_batch_size) - # Fetch a bunch of raw examples once, which we'll re-use if 'train_loader' not in cfg: raise ValueError('config must define train_loader') @@ -373,51 +449,13 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) - # Turn off packing for the dataloader (we want raw, pre-packed examples) - dataloader_cfg.dataset.packing_ratio = None - dataloader_cfg.dataset.max_leftovers_to_keep = None - train_dataloader = build_dataloader(dataloader_cfg, tokenizer, - max(raw_batch_sizes) * 100) - - # Get a bunch of raw examples - big_batch = next(iter(train_dataloader)) - - def split_big_batch(raw_batch_size: int) -> List: - input_ids = big_batch['input_ids'].split(raw_batch_size) - batches = [{'input_ids': x} for x in input_ids] - - for key in big_batch.keys(): - if key == 'input_ids': - continue - for idx, split in enumerate(big_batch[key].split(raw_batch_size)): - batches[idx].update({key: split}) - return batches - - def profile_packing(raw_batch_size: int) -> Tuple[float, float]: - packer = BinPackWrapper( - collator=lambda x: x, - target_batch_size=device_batch_size, - max_seq_len=dataloader_cfg.dataset.max_seq_len, - pad_token_id=0, # <-- Doesn't need to be correct for profiling - padding_side='left', # <-- Doesn't need to be correct for profiling - max_leftover_bins_to_keep=max_leftovers_to_keep) - - # Simulate feeding the packing collator a bunch of data - for batch in split_big_batch(raw_batch_size): - if batch['input_ids'].shape[0] < device_batch_size: - continue - _ = packer(batch) - - # Return the padding / waste stats over that bunch of data - padding_percent = 100 * (1 - packer.efficiency) - waste_percent = 100 * packer.waste - return padding_percent, waste_percent + results = profile_packing(dataloader_cfg, tokenizer, args.min, args.max, + args.num_packing_ratios, device_batch_size) header = '\n\n\n packing_ratio | % PADDING | % WASTE' fstr = ' {:5.1f} | {:5.2f}% | {:6.2f}%' print(header) print('-' * len(header)) - for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes): - padding, waste = profile_packing(raw_batch_size) + for packing_ratio, padding, waste in results: print(fstr.format(packing_ratio, padding, waste)) diff --git a/tests/test_packing.py b/tests/test_packing.py new file mode 100644 index 0000000000..1596013046 --- /dev/null +++ b/tests/test_packing.py @@ -0,0 +1,113 @@ +from typing import List + +from llmfoundry.data.packing import BinPackDataset +from torch.utils.data import IterableDataset + +class TestDataset(IterableDataset): + def __init__(self, data: List[List[int]]): + super().__init__() + self.data = data + + def __iter__(self): + for d in self.data: + yield {'input_ids': d } + +def test_simple_packing(): + dataset = TestDataset([ + [1], + [2] * 2, + [8] * 8, + [9] * 9, + ]) + + packed_dataset = BinPackDataset( + dataset, + packing_ratio=2, + target_batch_size=2, + max_seq_len=10, + pad_token_id=0, + ) + + packed_samples = [sample['input_ids'] for sample in packed_dataset] + + assert packed_samples[0] == [8] * 8 + [2] * 2 + assert packed_samples[1] == [9] * 9 + [1] + +def test_simple_packing_with_leftovers(): + dataset = TestDataset([ + [5] * 5, + [6] * 6, + [5] * 5, + [7] * 7, + ]) + + packed_dataset = BinPackDataset( + dataset, + packing_ratio=2, + target_batch_size=2, + max_seq_len=10, + pad_token_id=0, + ) + + packed_samples = [sample['input_ids'] for sample in packed_dataset] + + assert packed_samples[0] == [5] * 10 + assert packed_samples[1] == [7] * 7 + assert packed_samples[2] == [6] * 6 + +# def test_auto_packing(): +# reproducibility.seed_all(17) +# dataloader_cfg = DictConfig({ +# 'name': 'finetuning', +# 'dataset': { +# 'hf_name': 'mosaicml/dolly_hhrlhf', +# 'split': 'train', +# 'max_seq_len': 1024, +# 'allow_pad_trimming': False, +# 'decoder_only_format': True, +# 'packing_ratio': 'auto', +# 'shuffle': False, +# }, +# 'drop_last': False, +# 'num_workers': 1, +# 'pin_memory': False, +# 'prefetch_factor': 1, +# 'persistent_workers': True, +# 'timeout': 0, +# }) + +# tokenizer = build_tokenizer('mosaicml/mpt-7b', {}) + +# dataloader = build_finetuning_dataloader(dataloader_cfg, tokenizer, 6) + +# print(next(iter(dataloader))) + +# # print('length!', len([sample for sample in dataloader])) + +# # dataloader_cfg = DictConfig({ +# # 'name': 'finetuning', +# # 'dataset': { +# # 'hf_name': 'mosaicml/dolly_hhrlhf', +# # 'split': 'train', +# # 'max_seq_len': 1024, +# # 'allow_pad_trimming': False, +# # 'decoder_only_format': True, +# # 'shuffle': False, +# # }, +# # 'drop_last': False, +# # 'num_workers': 1, +# # 'pin_memory': False, +# # 'prefetch_factor': 1, +# # 'persistent_workers': True, +# # 'timeout': 0, +# # }) + +# # tokenizer = build_tokenizer('mosaicml/mpt-7b', {}) + +# # dataloader = build_finetuning_dataloader(dataloader_cfg, tokenizer, 6) + +# # print(len(dataloader)) + + +# # for sample in dataloader: +# # print(sample) \ No newline at end of file