diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index c6881fd276..7071cbdda3 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -16,7 +16,7 @@ SUPPORTED_EXTENSIONS, dataset_constructor) from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio -from llmfoundry.data.text_data import get_tokens_per_batch_func +from llmfoundry.data.text_data import build_streams, get_tokens_per_batch_func log = logging.getLogger(__name__) @@ -128,11 +128,14 @@ def build_finetuning_dataloader(cfg: DictConfig, dataset = None # for pyright sampler = None - if cfg.dataset.get('remote') is not None: + if cfg.dataset.get('remote') is not None or cfg.dataset.get( + 'streams') is not None: # Build streaming dataloader + streams = build_streams(cfg.dataset) dataset = dataset_constructor.build_from_streaming( tokenizer=tokenizer, - local=cfg.dataset.local, + streams=streams, + local=cfg.dataset.get('local', None), remote=cfg.dataset.get('remote', None), split=cfg.dataset.get('split', None), download_retry=cfg.dataset.get('download_retry', 2), @@ -279,11 +282,38 @@ def _validate_config(dataset_cfg: DictConfig) -> None: 'Using a streaming dataset requires setting both `remote` and `local`, ' +\ 'but dataset.local is None.' ) + elif dataset_cfg.get('streams') is not None: + # Using the streaming dataset codepath + illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load'] + discovered_illegal_keys = [] + for key in illegal_keys: + if dataset_cfg.get(key) is not None: + discovered_illegal_keys.append('`' + key + '`') + if discovered_illegal_keys: + raise ValueError( + 'The dataset config sets a value for `streams` as well as the ' +\ + f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\ + 'Those keys are used when building from a HuggingFace dataset, but ' +\ + 'setting `streams` instructs the dataset to build from a streaming dataset.' + ) + illegal_keys = ['remote', 'local'] + discovered_illegal_keys = [] + for key in illegal_keys: + if dataset_cfg.get(key) is not None: + discovered_illegal_keys.append('`' + key + '`') + if discovered_illegal_keys: + raise ValueError( + 'The dataset config sets a value for `streams` as well as the ' +\ + f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\ + 'Please either use single stream (set remote/local only) ' +\ + 'or put remote/local under streams' + ) + else: raise ValueError( - 'In the dataset config, you must set either `hf_name` to use a ' +\ - 'HuggingFace dataset or set `remote` to use a streaming ' +\ - 'dataset, but both were None.' + 'In the dataset config, you must set `hf_name` to use a HuggingFace ' +\ + 'dataset, or set `remote` to use a streaming dataset, or set ' +\ + '`streams` to use multiple streaming datasets, but all were None.' ) if dataset_cfg.get('max_seq_len') is None: raise ValueError( diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 2b397eae96..7f2a5417b4 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -37,14 +37,14 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import warnings from functools import partial from pathlib import Path -from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union, - cast) +from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence, + Tuple, Union, cast) import datasets as hf_datasets import huggingface_hub as hf_hub import numpy as np from composer.utils import dist -from streaming import StreamingDataset +from streaming import Stream, StreamingDataset from transformers import PreTrainedTokenizerBase from llmfoundry.utils.logging_utils import SpecificWarningFilter @@ -257,12 +257,25 @@ def is_valid_ift_example(pad_token_id: int, max_seq_len: int, non_padding_response) +def _stream_remote_local_validate(remote: Optional[str], local: Optional[str], + split: Optional[str]): + if remote is None or (local == remote): + if local is not None and os.path.isdir(local): + contents = set(os.listdir(local)) + if split is not None and split not in contents: + raise ValueError( + f'local directory {local} does not contain split {split}') + + class StreamingFinetuningDataset(StreamingDataset): """Finetuning dataset with flexible tokenization using StreamingDataset. Args: tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to tokenize samples. + streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from, + which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or + ``remote``/``local``. Defaults to ``None``. local (str): Local dataset directory where shards are cached by split. remote (str, optional): Remote path or directory to download the dataset from. If ``None``, its data must exist locally. StreamingDataset uses either ``streams`` or @@ -313,7 +326,8 @@ class StreamingFinetuningDataset(StreamingDataset): def __init__(self, tokenizer: PreTrainedTokenizerBase, - local: str, + streams: Optional[Sequence[Stream]] = None, + local: Optional[str] = None, remote: Optional[str] = None, split: Optional[str] = None, download_retry: int = 2, @@ -341,15 +355,15 @@ def __init__(self, f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}' ) - if remote is None or (local == remote): - if os.path.isdir(local): - contents = set(os.listdir(local)) - if split not in contents: - raise ValueError( - f'local directory {local} does not contain split {split}' - ) + if streams is None: + _stream_remote_local_validate(remote, local, split) + else: + for stream in streams: + _stream_remote_local_validate(stream.remote, stream.local, + split) super().__init__( + streams=streams, local=local, remote=remote, split=split, diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 3301d455e5..8d7ff5849d 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -232,6 +232,19 @@ def get_sequence_id_from_batch( return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1) +def build_streams(dataset_cfg: DictConfig): + streams_dict = dataset_cfg.pop('streams', None) + # build streams + streams = None + if streams_dict is not None: + streams = [] + for _, stream in streams_dict.items(): + # stream is the streams kwargs + # fwd all kwargs with **stream allows streaming to check args + streams.append(Stream(**stream)) + return streams + + def build_text_dataloader( cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, @@ -240,19 +253,11 @@ def build_text_dataloader( assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}' # get kwargs - streams_dict = cfg.dataset.pop('streams', None) mlm_probability = cfg.dataset.pop('mlm_probability', None) eos_token_id = cfg.dataset.pop('eos_token_id', None) bos_token_id = cfg.dataset.pop('bos_token_id', None) - # build streams - streams = None - if streams_dict is not None: - streams = [] - for _, stream in streams_dict.items(): - # stream is the streams kwargs - # fwd all kwargs with **stream allows streaming to check args - streams.append(Stream(**stream)) + streams = build_streams(cfg.dataset) # build dataset potentially with streams dataset = StreamingTextDataset( diff --git a/scripts/train/yamls/finetune/gpt2-arc-easy-cpu-streaming-dataset.yaml b/scripts/train/yamls/finetune/gpt2-arc-easy-cpu-streaming-dataset.yaml index d5fe26d3fc..4047256614 100644 --- a/scripts/train/yamls/finetune/gpt2-arc-easy-cpu-streaming-dataset.yaml +++ b/scripts/train/yamls/finetune/gpt2-arc-easy-cpu-streaming-dataset.yaml @@ -24,9 +24,11 @@ train_loader: name: finetuning dataset: ############ - remote: ${data_remote} - local: ${data_local} - split: train + streams: + my_data: + remote: ${data_remote} + local: ${data_local} + split: train ############ shuffle: true max_seq_len: ${max_seq_len} diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 0a7edc3d7a..319e6eafdf 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -548,31 +548,38 @@ def test_finetuning_dataloader_custom_split_remote(split: str): @pytest.mark.parametrize('pretokenize', [True, False]) +@pytest.mark.parametrize('use_multiple_streams', [True, False]) @pytest.mark.parametrize('use_bytes', [True, False]) -def test_finetuning_dataloader_streaming(pretokenize: bool, use_bytes: bool, +def test_finetuning_dataloader_streaming(pretokenize: bool, + use_multiple_streams: bool, + use_bytes: bool, tmp_path: pathlib.Path): max_seq_len = 2048 - remote_path = os.path.join(tmp_path, 'remote') - local_path = os.path.join(tmp_path, 'local') - tokenizer = build_tokenizer( tokenizer_name='gpt2', tokenizer_kwargs={'model_max_length': max_seq_len}, ) - build_mock_ft_streaming_dataset(remote_path, - 'train', - pretokenize, - use_bytes=use_bytes, - tokenizer=tokenizer) + streams_config = {'streams': {}} + num_streams = 2 + for i in range(num_streams): + remote_path = os.path.join(tmp_path, f'remote_{i}') + local_path = os.path.join(tmp_path, f'local_{i}') + build_mock_ft_streaming_dataset(remote_path, + 'train', + pretokenize, + use_bytes=use_bytes, + tokenizer=tokenizer) + streams_config['streams'][f'stream_{i}'] = { + 'remote': remote_path, + 'local': local_path, + 'split': 'train' + } cfg = { 'name': 'finetuning', 'dataset': { - 'remote': remote_path, - 'local': local_path, - 'split': 'train', 'max_seq_len': 2048, 'decoder_only_format': True, 'allow_pad_trimming': False, @@ -586,6 +593,10 @@ def test_finetuning_dataloader_streaming(pretokenize: bool, use_bytes: bool, 'persistent_workers': False, 'timeout': 0 } + if use_multiple_streams: + cfg['dataset'].update(streams_config) + else: + cfg['dataset'].update(streams_config['streams']['stream_0']) cfg = om.create(cfg)