Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Imagen to rosetta #278

Merged
merged 6 commits into from
Oct 5, 2023
Merged

Adds Imagen to rosetta #278

merged 6 commits into from
Oct 5, 2023

Conversation

terrykong
Copy link
Contributor

Imagen

Imagen is a text-to-image generative diffusion model that operates in pixel-space. This repository contains the necessary tools and scripts for performantly training Imagen from base model to its superresolution models in JAX on GPUs.

Prompts:

  • A racoon wearning a hat and leather jacketin front of a backyard window. There are raindrops on the window
  • A blue colored pizza
  • a highly detailed digital painting of a portal in a mystic forest with many beautiful trees. A person is standing in front of the portal.

Architecture

For maximum flexibility and low disk requirements, this repo supports a distributed architecture for text embedding in diffusion model training. Upon launching training, it will spawn LLM inference servers that will performantly calculate text embeddings online (with no latency hit). It does this by creating several inference clients in the diffusion model trainer's dataloaders, which send embedding requests to the inference servers. These servers are based on NVIDIA PyTriton, so execute all requests batched. Currently, this inference server supports T5x LLMs, but can be changed to be based on anything (doesn't even have to be JAX!) since the diffusion model trainer's client is simply making PyTriton (http) calls.

GPU Scripts and Usage

We provide scripts to run interactively or on SLURM.

Container

We provide a fully built and ready-to-use container here: ghcr.io/nvidia/t5x:imagen-2023-10-02.

We do not currently have custom-built container workflows, but are actively working on supporting this, stay tuned for updates! Imagen will also be available in our T5x container in future releases.

Dataset

This model accepts webdataset-format datasets for training. For reference, we have an imagenet webdataset example here.(NOTE: imagen is not directly compatible with imagenet). For imagen training with a compatible dataset, you can find or create your own webdataset (with image and text modalities).

Once you have your webdataset, update the dataset configs { base, sr1, sr2} with the paths to your dataset(s) under MIXTURE_OR_TASK_NAME.

The 'img-txt-ds' configs assume a webdataset with a text and image modality. The images are in jpg format and the text is raw text in a '.txt' file. Currently, the configs are set up to do resolution-based filtering, scale-preserved square random cropping, and low-resolution image generation for SR model training. This can be changed (i.e. if you want your text in .json format and want to do additional processing) in the dataset configuration files { base, sr1, sr2}.

Downloading the LLM checkpoint

You will need to acquire the LLM checkpoint for T5 (for multimodal training) from T5x here. All models use T51.1 format T5-xxl by default. Once you have the checkpoint, place it at rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl (appending the _{size} to the checkpoint folder). NOTE: We're working on adding TransformerEngine support to the inference server, but for now, please run with the DISABLE_TE=True environment variable (example scripts include this).

Running interactively

Note: this should only be done with singlenode jobs

CONTAINER=ghcr.io/nvidia/t5x:imagen-2023-10-02
docker run --rm --gpus=all -it --net=host --ipc=host -v ${PWD}:/opt/rosetta -v ${DATASET_PATH}:/mnt/datasets --privileged $CONTAINER bash

Single Node runs

Pretraining can be done on multiple gpus within 1 host with scripts/singlenode_inf_train.sh. This will build an Imagen model with the Adam optimizer and relevant parameters. It will also launch the relevant LLM inference servers.

#### Pretraining (interactive: already inside container) with example args
bash rosetta/projects/imagen/scripts/singlenode_inf_train.sh {DATASET NAME} {MODEL NAME} {PRECISION} {NUM GPUS} {BSIZE/GPU} {LOGDIR} {MODEL DIR} {NUM LLM INFERENCE GPUS} {INFERENCE SERVER LLM SIZE}

#### Pretraining (non-interactive)
docker run --rm --gpus=all --net=host --ipc=host -v ${DATASET_PATH}:/mnt/datasets $CONTAINER bash rosetta/projects/imagen/scripts/singlenode_inf_train.sh {args from above}

Multi Node runs

For a SLURM+pyxis cluster, the scripts/example_slurm_inf_train.sub file provides an example slurm submit file (edit with your details), which calls scripts/multinode_train.sh and scripts/specialized_run.py to execute training.

Pretraining run commands

All commands below assume you are in $ROSETTA_DIR=/opt/rosetta and have the scripts and slurm scripts locally.

Multinode

Arguments are set as such:

