Implementation of Variational Diffusion Auto-encoder: Latent Space Extraction from Pre-trained Diffusion Models (Batzolis++23) in jax
and equinox
.
The idea here is to remedy the assumption a traditional variational autoencoder (VAE) on the reconstruction likelihood
In practice, this likelihood of the data given a latent code consists of the sum of the scores of the marginal likelihood of the data
The variational posterior is modelled by a Gaussian ansatz parameterised with a mean and diagonal covariance, but is a function of the diffusion time. This approach separates the uses the same VAE objective comprising the reconstruction loss and variational posterior KL, but the gradients only adjust the encoder - this improves the training dynamics of the traditional VAE by forming the VAE objective with a pretrained diffusion model and a varational posterior, only the latter of which is optimised. This also allows for the extraction of a latent space from existing generative models.
The corrector model