(Results of Latte with skip-branches on text-to-video and class-to-video tasks. Left: text-to-video with 1.7x and 2.0x speedup. Right: class-to-video with 2.2x and 2.5x speedup. Latency is measured on one A100.)
This repository contains the official PyTorch implementation of the paper: Accelerating Vision Diffusion Transformers with Skip Branches. In this work, we enhance standard DiT models by introducing Skip-DiT, which incorporates skip branches to improve feature smoothness. We also propose Skip-Cache, a method that leverages skip branches to cache DiT features across timesteps during inference.The effectiveness of our approach is validated on various DiT backbones for both video and image generation, demonstrating how skip branches preserve generation quality while achieving significant speedup. Experimental results show that Skip-Cache provides a
Illustration of Skip-DiT and Skip-Cache for DiT visual generation caching. (a) The vanilla DiT block for image and video generation. (b) Skip-DiT modifies the vanilla DiT model using skip branches to connect shallow and deep DiT blocks. (c) Given a Skip-DiT with
Feature smoothness analysis of DiT in the class-to-video generation task using DDPM. Normalized disturbances, controlled by strength coefficients
Model | Task | Training Data | Backbone | Size(G) | Skip-Cache |
---|---|---|---|---|---|
Latte-skip | text-to-video | Vimeo | Latte | 8.76 | ✅ |
DiT-XL/2-skip | class-to-image | ImageNet | DiT-XL/2 | 11.40 | ✅ |
ucf101-skip | class-to-video | UCF101 | Latte | 2.77 | ✅ |
taichi-skip | class-to-video | Taichi-HD | Latte | 2.77 | ✅ |
skytimelapse-skip | class-to-video | SkyTimelapse | Latte | 2.77 | ✅ |
ffs-skip | class-to-video | FaceForensics | Latte | 2.77 | ✅ |
Pretrained text-to-image Model of HunYuan-DiT can be found in Huggingface and Tencent-cloud.
To prepare environments for class-to-video
, text-to-video
tasks, please refer to Latte or you can:
cd class-to-video
conda env create -f environment.yaml
conda activate latte
To prepare environments for class-to-image
task, please refer to DiT or you can:
cd class-to-image
conda env create -f environment.yaml
conda activate DiT
To prepare environments for text-to-image task, please refer to Hunyuan-DiT:
cd text-to-image
conda env create -f environment.yaml
conda activate HunyuanDiT
To download models, first install:
python -m pip install "huggingface_hub[cli]"
To download models needed in text-to-video task:
# 1. Download the vae, encoder, tokenizer and orinal checkpoints of Latte. Models are download default to ./text-to-video/pretrained/
cd text-to-video
python pretrained/download_ckpts.py
# 2. Download the checkpoint of Latte-skip
huggingface-cli download GuanjieChen/Skip-DiT/Latte-skip.pt -d ./pretrained/
To download models needed in class-to-video tasks.
cd class-to-video
# 1. Download vae models and original checkpoints of Latte:
python pretrained/download_ckpts.py
# 2. Download the checkpoint of Latte-skip
task_name=ffs # or UCF101, skytimelapse, taichi
huggingface-cli download GuanjieChen/Skip-DiT/${task}-skip.pt -d ./pretrained/
To download models needed in text-to-image task:
cd text-to-image
mkdir ckpts
huggingface-cli download Tencent-Hunyuan/HunyuanDiT-v1.2 --local-dir ./ckpts
To download models needed in class-to-image task:
cd class-to-image
mkdir ckpts
# 1. download DiT-XL/2
wget -P ckpts https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt
# 2. download DiT-XL/2-skip
huggingface-cli download GuanjieChen/Skip-DiT/DiT-XL-2-skip.pt -d ./ckpts
To infer with text-to-video models:
cd text-to-video
# 1. infer with original latte
./sample/t2v.sh
# 2. infer with skip dit
./sample/t2v_skip.sh
# 3. accelerate with skip-cache
./sample/t2v_skip_cache.sh
To infer with class-to-video models:
task_name=ffs # or UCF101, skytimelapse, taichi-hd
cd class-to-video
# 1. infer with original latte
./sample/${task_name}.sh
# 2. infer with skip dit
./sample/${task_name}_skip.sh
# 3. accelerate with skip-cache
./sample/${task_name}_cache.sh
To infer with text-to-image model, please follow the instructions in the official implementation of HunYuan-DiT. To use Skip-Cache, add the following command to inference scripts ./infer.sh
:
--deepcache -N {2,3,4...}
To infer with class-to-image models:
cd class-to-image
# 1. infer with DiT-XL/2
python sample.py --ckpt path/to/DiT-XL-2.pt --model DiT-XL/2
# 2. infer with DiT-XL/2-skip
python sample.py --ckpt path/to/DiT-XL-2-skip.pt --model DiT-skip
# 3. accelerate with skip-cache
python sample.py --ckpt path/to/DiT-XL-2-skip.pt --model DiT-cache-2 # or DiT-cache-3, DiT-cache-4...
To train the DiT-XL/2-skip:
- Download the ImageNet dataset.
- Implement the TODO in the train.py
- run the script
run_train.sh
To train the class-to-video models:
- Download the datasets offered by Xin Ma in huggingface: skytimelapse, taichi, ffs. And you have to download ucf101 from the website.
- Implement the TODOs in the training configs under
class-to-video/configs
- Run the training scripts under
class-to-image/train_scripts
To train the text-to-video model:
- Prepare your text-video datasets and implement the
text-to-video/datasets/t2v_joint_dataset.py
- Run the two-stage training strategy:
- Freeze all the parameters except skip-branches. Set
freeze=True
intext-to-video/configs/train_t2v.yaml
. And then run the training scripts. - Overall training. Set
freeze=False
intext-to-video/configs/train_t2v.yaml
. And then run the training scripts.
- Freeze all the parameters except skip-branches. Set
Skip-DiT has been greatly inspired by the following amazing works and teams: DeepCache, Latte, DiT, and HunYuan-DiT, we thank all the contributors for open-sourcing.
The code and model weights are licensed under LICENSE.