sbatch -N {NODE_CT} rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \
{DATASET NAME} {MODEL NAME} {PRECISION} {NUM GPUS / NODE} {BSIZE/GPU} {MODEL DIR} {NUM LLM INFERENCE GPUS} {INFERENCE SERVER LLM SIZE}

All parameters can be found in the relevant script.

Example training Commands

Assumes 8GPU 80GB A100/H100 Nodes.

Imagen-base small (500M):

sbatch -N 14 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \
{DATASET} imagen_base_500M bfloat16 8 32 runs/imagen-base 48 xxl

Imagen-base large (2B):

sbatch -N 20 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \
{DATASET} imagen_base_2B bfloat16 8 16 runs/imagen-base 32 xxl

Imagen-sr1 (efficient unet) (600M):

sbatch -N 14 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \
{DATASET} imagen_sr1_efficientunet_600M bfloat16 8 32 runs/imagen-sr1 48 xxl

Imagen-sr2 (efficient unet) (600M):

sbatch -N 14 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \
{DATASET} imagen_sr2_efficientunet_600M bfloat16 8 32 runs/imagen-sr2 48 xxl

Sampling

You can find example sampling scripts that use the 500M base model and EfficientUnet SR models in scripts. Prompts should be specified as in example

Sampling 256x256 images

Defaults to imagen_256_sample.gin config (can be adjusted in script)

CUDA_VISIBLE_DEVICES=<DEVICES> CFG=5.0 BASE_PATH=<BASE_CKPT> SR1_PATH=<SR1_CKPT> PROMPT_TEXT_FILES=<FILE> ./rosetta/projects/imagen/scripts/sample_imagen_256.sh 

Sampling 1024x1024 images

Defaults to imagen_1024_sample.gin config (can be adjusted in script).

CUDA_VISIBLE_DEVICES=<DEVICES> CFG=5.0 BASE_PATH=<BASE_CKPT> SR1_PATH=<SR1_CKPT> SR2_PATH=<SR2_CKPT> PROMPT_TEXT_FILES=<FILE> ./rosetta/projects/imagen/scripts/sample_imagen_1024.sh 

Convergence and Performance

Global Batch size = 2048. We assume 2.5B Training examples in these calculations. LLM Inference server nodes are not included in these numbers.

size GPU Precision #GPUs BS / GPU Images/Sec Im/Sec/GPU Est. Walltime (hr) GPU-days Config
Imagen-base-500M A100-80G-SXM BF16 8 64 858 107.0 809 269 cfg
Imagen-base-500M A100-80G-SXM BF16 32 64 3056 95.5 227 303 cfg
Imagen-base-2B A100-80G-SXM BF16 8 16 219 27.4 3170 1057 cfg
Imagen-base-2B A100-80G-SXM BF16 32 16 795 24.8 873 1164 cfg
Imagen-base-2B A100-80G-SXM BF16 128 16 2934 22.9 236 1258 cfg
Imagen-SR1-600M-EffUNet A100-80G-SXM BF16 8 64 674 84.3 1030 343 cfg
Imagen-SR1-600M-EffUNet A100-80G-SXM BF16 32 64 2529 79.1 274 365 cfg
Imagen-SR2-600M-EffUNet A100-80G-SXM BF16 8 64 678 84.8 1024 341 cfg
Imagen-SR2-600M-EffUNet A100-80G-SXM BF16 32 64 2601 81.3 267 356 cfg
Imagen-SR1-430M-UNet A100-80G-SXM BF16 8 16 194 24.3 3580 1193 cfg

Imagen-SR1-430M-UNet is not currently supported. You can use the sr1-efficient-unet instead. Coming Soon!

Imagen base 500M + Efficient SR1 (600M):

cfg FID-30K (256x256)
2 11.30
3 10.23
4 11.33
6 12.34

@terrykong terrykong requested a review from sharathts October 2, 2023 23:25
sharathts
sharathts previously approved these changes Oct 2, 2023
@terrykong
Copy link
Contributor Author

There are several build/test issues in the latest t5x nightlies that are causing this PR's presubmit to fail.

I have manually the built image to inspect that no unit-test failures result from this PR: https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6422945370/job/17440459630

This should be safe to merge

@terrykong terrykong requested a review from sharathts October 5, 2023 18:31
@terrykong terrykong merged commit 25de1cd into main Oct 5, 2023
27 of 41 checks passed
@terrykong terrykong deleted the imagen-release branch October 5, 2023 18:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants