Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Imagen
Imagen is a text-to-image generative diffusion model that operates in pixel-space. This repository contains the necessary tools and scripts for performantly training Imagen from base model to its superresolution models in JAX on GPUs.
Prompts:
Architecture
For maximum flexibility and low disk requirements, this repo supports a distributed architecture for text embedding in diffusion model training. Upon launching training, it will spawn LLM inference servers that will performantly calculate text embeddings online (with no latency hit). It does this by creating several inference clients in the diffusion model trainer's dataloaders, which send embedding requests to the inference servers. These servers are based on NVIDIA PyTriton, so execute all requests batched. Currently, this inference server supports T5x LLMs, but can be changed to be based on anything (doesn't even have to be JAX!) since the diffusion model trainer's client is simply making PyTriton (http) calls.
GPU Scripts and Usage
We provide scripts to run interactively or on SLURM.
Container
We provide a fully built and ready-to-use container here:
ghcr.io/nvidia/t5x:imagen-2023-10-02
.We do not currently have custom-built container workflows, but are actively working on supporting this, stay tuned for updates! Imagen will also be available in our T5x container in future releases.
Dataset
This model accepts webdataset-format datasets for training. For reference, we have an imagenet webdataset example here.(NOTE: imagen is not directly compatible with imagenet). For imagen training with a compatible dataset, you can find or create your own webdataset (with image and text modalities).
Once you have your webdataset, update the dataset configs { base, sr1, sr2} with the paths to your dataset(s) under
MIXTURE_OR_TASK_NAME
.The 'img-txt-ds' configs assume a webdataset with a text and image modality. The images are in jpg format and the text is raw text in a
'.txt'
file. Currently, the configs are set up to do resolution-based filtering, scale-preserved square random cropping, and low-resolution image generation for SR model training. This can be changed (i.e. if you want your text in.json
format and want to do additional processing) in the dataset configuration files { base, sr1, sr2}.Downloading the LLM checkpoint
You will need to acquire the LLM checkpoint for T5 (for multimodal training) from T5x here. All models use T51.1 format T5-xxl by default. Once you have the checkpoint, place it at
rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl
(appending the_{size}
to the checkpoint folder). NOTE: We're working on adding TransformerEngine support to the inference server, but for now, please run with theDISABLE_TE=True
environment variable (example scripts include this).Running interactively
Note: this should only be done with singlenode jobs
Single Node runs
Pretraining can be done on multiple gpus within 1 host with
scripts/singlenode_inf_train.sh
. This will build an Imagen model with the Adam optimizer and relevant parameters. It will also launch the relevant LLM inference servers.Multi Node runs
For a SLURM+pyxis cluster, the
scripts/example_slurm_inf_train.sub
file provides an example slurm submit file (edit with your details), which callsscripts/multinode_train.sh
andscripts/specialized_run.py
to execute training.Pretraining run commands
All commands below assume you are in
$ROSETTA_DIR=/opt/rosetta
and have the scripts and slurm scripts locally.Multinode
Arguments are set as such:
All parameters can be found in the relevant script.
Example training Commands
Assumes 8GPU 80GB A100/H100 Nodes.
Imagen-base small (500M):
Imagen-base large (2B):
Imagen-sr1 (efficient unet) (600M):
Imagen-sr2 (efficient unet) (600M):
Sampling
You can find example sampling scripts that use the 500M base model and EfficientUnet SR models in scripts. Prompts should be specified as in example
Sampling 256x256 images
Defaults to imagen_256_sample.gin config (can be adjusted in script)
Sampling 1024x1024 images
Defaults to imagen_1024_sample.gin config (can be adjusted in script).
Convergence and Performance
Global Batch size = 2048. We assume 2.5B Training examples in these calculations. LLM Inference server nodes are not included in these numbers.
Imagen-SR1-430M-UNet
is not currently supported. You can use the sr1-efficient-unet instead. Coming Soon!Imagen base 500M + Efficient SR1 (600M):