Official Implementation of the paper: Fast Samplers for Inverse Problems in Iterative Refinement Models
We propose a method for fast sampler based on Conjugate Integrators for solving inverse problems using pretrained diffusion or flow models. Also see our official project page for qualitative results.
This repo builds on top of the official RED-Diff and official Rectified Flow implementations. Please clone the repo and cd
to either diffusion
or flow
as the working directory for downstream tasks like setting up the dependencies, and inference.
Diffusion: Our requirements.txt
file can be used to setup a conda environment directly using the following command:
conda create --name <env> --file requirements.txt
Flow: Please see Rectified Flow for setting up the dependencies.
Diffusion: In this work we use pretrained unconditional ImageNet checkpoints from the ADM repository. We use an pretrained unconditional FFHQ checkpoint from the DPS repository
Flow: We use pretrained weight from Rectified Flow, with some additional packages including pytorch_msssim
, omegaconf
, compressai
, lpips
(mostly about evaluation and degradation transform).
Diffusion: Config management is done using Hydra configs. All configs can be found in the directory _configs/
with algorithm specific config in the _configs/algo
directory.
We include a sample inference script used in this work in the directory test_scripts/test.sh
. For instance, using the noiseless linear Conjugate
samples_root=demo_samples/
exp_root=/home/pandeyk1/ciip_results/
ckpt_root=/home/pandeyk1/ciip_results/pretrained_chkpts/adm/
save_deg=True
save_ori=True
overwrite=True
smoke_test=1 # Controls the number of batches generated
batch_size=1
python main.py \
diffusion=vpsde \
classifier=none \
algo=cpgdm \
algo.lam=0 \
algo.w=15.0 \
algo.deg=bicubic4 \
algo.num_eps=1e-6 \
algo.denoise=False \
loader=imagenet256_ddrmpp \
loader.batch_size=$batch_size \
loader.num_workers=2 \
dist.num_processes_per_node=1 \
exp.name=debug \
exp.num_steps=5 \
exp.seed=0 \
exp.stride=uniform \
exp.t_start=0.6 \
exp.t_end=0.999 \
exp.root=$exp_root \
exp.name=demo \
exp.ckpt_root=$ckpt_root \
exp.samples_root=$samples_root \
exp.overwrite=True \
exp.save_ori=$save_ori \
exp.save_deg=$save_deg \
exp.smoke_test=$smoke_test
See run.sh
in the flow
directory for running the inference on the flow model.
If you find the code useful for your research, please consider citing our paper:
@misc{pandey2024fastsamplersinverseproblems,
title={Fast Samplers for Inverse Problems in Iterative Refinement Models},
author={Kushagra Pandey and Ruihan Yang and Stephan Mandt},
year={2024},
eprint={2405.17673},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2405.17673},
}