Skip to content

Commit

Permalink
Major overhaul of Streaming documentation (#636)
Browse files Browse the repository at this point in the history
* getting_started section updated

* preparing datasets section wip

* wrapped up preparing datasets

* shard retrieval

* added dataset configuration folder

* starting on distributed trainng

* added a tip for shard sample access

* complete

* deleted old files

* main docs rewrite

* no tests for code segments

* note

* docs

* math rendered?

* dependency hell

* doc fixes

* doc fixes

* doctests

* fixing doctest

* karan's comments and faqs

* addressing comments

* addressing comments

* addressed james comments

* removed testing file

* pushing

---------

Co-authored-by: Xiaohan Zhang <[email protected]>
  • Loading branch information
snarayan21 and XiaohanZhangCMU authored Apr 5, 2024
1 parent 30326fe commit 24e9182
Show file tree
Hide file tree
Showing 52 changed files with 3,321 additions and 717 deletions.
18 changes: 18 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,21 @@ updates:
interval: "weekly"
# Allow up to 5 open pull requests for pip dependencies
open-pull-requests-limit: 5
ignore:
- dependency-name: "GitPython"
- dependency-name: "docutils"
- dependency-name: "furo"
- dependency-name: "myst-parser"
- dependency-name: "nbsphinx"
- dependency-name: "pandoc"
- dependency-name: "pypandoc"
- dependency-name: "sphinx-argparse"
- dependency-name: "sphinx-copybutton"
- dependency-name: "sphinx"
- dependency-name: "sphinx-tabs"
- dependency-name: "sphinxcontrib.katex"
- dependency-name: "sphinxcontrib-applehelp"
- dependency-name: "sphinxcontrib-devhelp"
- dependency-name: "sphinxcontrib-htmlhelp"
- dependency-name: "sphinxcontrib-qthelp"
- dependency-name: "sphinxcontrib-serializinghtml"
Binary file added docs/source/_static/images/batching_methods.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/images/mds_writing.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/images/py1b_py1br.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/images/py1e.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/images/remote_streams.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/images/sample_partition.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/images/sample_retrieval.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/images/shards_sequential.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/images/shards_shuffled.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/images/shuffling_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,15 @@
'sphinx.ext.extlinks',
'sphinx.ext.coverage',
'sphinx.ext.napoleon',
'sphinxcontrib.katex',
'sphinx.ext.viewcode',
'sphinx.ext.intersphinx',
'sphinx_copybutton',
'myst_parser',
'sphinxarg.ext',
'sphinx.ext.doctest',
'nbsphinx',
'sphinx_tabs.tabs',
'nbsphinx',
]


Expand Down
89 changes: 89 additions & 0 deletions docs/source/dataset_configuration/mixing_data_sources.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Mixing Datasets

