From 229ab4f95d716c4f7094ec365cef122719c4f587 Mon Sep 17 00:00:00 2001 From: snarayan21 Date: Fri, 15 Sep 2023 22:32:15 -0700 Subject: [PATCH] Updated streaming args for StreamingDataset subclasses (#602) * Fix streaming dataset setup for fine-tuning Currently, when a StreamingFinetuningDataset is created using the build_finetuning_dataloader method, a failure is returned as some incorrect parameters are passed through to the constructor of StreamingFinetuningDataset. This patch fixes the paramter mismatch and adds test coverage for this case. * updated StreamingTextDataset and StreamingFinetuningDataset with new streaming args, bumped streaming version * updated StreamingTextDataset and StreamingFinetuningDataset with new streaming args, bumped streaming version --------- Co-authored-by: Aiden Grossman Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/data/finetuning/dataloader.py | 2 + llmfoundry/data/finetuning/tasks.py | 104 ++++++++++++++++------- llmfoundry/data/text_data.py | 13 ++- setup.py | 2 +- tests/test_dataloader.py | 58 +++++++++++++ 5 files changed, 146 insertions(+), 33 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index a009f13660..661b1e808d 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -136,6 +136,8 @@ def build_finetuning_dataloader(cfg: DictConfig, shuffle_seed=cfg.dataset.get('shuffle_seed', 9176), shuffle_block_size=cfg.dataset.get('shuffle_block_size', 1 << 18), sampling_method=cfg.dataset.get('sampling_method', 'balanced'), + sampling_granularity=cfg.dataset.get('sampling_granularity', 1), + batching_method=cfg.dataset.get('batching_method', 'random'), ) collate_fn, dataloader_batch_size = _build_collate_fn( diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index c184dc9848..0a2b386048 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -71,44 +71,76 @@ class StreamingFinetuningDataset(StreamingDataset): """Finetuning dataset with flexible tokenization using StreamingDataset. Args: - local (str): Local dataset directory where shards are cached by split. tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to tokenize samples. - remote (str, optional): Download shards from this remote path or directory. If None, this - rank and worker's partition of the dataset must all exist locally. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to ``False``. - predownload (int, optional): Target number of samples ahead to download the shards of while - iterating. Defaults to ``100_000``. - keep_zip (bool, optional): Whether to keep or delete the compressed file when - decompressing downloaded shards. If set to None, keep if remote is local. Defaults to - ``None``. + local (str): Local dataset directory where shards are cached by split. + remote (str, optional): Remote path or directory to download the dataset from. If ``None``, + its data must exist locally. StreamingDataset uses either ``streams`` or + ``remote``/``local``. Defaults to ``None``. + split (str, optional): Which dataset split to use, if any. If provided, we stream from/to + the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. download_timeout (float): Number of seconds to wait for a shard to download before raising an exception. Defaults to ``60``. validate_hash (str, optional): Optional hash or checksum algorithm to use to validate shards. Defaults to ``None``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption. - If ``None``, defaults to the number of nodes of the initial run. Defaults to 128. + keep_zip (bool): Whether to keep or delete the compressed form when decompressing + downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to + `False``. + epoch_size (int, optional): Number of samples to draw per epoch balanced across all + streams. If ``None``, takes its value from the total number of underlying samples. + Provide this field if you are weighting streams relatively to target a larger or + smaller epoch size. Defaults to ``None``. + predownload (int, optional): Target number of samples ahead to download the shards of while + iterating. Defaults to ``100_000``. + cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's + shard cache. Before downloading a shard, the least recently used resident shard(s) may + be evicted (deleted from the local cache) in order to stay under the limit. Set to None + to disable shard eviction. Supports integer bytes as well as string human-readable + bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None. + partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. + num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with + resumption. Defaults to ``None``, which is interpreted as the number of nodes of the + initial run. batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is partitioned over the workers. Defaults to ``None``. + shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to + ``False``. + shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``. + shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. + shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. + sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. + Defaults to ``balanced``. + sampling_granularity (int): When picking samples for a stream's final partial repeat, + how many samples to pick from the same shard at a time (``1`` for evenly balanced + across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc). + Defaults to ``1``. + batching_method (str): Which batching method to use, either ``random``, ``stratified``, or + ``per_stream``. Defaults to ``random``. """ def __init__(self, - local: str, tokenizer: PreTrainedTokenizerBase, + local: str, remote: Optional[str] = None, split: Optional[str] = None, - shuffle: bool = False, - predownload: Optional[int] = 100_000, - keep_zip: bool = False, download_retry: int = 2, download_timeout: float = 60, validate_hash: Optional[str] = None, - shuffle_seed: int = 9176, - num_canonical_nodes: Optional[int] = 128, + keep_zip: bool = False, + epoch_size: Optional[int] = None, + predownload: Optional[int] = None, + cache_limit: Optional[Union[int, str]] = None, + partition_algo: str = 'orig', + num_canonical_nodes: Optional[int] = None, batch_size: Optional[int] = None, + shuffle: bool = False, + shuffle_algo: str = 'py1b', + shuffle_seed: int = 9176, + shuffle_block_size: int = 1 << 18, + sampling_method: str = 'balanced', + sampling_granularity: int = 1, + batching_method: str = 'random', **kwargs: Any): if len(kwargs) > 0: @@ -125,18 +157,28 @@ def __init__(self, ) # Build Dataset - super().__init__(local=local, - remote=remote, - split=split, - shuffle=shuffle, - predownload=predownload, - keep_zip=keep_zip, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - shuffle_seed=shuffle_seed, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size) + super().__init__( + local=local, + remote=remote, + split=split, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip, + epoch_size=epoch_size, + predownload=predownload, + cache_limit=cache_limit, + partition_algo=partition_algo, + num_canonical_nodes=num_canonical_nodes, + batch_size=batch_size, + shuffle=shuffle, + shuffle_algo=shuffle_algo, + shuffle_seed=shuffle_seed, + shuffle_block_size=shuffle_block_size, + sampling_method=sampling_method, + sampling_granularity=sampling_granularity, + batching_method=batching_method, + ) self.tokenizer = tokenizer diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 31626b237f..afdd243adf 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -66,7 +66,14 @@ class StreamingTextDataset(StreamingDataset): shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``. shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. - sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. Defaults to ``balanced``. + sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. + Defaults to ``balanced``. + sampling_granularity (int): When picking samples for a stream's final partial repeat, + how many samples to pick from the same shard at a time (``1`` for evenly balanced + across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc). + Defaults to ``1``. + batching_method (str): Which batching method to use, either ``random``, ``stratified``, or + ``per_stream``. Defaults to ``random``. """ def __init__(self, @@ -91,6 +98,8 @@ def __init__(self, shuffle_seed: int = 9176, shuffle_block_size: int = 1 << 18, sampling_method: str = 'balanced', + sampling_granularity: int = 1, + batching_method: str = 'random', **kwargs: Any): group_method = kwargs.pop('group_method', None) @@ -138,6 +147,8 @@ def __init__(self, shuffle_seed=shuffle_seed, shuffle_block_size=shuffle_block_size, sampling_method=sampling_method, + sampling_granularity=sampling_granularity, + batching_method=batching_method, ) self.tokenizer = tokenizer self.max_seq_len = max_seq_len diff --git a/setup.py b/setup.py index b07b8afe08..1a93bd05f7 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ 'mosaicml[libcloud,wandb,mlflow]>=0.16.1,<0.17', 'accelerate>=0.20,<0.21', # for HF inference `device_map` 'transformers>=4.33,<4.34', - 'mosaicml-streaming>=0.5.1,<0.6', + 'mosaicml-streaming>=0.6,<0.7', 'torch>=1.13.1,<2.1.1', 'datasets>=2.14.5,<2.15', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 53549ccfe1..72bfac1d08 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -13,6 +13,7 @@ import torch from composer.utils import dist, using_torch_2 from omegaconf import OmegaConf as om +from streaming import MDSWriter from llmfoundry import (build_finetuning_dataloader, build_text_denoising_dataloader) @@ -42,6 +43,25 @@ def get_abs_data_path(data_local: str): return os.path.join(os.getcwd(), data_local) +def build_mock_ft_streaming_dataset(data_path: str, split: str): + columns = {'prompt': 'str', 'response': 'str'} + + dataset = [{ + 'prompt': 'This is just a test1', + 'response': 'Hello World1' + }, { + 'prompt': 'This is just a test2', + 'response': 'Hello world2' + }] + + output_path = os.path.join(data_path, split) + + with MDSWriter(columns=columns, out=output_path, + compression=None) as output_writer: + for sample in dataset: + output_writer.write(sample) + + @pytest.mark.parametrize('tokenizer_name', ['gpt2', 'facebook/opt-125m']) @pytest.mark.parametrize('pretokenize', [False, True]) def test_correct_padding(tokenizer_name: str, @@ -414,6 +434,44 @@ def test_finetuning_dataloader_custom_split_remote( _ = build_finetuning_dataloader(cfg, tokenizer, 4) +def test_finetuning_dataloader_streaming(tmp_path: pathlib.Path): + max_seq_len = 2048 + + remote_path = os.path.join(tmp_path, 'remote') + local_path = os.path.join(tmp_path, 'local') + + build_mock_ft_streaming_dataset(remote_path, 'train') + + cfg = { + 'name': 'finetuning', + 'dataset': { + 'remote': remote_path, + 'local': local_path, + 'split': 'train', + 'max_seq_len': 2048, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + 'drop_last': False, + 'num_workers': 4, + 'pin_memory': False, + 'prefetch_factor': 2, + 'persistent_workers': False, + 'timeout': 0 + } + + cfg = om.create(cfg) + + tokenizer = build_tokenizer( + tokenizer_name='gpt2', + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) + + _ = build_finetuning_dataloader(cfg, tokenizer, 4) + + @pytest.mark.parametrize('add_bad_data_dropped', [True, False]) @pytest.mark.parametrize('add_bad_data_error', [True, False]) def test_malformed_data(