Skip to content

Latest commit

 

History

History
156 lines (129 loc) · 8.84 KB

README.md

File metadata and controls

156 lines (129 loc) · 8.84 KB

Phase Space Langevin Diffusion

ICCV'23 (Oral Presentation)


Official Implementation of the paper: A Complete Recipe for Diffusion Generative Models

Overview

We propose a complete recipe for constructing novel diffusion processes which are guaranteed to converge to a specified stationary distribution. This serves two benefits:

  • Firstly, a principled parameterization to construct diffusion models can allow for construction of more flexible processes which do not necessarily rely on physical intuition (like CLD)
  • Secondly, given a diffusion process, our parameterization can validate if the process converges to a specified stationary distribution.

To instantiate this recipe, we propose a new diffusion model: Phase Space Langevin Diffusion (PSLD) which outperforms diffusion models like VP-SDE and CLD, and achieves excellent sample quality on standard image synthesis benchmarks like CIFAR-10 (FID:2.10) and CelebA-64 (FID: 2.01). PSLD also supports classifier-(free) guidance out of the box like other diffusion models.

Code Overview

This repo uses PyTorch Lightning for training and Hydra for config management so basic familiarity with both these tools is expected. Please clone the repo with PSLD as the working directory for any downstream tasks like setting up the dependencies, training and inference.

Dependency Setup

We use pipenv for a project-level dependency management. Simply install pipenv and run the following command:

pipenv install

Config Management

We manage hydra configurations separately for each benchmark/dataset used in this work. All configs are present in the main/configs directory. This directory has subfolders named according to the dataset where each subfolder contains the corresponding config associated with a specific diffusion models.

Training and Evaluation

We include a sample training and inference script for CIFAR-10. We include all scripts used in this work in the directory /scripts_psld/.

  • Training a PSLD model on CIFAR-10:
python main/train_sde.py +dataset=cifar10/cifar10_psld \
                     dataset.diffusion.data.root=\'/path/to/cifar10/\' \
                     dataset.diffusion.data.name='cifar10' \
                     dataset.diffusion.data.norm=True \
                     dataset.diffusion.data.hflip=True \
                     dataset.diffusion.model.score_fn.in_ch=6 \
                     dataset.diffusion.model.score_fn.out_ch=6 \
                     dataset.diffusion.model.score_fn.nf=128 \
                     dataset.diffusion.model.score_fn.ch_mult=[2,2,2] \
                     dataset.diffusion.model.score_fn.num_res_blocks=8 \
                     dataset.diffusion.model.score_fn.attn_resolutions=[16] \
                     dataset.diffusion.model.score_fn.dropout=0.15 \
                     dataset.diffusion.model.score_fn.progressive_input='residual' \
                     dataset.diffusion.model.score_fn.fir=True \
                     dataset.diffusion.model.score_fn.embedding_type='fourier' \
                     dataset.diffusion.model.sde.beta_min=8.0 \
                     dataset.diffusion.model.sde.beta_max=8.0 \
                     dataset.diffusion.model.sde.decomp_mode='lower' \
                     dataset.diffusion.model.sde.nu=4.01 \
                     dataset.diffusion.model.sde.gamma=0.01 \
                     dataset.diffusion.model.sde.kappa=0.04 \
                     dataset.diffusion.training.seed=0 \
                     dataset.diffusion.training.chkpt_interval=50 \
                     dataset.diffusion.training.mode='hsm' \
                     dataset.diffusion.training.fp16=False \
                     dataset.diffusion.training.use_ema=True \
                     dataset.diffusion.training.batch_size=16 \
                     dataset.diffusion.training.epochs=2500 \
                     dataset.diffusion.training.accelerator='gpu' \
                     dataset.diffusion.training.devices=8 \
                     dataset.diffusion.training.results_dir=\'/path/to/logdir/' \
                     dataset.diffusion.training.workers=1
  • Generating CIFAR-10 samples from the PSLD model
python main/eval/sample.py +dataset=cifar10/cifar10_psld \
                     dataset.diffusion.model.score_fn.in_ch=6 \
                     dataset.diffusion.model.score_fn.out_ch=6 \
                     dataset.diffusion.model.score_fn.nf=128 \
                     dataset.diffusion.model.score_fn.ch_mult=[2,2,2] \
                     dataset.diffusion.model.score_fn.num_res_blocks=8 \
                     dataset.diffusion.model.score_fn.attn_resolutions=[16] \
                     dataset.diffusion.model.score_fn.dropout=0.15 \
                     dataset.diffusion.model.score_fn.progressive_input='residual' \
                     dataset.diffusion.model.score_fn.fir=True \
                     dataset.diffusion.model.score_fn.embedding_type='fourier' \
                     dataset.diffusion.model.sde.beta_min=8.0 \
                     dataset.diffusion.model.sde.beta_max=8.0 \
                     dataset.diffusion.model.sde.nu=4.01 \
                     dataset.diffusion.model.sde.gamma=0.01 \
                     dataset.diffusion.model.sde.kappa=0.04 \
                     dataset.diffusion.model.sde.decomp_mode='lower' \
                     dataset.diffusion.evaluation.seed=0 \
                     dataset.diffusion.evaluation.sample_prefix='some_prefix_for_sample_names' \
                     dataset.diffusion.evaluation.devices=8 \
                     dataset.diffusion.evaluation.save_path=\'/path/to/generated/samples/\' \
                     dataset.diffusion.evaluation.batch_size=16 \
                     dataset.diffusion.evaluation.stride_type='uniform' \
                     dataset.diffusion.evaluation.sample_from='target' \
                     dataset.diffusion.evaluation.workers=1 \
                     dataset.diffusion.evaluation.chkpt_path=\'/path/to/pretrained/chkpt.ckpt\' \
                     dataset.diffusion.evaluation.sampler.name="em_sde" \
                     dataset.diffusion.evaluation.n_samples=50000 \
                     dataset.diffusion.evaluation.n_discrete_steps=50 \
                     dataset.diffusion.evaluation.path_prefix="50"

We evaluate sample quality using FID scores. We compute FID scores using the torch-fidelity package.

Pretrained Checkpoints

Pre-trained PSLD checkpoints can be found here

Citation

If you find the code useful for your research, please consider citing our ICCV'23 paper:

@InProceedings{Pandey_2023_ICCV,
    author    = {Pandey, Kushagra and Mandt, Stephan},
    title     = {A Complete Recipe for Diffusion Generative Models},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2023},
    pages     = {4261-4272}
}

License

See the LICENSE file.