Training a model often requires combining data from multiple different sources. Streaming makes combining these data sources, or streams, easy and configurable. See the [main concepts page](../getting_started/main_concepts.md#distributed-model-training) for a high-level view of distributed training with multiple streams.

## Using Streams

A stream is a data source, as a collection of shard files (or set of subdirectories containing shard files). Shard files can optionally be compressed. Streams are represented by the {class}`streaming.Stream` object. Similar to {class}`streaming.StreamingDataset` itself, a `Stream` object can take in `remote` and `local` paths -- see [here](../getting_started/main_concepts.md#remote-data-streams) for an example.

It is possible, though not recommended, for streams to have different schemas.

## Configuring the data mix
The `proportion`, `repeat`, or `choose` arguments to `Stream` are used to configure different dataset mixing schemes. Only one of them may be set at a time, and all streams must use the same mixing scheme (e.g., Stream A with `proportion` and Stream B with `choose` are incompatible).
- **`proportion`**: Specifies how to sample this Stream relative to other Streams.
- **`repeat`**: Specifies the degree to which a Stream is upsampled or downsampled.
- **`choose`**: Specifies the number of samples to choose from a Stream.

Let's look at some examples of dataset mixing in action.

### Using `proportion` for relative weighting

As an example, let's say we have Stream A with 100 samples and Stream B with 200 samples. The `epoch_size`, if not set, will default to the total number of unique samples -- in this case, 300. To configure our training dataset to be 25% from Stream A and 75% from Stream B, we simply set `proportion` to those values:
<!--pytest.mark.skip-->
```python
stream_A = Stream(
remote = 's3://stream_A_remote',
local = '/tmp/stream_A',
proportion = 0.25,
)
stream_B = Stream(
remote = 's3://stream_B_remote',
local = '/tmp/stream_B',
proportion = 0.75,
)
dataset = StreamingDataset(
streams = [stream_A, stream_B],
)
```

Since `epoch_size` has not been specified, the epoch will be 300 samples long, of which 75 samples will come from Stream A, and 225 from Stream B. Equivalently, we could have also set `proportion` to 2 for Stream A and 6 for Stream B for the same weighting -- StreamingDataset will normalize the proportion weights.

If `epoch_size` is explicitly set, then proportions will apply to that value instead. For example, if `epoch_size` was passed as 400 to StreamingDataset, as below, and proportions stayed the same, then in each epoch, 100 samples would be from Stream A and 300 would be from Stream B.
<!--pytest.mark.skip-->
```python
dataset = StreamingDataset(
epoch_size = 400,
streams = [stream_A, stream_B], # With proportions A: 0.25 and B: 0.75.
)
```

For multi-epoch training, to control how samples are chosen between epochs, see the [inter-epoch sampling](replication_and_sampling.md#inter-epoch-sampling) section.

### Using `repeat` for absolute weighting

It can be useful to specify how many times to upsample or downsample a Stream -- the `repeat` argument fulfills this use case. For example, to see every sample from Stream A 3 times per epoch, simply set `repeat` to 3:
<!--pytest.mark.skip-->
```python
stream_A = Stream(
remote = 's3://stream_A_remote',
local = '/tmp/stream_A',
repeat = 3,
)
```

To downsample a stream, meaning that only a fraction of the total samples from that stream are seen every epoch, set `repeat` to less than 1. For example, to see only a quarter of the samples from Stream A per epoch, set `repeat` to 0.25.

### Using `choose` for absolute weighting

Specifying the absolute number of samples to choose from a Stream can also be useful when mixing datasets. Use the `choose` argument to indicate the number of samples to take from a stream per epoch. For example, to see exactly 250 samples from Stream A per epoch, set `choose` to 250:
<!--pytest.mark.skip-->
```python
stream_A = Stream(
remote = 's3://stream_A_remote',
local = '/tmp/stream_A',
choose = 250,
)
```

## Batching Methods

Controlling how a global batch is constructed is a requirement for some training runs. StreamingDataset's `batching_method` argument takes in three different options to configure the composition of each global batch:
- **`'random'`**: (default) Global batches respect dataset mixing *in expectation*. Stream proportions can vary somewhat between batches.
- **`'stratified'`**: *Every* global batch respects dataset mixing exactly. Can help mitigate loss spikes and divergence by making sure stream proportions hold for every batch.
- **`'per_stream'`**: Each global batch contains samples from only one stream at a time. Particularly useful when your streams contain data of different tensor shapes/sizes, so that each batch can contain samples of the same shape/size.

As an example, suppose we have Stream A (green) and Stream B (blue), each making up half of our total dataset. Applying each of the batching methods would make global batches look like this:

<img src="../_static/images/batching_methods.png" alt="Batching Methods" width="800"/>

Each bar represents a single global batch. We see that `random` batching can have some variance in stream composition, while `stratified` batching keeps composition exact, and `per_stream` batching constructs each batch with a single stream.
51 changes: 51 additions & 0 deletions docs/source/dataset_configuration/replication_and_sampling.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Replication and Sampling

You can control how samples are replicated, chosen between epochs, and chosen from shards. These are useful for a variety of cases:
- **Replication**: Replicate training samples among subsets of devices. This is particularly useful for Tensor Parallelism (TP) or Sequence Parallelism (SP).
- **Inter-epoch Sampling**: Control if the samples seen across epochs should vary or not.
- **Sampling from shards**: Control how many samples to choose from each shard at a time.

Let's see when and how to use these features.

## Replication

Training with Tensor Parallelism (TP) or Sequence Parallelism (SP) requires multiple devices to see the same sample of data. The `replication` parameter of {class}`streaming.StreamingDataset`, controls how many consecutive devices will see the same samples in each batch. For example, if `replication` is set to 4 for a training job with 16 GPUs, devices 0 through 3 will see the same samples, devices 4 through 7 will see the same samples, and so on.
<!--pytest.mark.skip-->
```python
dataset = StreamingDataset(
...
replication = 4, # Every 4 GPUs will see the same samples.
...
)
```

Be aware that samples are only replicated across consecutive GPUs, as denoted by their rank from [PyTorch's distributed module](https://pytorch.org/docs/stable/distributed.html).

## Epoch size

You can specify the size of each epoch of training with the `epoch_size` argument:
<!--pytest.mark.skip-->
```python
dataset = StreamingDataset(
...
epoch_size = 10000, # Each epoch will be 10k samples.
...
)
```

## Inter-epoch sampling

You can choose how sampling from your dataset(s) occurs between epochs by specifying the `sampling_method` when instantiating `StreamingDataset`. This can be one of two values:

- `'balanced'`: (default) Samples are chosen at random from dataset(s) during each epoch.
- `'fixed'`: The same samples from the dataset(s) are chosen during every epoch.

For example, with `balanced` sampling, if the size of an epoch is 1000 samples, but my dataset contains 2000 samples, then each epoch will consist of 1000 samples taken at random from the underlying 2000. But with `fixed` sampling, the same 1000 samples that are seen in epoch 0 will be seen in all subsequent epochs as well.

## Sampling from shards

If all samples from a shard don't have to be used in training, the number of samples to choose from each shard is set by the `sampling_granularity` parameter to StreamingDataset. The `sampling_granularity` arg defaults to 1, meaning that one sample is chosen from each shard at a time.

This is particularly useful if just training on a small subset of your overall dataset. In this case, the way in which samples are chosen from shards becomes important, and directly impacts how many shards I have to download for the training job. For example, suppose the overall dataset has 10,000 samples, split up between 1000 shards of 100 samples each, but the epoch size is just 1000 samples. If `sampling_granularity` is set to 1, then the training dataset will consist of a single sample from each of the 1000 shards, meaning that all 1000 shards have to be downloaded over the course of the run. Instead, if `sampling_granularity` is set to 100, then the training dataset will consist of all 100 samples from just 10 shards, and only 10 shards will have to be downloaded for the run.

If the run's epoch size is large enough such that all shards have to be downloaded anyways, setting `sampling_granularity` will not change shard download demand.
100 changes: 100 additions & 0 deletions docs/source/dataset_configuration/shard_retrieval.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Shard Retrieval

Shards are downloaded on the fly during training and samples are retrieved from them. You can configure {class}`streaming.StreamingDataset`'s shard retrieval to meet your training job's needs. For more information about shard retrieval during distributed model training, refer to the [main concepts page](../getting_started/main_concepts.md#distributed-model-training).

## Loading datasets

### Pointing to your dataset

To train on a dataset that lives in a remote location, simply pass the path to StreamingDataset's `remote` argument. The dataset's `index.json` file should live at this directory. StreamingDataset works with all major cloud providers. The `local` argument should be used to specify where the downloaded shards will be stored, on local disk.
<!--pytest.mark.skip-->
```python
dataset = StreamingDataset(
remote = 's3://some-bucket/my-dataset', # dataset lives at this remote path
local = '/local/dataset', # shards downloaded and stored locally at this path
)
```

If your dataset is already available on local disk for your GPUs to access, only specify the `local` argument.
<!--pytest.mark.skip-->
```python
dataset = StreamingDataset(
local = '/local/dataset', # dataset shards are already locally available at this path
)
```

The `split` argument can be used to specify a particular subdirectory to use -- for example, a training dataset split.
<!--pytest.mark.skip-->
```python
dataset = StreamingDataset(
remote = 's3://some-bucket/my-dataset',
local = '/local/dataset',
split = 'train', # dataset will be loaded from 's3://some-bucket/my-dataset/train'
)
```

### Multiple streams

If using multiple data sources, specify the `remote` and/or `local` paths for each one in a separate {class}`streaming.Stream` object, and pass those to StreamingDataset's `streams` argument. An example can be found [here](../getting_started/main_concepts.md#Remote-data-streams).

### Hash Validation

If you wrote out your dataset shards with specific hash functions (see [here](../preparing_datasets/basic_dataset_conversion.md#Configuring-dataset-writing)) and want to validate them at training time, set the `validate_hash` argument to StreamingDataset. Depending on the hash function, this may slow down data loading.
<!--pytest.mark.skip-->
```python
dataset = StreamingDataset(
...
validate_hash = 'sha1', # validate shard using sha1 hash function
...
)
```

## Controlling shard downloads

### Downloading ahead

Setting the `predownload` argument ensures that StreamingDataset will download the shards needed for the upcoming `predownload` samples, per worker. For example, if `predownload` is set to 8, then each DataLoader worker will download the shards needed for up to 8 samples ahead of the current point in training. The default value of `predownload` in StreamingDataset performs well, so only set this argument if you want to prepare more samples ahead of the current training batch.
<!--pytest.mark.skip-->
```python
dataset = StreamingDataset(
...
predownload = 8, # each worker will download shards for up to 8 samples ahead
...
)
```

### Retries and Timeout

Set the `download_retry` argument to the number of times a shard download should be retried. The `download_timeout` argument specifies, in seconds, how long to wait for a shard download before throwing an exception. For larger shards, a longer `download_timeout` can be necessary.
<!--pytest.mark.skip-->
```python
dataset = StreamingDataset(
...
download_retry = 3, # retry shard downloads up to 3 times
download_timeout = 120, # wait 2 minutes for a shard to download
...
)
```

## Configure shard storage

### Cache limit

If you have limited local disk space, specify the `cache_limit` argument. Once locally stored shards reach the `cache_limit`, Streaming will begin evicting shards to stay under the limit. This is particularly useful for very large datasets or small disks. Setting `cache_limit` too low will hinder performance, since shards may be continually evicted and redownloaded. This can be specified as integer bytes or as a human-readable string.
<!--pytest.mark.skip-->
```python
cache_limit = 10*1024**2 # cache limit of 10mb
cache_limit = '10mb' # also a cache limit of 10mb
```

### Keeping compressed shards

If your dataset shards are compressed (see [here](../preparing_datasets/basic_dataset_conversion.md#Configuring-dataset-writing)), StreamingDataset will decompress them upon download for use in training. To control whether the compressed versions of shards are kept locally, use the `keep_zip` flag. This defaults to `False`, meaning that StreamingDataset will default to deleting compressed shards and only keeping the decompressed shards.
<!--pytest.mark.skip-->
```python
dataset = StreamingDataset(
...
keep_zip = True, # keep compressed versions of shards locally
...
)
```
Loading

0 comments on commit 24e9182

Please sign in to comment.