Currently MaxText has three data input pipelines:
Pipeline | Dataset formats | Features | Limitations |
---|---|---|---|
HuggingFace | datasets in HuggingFace Hub local/Cloud Storage datasets in json, parquet, arrow, csv, txt |
convenience multiple formats |
limit scalability using HuggingFace Hub non-deterministic with preemption (deterministic without preemption) |
Grain | ArrayRecord, available through Tensorflow Datasets | fully deterministic, regardless of preemption | only supports random access datasets |
TFDS | TFRecord, available through Tensorflow Datasets | only supports TFRecords non-deterministic with preemption (deterministic without preemption) |
- Perf data for all 3 input pipeline: https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Perf.md
In multihost environment, if use an input pipeline that reads data sequentially (HuggingFace or TFDS), the most performant way is to have each data file only accessed by one host, and each host access a subset of data files (shuffle is within the subset of files). This requires (# of data files) to be multiples of (# of hosts loading data). We recommand users to reshard the dataset or use a subset of hosts to load data by setting expansion_factor_real_data (only available for some topologies, will error out otherwise). In MaxText, since the goal is to demonstrate the most performant experience, the behaviors for different data pipelines are:
- When (# of data files) >= (# of hosts loading data), assign files to each host as evenly as possible, some host may ended up with 1 file more than the others. When some hosts run out of data, they will produce empty padding batches, so that you are able to utilize the data from the hosts that still have data. But in this stage, training/eval will be less effective, and you will see a decrease in total_weights and slower change in loss. If all hosts run out of data before the step number you set, you will see 0 total_weights and 0 loss. The training/eval will run until the steps/eval_steps set in the config. Note that even each host are assigned the same number of data files, due to the different example count in each data file, and example packing, you will still have different number of batches on each host near the end of the epoch.
- When (# of data files) < (# of hosts loading data), files are read sequentially with multiple hosts accessing each file, perf can degrade quickly as # of host increases.
- When (# of data files) >= (# of hosts loading data), assign equal number of files to each host. The remainning files are skipped. Train/eval will hang if steps/eval_steps are not met but some hosts run out of data. Please set steps/eval_steps accordingly.
- When (# of data files) < (# of hosts loading data), files are read sequentially with multiple hosts accessing each file, perf can degrade quickly as # of host increases.
- Perf not affected by (# of data files) vs (# of hosts loading data). Data are shuffled globally. Because grain uses a data format (ArrayRecord) that supports random access by index. Even with multiple hosts accessing the same file, they are accessing different indices and and won't have the issue seen with sequential reading.
- At the end of the dataset, you may still have some hosts runing out of indices and hang, Please set steps/eval_steps accordingly.
The HuggingFace pipeline supports streaming directly from HuggingFace Hub, or from GCS bucket in HuggingFace supported formats (parquet, json, etc.). This is through the HuggingFace datasets.load_dataset
API with streaming=True
, which take in hf_*
parameters.
dataset_type: hf
hf_path: 'allenai/c4' # for using https://huggingface.co/datasets/allenai/c4
hf_data_dir: 'en'
hf_train_files: ''
# set eval_interval > 0 to use the specified eval dataset, otherwise, only metrics on the train set will be calculated.
eval_interval: 10000
hf_eval_split: 'validation'
hf_eval_files: ''
# for HF pipeline, tokenizer_path can be a path in HuggingFace Hub,
# or a local path containing tokenizer in a format supported by transformers.AutoTokenizer
tokenizer_path: 'google-t5/t5-large' # for using https://huggingface.co/google-t5/t5-large
hf_access_token: '' # provide token if using gated dataset or tokenizer
dataset_type: hf
hf_path: 'parquet' # or json, arrow, etc.
hf_data_dir: ''
hf_train_files: 'gs://<bucket>/<folder>/*-train-*.parquet' # match the train files
# set eval_interval > 0 to use the specified eval dataset. Otherwise, only metrics on the train set will be calculated.
eval_interval: 10000
hf_eval_split: ''
hf_eval_files: 'gs://<bucket>/<folder>/*-validation-*.parquet' # match the val files
# for HF pipeline, tokenizer_path can be a path in HuggingFace Hub,
# or a local path containing tokenizer in a format supported by transformers.AutoTokenizer
tokenizer_path: 'google-t5/t5-large' # for using https://huggingface.co/google-t5/t5-large
- Streaming data directly from HuggingFace Hub may be impacted by the traffic of the server. During peak hours you may encounter "504 Server Error: Gateway Time-out". It's recommended to download the HuggingFace dataset to a GCS bucket or disk for the most stable experience.
- Streaming data directly from HuggingFace Hub works in multihost settings with a small number of hosts. We have encountered "read time out" error with host number > 16.
- Only supports epoch=1 at this moment.
Determinism in a data input pipeline means that the same input data always results in the same sequence of batches at each step. This is typically achieved by setting a fixed shuffle seed during pipeline initialization. In an ideal scenario, where training runs uninterrupted, this determinism is straightforward (deterministic without preemption). However, real-world distributed training environments often face preemptions due to maintenance, hardware failures, or resource constraints. When a preempted training run resumes, the data input pipeline is re-initialized. If the same shuffle seed is used, the pipeline restarts from the beginning, potentially re-training the model on initial data. Conversely, a new seed produces a different batch sequence, making it difficult to track which data has been seen and how often each example is used for training. This lack of control can impact model performance and reproducibility.
Grain ensures determinism in data input pipelines by saving the pipeline's state, including dataset metadata and processed data indices, within a small JSON file in checkpoints. When a training run is resumed with the same dataset and shuffle seed, Grain restores the pipeline's exact state from the checkpoint. This enables fully deterministic, reproducible training that is resilient to disruptions.
- Model sensitive to repetition. When models are sensitive to the frequency with which they encounter specific examples, precise control over the order and repetition of data during training is essential.
- Convergence comparison. In sensitive convergence experiments like testing quantization techniques, maintaining identical data batches between runs (e.g., quantized vs. unquantized) is essential for comparison. Determinism ensures consistency even when the runs are long and undergo saving/resuming at different steps.
- Debug training anomalies. When troubleshooting training spikes or anomalies, the ability to replay the exact data sequence helps distinguish between bad data batches and underlying hardware or software issues.
In HF or TFDS data pipeline, global shuffle is performed by a shuffle buffer with limited size. Grain performs global shuffle of the indices in the beginning of each epoch and then reads the elements according to the random order. We have found this to be generally fast enough, even when using hard drives and distributed file systems.
- Dataset needs to be in a format that supports random access. The default format is ArrayRecord. For converting a dataset into ArrayRecord, see instructions. Additionally, other random accessible data sources can be supported via a custom data source class (docs).
- ArrayRecord dataset, when hosted on GCS bucket, can only be read through Cloud Storage FUSE. The installation of Cloud Storage FUSE is included in setup.sh. User then needs to mount the GCS bucket to a local path for each worker, using the script setup_gcsfuse.sh. The script configs some parameters for the mount.
bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$BUCKET_NAME MOUNT_PATH=$MOUNT_PATH
- Set
dataset_type=grain
and setgrain_train_files
to match the ArrayRecord files via a local path since the bucket has been mounted. - Tune
grain_worker_count
for performance. This parameter controls the number of child process used by Grain (more details in behind_the_scene, code). If you use a large number of workers, please check your config for gcsfuse in setup_gcsfuse.sh to avoid gcsfuse throttling. - Example command:
bash setup_gcsfuse.sh \
DATASET_GCS_BUCKET=maxtext-dataset \
MOUNT_PATH=/tmp/gcsfuse && \
python3 MaxText/train.py MaxText/configs/base.yml \
run_name=<RUN_NAME> base_output_directory=gs://<MY_BUCKET> \
dataset_type=grain \
grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record* \
grain_worker_count=2
- Using validation set for eval When setting eval_interval > 0, eval will be run with a specified eval dataset. Example config:
eval_interval: 10000
grain_eval_files: '/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-validation.array_record*'
- Download the Allenai c4 dataset in TFRecord format to a GCS bucket (will cost about $100, details)
bash download_dataset.sh {GCS_PROJECT} {GCS_BUCKET_NAME}
- Use the following config:
dataset_type: tfds
dataset_name: 'c4/en:3.0.1'
# set eval_interval > 0 to use the specified eval dataset. Otherwise, only metrics on the train set will be calculated.
eval_interval: 10000
eval_dataset_name: 'c4/en:3.0.1'
eval_split: 'validation'
# TFDS input pipeline only supports tokenizer in spm format
tokenizer_path: "assets/tokenizer.llama2"