From dbd798e00c474cba5f2d53c1bca476c241f1aa84 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Fri, 14 Jun 2024 14:58:13 -0700 Subject: [PATCH] Fix packing + streaming + resumption (#1277) --- llmfoundry/data/finetuning/dataloader.py | 4 +++- llmfoundry/data/finetuning/tasks.py | 12 ++++++++++++ tests/data/test_packing.py | 10 ++++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 639beba6f0..160e9bfe3b 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -222,7 +222,7 @@ def build_finetuning_dataloader( cache_limit=dataset_cfg.get('cache_limit', None), partition_algo=dataset_cfg.get('partition_algo', 'relaxed'), num_canonical_nodes=dataset_cfg.get('num_canonical_nodes', None), - batch_size=dataset_batch_size, + batch_size=dataloader_batch_size, shuffle=dataset_cfg.get('shuffle', False), shuffle_algo=dataset_cfg.get('shuffle_algo', 'py1e'), shuffle_seed=dataset_cfg.get('shuffle_seed', 9176), @@ -233,6 +233,7 @@ def build_finetuning_dataloader( max_seq_len=dataset_cfg['max_seq_len'], allow_unsafe_types=dataset_cfg.get('allow_unsafe_types', False), replication=replication_factor, + packing_ratio=dataloader_batch_size / dataset_batch_size, ) else: @@ -390,6 +391,7 @@ def _validate_config( 'allow_pad_trimming', 'seq_parallel_replication', 'auto_packing_replication', + 'max_leftover_bins_to_keep', } if not set(kwargs.keys()).issubset(allowed_additional_kwargs): raise ValueError( diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 40f178fb6e..9a0f680bd7 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -592,6 +592,7 @@ def __init__( max_seq_len: int = 2048, allow_unsafe_types: bool = False, replication: Optional[int] = None, + packing_ratio: Optional[float] = None, **kwargs: Any, ): @@ -644,6 +645,7 @@ def __init__( self.tokenizer = tokenizer self.max_seq_len = max_seq_len + self.packing_ratio = packing_ratio # How to process a sample def __getitem__(self, idx: int) -> Dict[str, Any]: @@ -675,6 +677,16 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: return {'turns': [sample]} return tokenize_formatted_example(sample, tokenizer=self.tokenizer) + def state_dict(self, num_samples: int, + from_beginning: bool) -> Dict[str, Any]: + if self.packing_ratio is not None: + num_samples = int(self.packing_ratio * num_samples) + + return super().state_dict( + num_samples=num_samples, + from_beginning=from_beginning, + ) + class DatasetConstructor: diff --git a/tests/data/test_packing.py b/tests/data/test_packing.py index b910b8c5ff..d181dbde0b 100644 --- a/tests/data/test_packing.py +++ b/tests/data/test_packing.py @@ -14,6 +14,7 @@ from torch.utils.data import DataLoader from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader +from llmfoundry.data.finetuning.tasks import StreamingFinetuningDataset from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio from llmfoundry.utils.builders import build_tokenizer @@ -206,6 +207,15 @@ def test_auto_packing_with_streaming_dataloader(tmp_path: Path): if batch_ix >= 3: break + assert isinstance(loader, DataLoader) + assert isinstance(loader.dataset, StreamingFinetuningDataset) + assert loader.dataset.packing_ratio is not None + assert isinstance(loader.batch_size, int) + assert loader.dataset.packing_ratio == int(loader.batch_size / 6) + + state_dict = loader.dataset.state_dict(num_samples=2, from_beginning=False) + assert state_dict['sample_in_epoch'] == 2 * loader.dataset.packing_ratio + @pytest.mark.parametrize('packing_ratio', ['auto', 2.0]) @patch(