Generated image from ANT-UW (DiT-L) with guidance scale 3.0
This repository contains the official PyTorch implementation of the following paper: "Addressing Negative Transfer in Diffusion Models" (Neurips 2023). To gain a better understanding of the paper, please visit our project page and paper on arXiv.
This code is to train DiT model with ANT on ImageNet dataset. Our implementation is based on DiT, LibMTL, NashMTL.
- 2024.07.04: Our project Switch Diffusion Transformer utilizing ANT-UW has been accepted in ECCV2024!.
- 2024.05.31: Uncertainty weighting is utilized in EDM2, a concept we had previously explored in our work. We believe that rethinking diffusion models as multi-task learners is an effective direction to improve these models!
- 2023.10.07: DTR uses ANT-UW and shows that ANT-UW outperforms P2weight and MinSNR.
- Please check our new project "Denoising Task Routing for Diffusion Models"
- 2023.11.05: Initial release.
$ pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
$ pip install -r requirements.txt
-
python train_ant_single_gpu.py \ --data_path [Your Data Path] \ --total_clusters [Your cluster size (Default: 5)] \ --mtl_method [uw, nash, or pcgrad]
-
torchrun --nproc_per_node=[number of gpus] train_uw_multi_gpu.py \ --data_path [Your Data Path] \ --total_clusters [Your cluster size (Default: 5)]
torchrun --nproc_per_node=[number of gpus] sample_ddp.py \
--model_config [Your config path (config/DiT-L.yaml, config/DiT-S.yaml)] \
--ckpt [Your ckpt path] \
--sample-dir [Your sample dir]
The example codes for interval clustering are shown in interval_clustering.py
All models can be downloaded from OneDrive link
Model | FID | IS | Precision | Recall |
---|---|---|---|---|
DiT-L + ANT-UW (Multi-GPU) | 5.695 | 186.661 | 0.811 | 0.491 |
DiT-S + Nash | 44.65 | 33.48 | 0.4209 | 0.5272 |
DiT-S + UW | 48.40 | 30.84 | 0.4196 | 0.5172 |