diff --git a/README.md b/README.md index 5ade3a0924..f5ab4f3e79 100644 --- a/README.md +++ b/README.md @@ -475,6 +475,20 @@ tokens: # these are delimiters When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary. +#### Streaming dataset + +Use [mosaicml-streaming](https://github.com/mosaicml/streaming?tab=readme-ov-file#quick-start) to prepare your dataset for streaming the data. This allows for using "infinite" data sets. Just add `streaming: true` to your `datasets` entry: + +``` +datasets: + - ds_type: json + path: s3://my-bucket/datasets-path/ + type: completion + streaming: true +``` + +Ensure that you have uploaded the dataset according to [mosaicml-streaming](https://github.com/mosaicml/streaming?tab=readme-ov-file#quick-start)'s format beforehand. + ### Inference Playground Axolotl allows you to load your model in an interactive terminal playground for quick experimentation. diff --git a/docs/config.qmd b/docs/config.qmd index e2ea778603..b9d3b21b71 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -90,6 +90,7 @@ datasets: shards: # Optional[int] number of shards to split data into name: # Optional[str] name of dataset configuration to load train_on_split: train # Optional[str] name of dataset split to load from + streaming: null # Optional[bool] whether to use `mosaicml-streaming`'s capabilities or not. # Optional[str] fastchat conversation type, only used with type: sharegpt conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py diff --git a/requirements.txt b/requirements.txt index f707946a02..d545135571 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,6 +31,7 @@ art fschat==0.2.36 gradio==3.50.2 tensorboard +mosaicml-streaming==0.7.5 mamba-ssm==1.2.0.post1 diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 0fbed08ca3..a6af007b0c 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -96,6 +96,7 @@ class SFTDataset(BaseModel): data_files: Optional[Union[str, List[str]]] = None name: Optional[str] = None ds_type: Optional[str] = None + streaming: Optional[bool] = None train_on_split: Optional[str] = None field: Optional[str] = None diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 39c50b1a07..f36bf90b03 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -51,6 +51,50 @@ LOG = logging.getLogger("axolotl") +def load_streaming_dataset(config_dataset): + """ + Load a streaming dataset from a remote storage. + + This function initializes a streaming dataset from a remote S3 bucket, + wraps the data into a generator, and then converts it into a Hugging Face + dataset for compatibility purposes. + + Parameters: + - config_dataset (dict): Configuration dictionary that may contain settings necessary for initializing the dataset. + + Returns: + - ds (datasets.Dataset): A Hugging Face dataset object that streams data from the specified remote location. + """ + # These imports are local due to the optionality of `mosaicml-streaming`. + from functools import partial + + from datasets import Features, Value + from streaming import StreamingDataset + + # Initialize the `StreamingDataset` with configurations. + streaming_dataset = StreamingDataset( + local=None, remote=config_dataset.path, shuffle=True, batch_size=4 + ) + + # Define dataset features according to the axolotl structure. + features = Features({"text": Value("string")}) + + # Shim between `StreamingDataset` and `Dataset`. + def generator_from_streaming_dataset(streaming_dataset): + yield from streaming_dataset + + # Create a Hugging Face dataset from the generator. + # + # This is necessary because downstream functions use a different interface + # than `StreamingDataset` (e.g. the `features` attribute). + ds = Dataset.from_generator( # pylint: disable=invalid-name + generator=partial(generator_from_streaming_dataset, streaming_dataset), + features=features, + ) + + return ds + + def prepare_dataset(cfg, tokenizer): prompters = [] if not cfg.pretraining_dataset: @@ -317,10 +361,13 @@ def for_d_in_datasets(dataset_configs): ) elif ds_from_cloud and remote_file_system: if remote_file_system.isdir(config_dataset.path): - ds = load_from_disk( - config_dataset.path, - storage_options=storage_options, - ) + if config_dataset.streaming: + ds = load_streaming_dataset(config_dataset) + else: + ds = load_from_disk( + config_dataset.path, + storage_options=storage_options, + ) elif remote_file_system.isfile(config_dataset.path): ds_type = get_ds_type(config_dataset) ds = load_dataset(