From 2939cc983d0246f68bbff1813317987f450bc4ce Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Tue, 9 Apr 2024 14:43:09 -0400 Subject: [PATCH] Updating the streaming version in setup.py (#1103) * updating the streaming version in setup.py * updating constructor call to StreamingDataset * making allow_unsafe_types and replication configurable through dataset configs * adding docstring --- llmfoundry/data/finetuning/dataloader.py | 2 ++ llmfoundry/data/finetuning/tasks.py | 10 ++++++++++ llmfoundry/data/text_data.py | 10 ++++++++++ setup.py | 2 +- 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 38c9673a14..1d8711d280 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -170,6 +170,8 @@ def build_finetuning_dataloader(cfg: DictConfig, sampling_granularity=cfg.dataset.get('sampling_granularity', 1), batching_method=cfg.dataset.get('batching_method', 'random'), max_seq_len=cfg.dataset.max_seq_len, + allow_unsafe_types=cfg.dataset.get('allow_unsafe_types', False), + replication=cfg.dataset.get('replication', None), ) else: diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 4ca15e8d1f..42b15e4d6e 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -490,6 +490,12 @@ class StreamingFinetuningDataset(StreamingDataset): Defaults to ``1``. batching_method (str): Which batching method to use, either ``random``, ``stratified``, or ``per_stream``. Defaults to ``random``. + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an error + if ``False``. Defaults to ``False``. + replication (int, optional): Determines how many consecutive devices will receive the same + samples. Useful for training with tensor or sequence parallelism, where multiple + devices need to see the same partition of the dataset. Defaults to ``None``. """ def __init__(self, @@ -516,6 +522,8 @@ def __init__(self, sampling_granularity: int = 1, batching_method: str = 'random', max_seq_len: int = 2048, + allow_unsafe_types: bool = False, + replication: Optional[int] = None, **kwargs: Any): if len(kwargs) > 0: @@ -552,6 +560,8 @@ def __init__(self, sampling_method=sampling_method, sampling_granularity=sampling_granularity, batching_method=batching_method, + allow_unsafe_types=allow_unsafe_types, + replication=replication, ) self.tokenizer = tokenizer diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index e85968543c..fc31b890b0 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -83,6 +83,12 @@ class StreamingTextDataset(StreamingDataset): Defaults to ``1``. batching_method (str): Which batching method to use, either ``random``, ``stratified``, or ``per_stream``. Defaults to ``random``. + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an error + if ``False``. Defaults to ``False``. + replication (int, optional): Determines how many consecutive devices will receive the same + samples. Useful for training with tensor or sequence parallelism, where multiple + devices need to see the same partition of the dataset. Defaults to ``None``. """ def __init__(self, @@ -109,6 +115,8 @@ def __init__(self, sampling_method: str = 'balanced', sampling_granularity: int = 1, batching_method: str = 'random', + allow_unsafe_types: bool = False, + replication: Optional[int] = None, **kwargs: Any): if len(kwargs) > 0: @@ -151,6 +159,8 @@ def __init__(self, sampling_method=sampling_method, sampling_granularity=sampling_granularity, batching_method=batching_method, + allow_unsafe_types=allow_unsafe_types, + replication=replication, ) self.tokenizer = tokenizer self.max_seq_len = max_seq_len diff --git a/setup.py b/setup.py index 79511eeca3..1e384d35ae 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ 'mlflow>=2.10,<3', 'accelerate>=0.25,<0.26', # for HF inference `device_map` 'transformers>=4.39.3,<4.40', - 'mosaicml-streaming>=0.7.4,<0.8', + 'mosaicml-streaming>=0.7.5,<0.8', 'torch>=2.2.1,<2.3', 'datasets>=2.16,<2.17', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data