PyTorch implementation of DuoVAE (a VAE-framework for property-controlled data generation) proposed in
Collagen Fiber Centerline Tracking in Fibrotic Tissue via Deep Neural Networks with Variational Autoencoder-based Synthetic Training Data Generation,
Hyojoon Park*, Bin Li*, Yuming Liu, Michael S. Nelson, Helen M. Wilson, Eftychios Sifakis, Kevin W. Eliceiri,
Medical Image Analysis 2023.
DuoVAE generates data with desired properties controlled by continuous property values.
This repository is designed specifically for training and testing on the VAE benchmark datasets:
dSprites and 3dshapes.
-
Linux, MacOS, Windows
-
CPU, CUDA, MPS (arm64 Apple silicon)
(MPS is supported for MacOS 12.3+, but may be unstable currently - September 2022)
For conda users,
conda env create -f environment.yml
For pip users,
pip install -r requirements.txt
Then, install PyTorch with either CPU, CUDA, or MPS supports.
Other configuration methods can be found here.
List of tested versions can be found here.
Running train.py
(see Train
section below) will automatically download the needed datasets.
To download them directly, run
bash download_datasets.sh
This will download dSprites and 3dshapes to: ./datasets/data/
.
Command format is python train.py <dataset-type>
, for example
python train.py duovae 2d
<dataset-type>
Additional parameters can be configured in parameters.json
.
More command examples can be found here.
(coming)
dSprites dataset
The controlled properties (from left to right in each row) are
-
$y_1$ : scale of a shape$\rightarrow$ from small to large, -
$y_2$ :$x$ position of a shape$\rightarrow$ from left to right, -
$y_3$ :$y$ position of a shape$\rightarrow$ from top to bottom.
3dshapes dataset
The controlled properties (from left to right in each row) are
-
$y_1$ : scale of a shape$\rightarrow$ from small to large, -
$y_2$ : wall color$\rightarrow$ from red to violet, -
$y_3$ : floor color$\rightarrow$ from red to violet.
The mutual information of two random variables quantifies the amount of information (in units such as shannons (bits) or nats) obtained about one random variable by observing the other random variable.
In the ideal case, the heatmaps of normalized MI between each of the properties and latent variables should be 1 in the diagonal values and 0 in the off-diagonal values as well as for
Below are the heatmaps of the normalized MI on dSprites (left) and 3dshapes (right) dataset.
We can see that the diagonal values of the heatmap are high and close to 1 for both datasets, implying high correlations between each of the respective property values
It is possible to have a very high MI score but with very poor reconstructions. Therefore, we visualize the reconstructions generated when traversing the latent variables to validate whether the latent representation spaces are smooth and disentangled.
dSprites dataset with supervised latent variables
3dshapes dataset with supervised latent variables
Each of the supervised latent variables
@article{park2023collagen,
title={Collagen fiber centerline tracking in fibrotic tissue via deep neural networks with variational autoencoder-based synthetic training data generation},
author={Park, Hyojoon and Li, Bin and Liu, Yuming and Nelson, Michael S and Wilson, Helen M and Sifakis, Eftychios and Eliceiri, Kevin W},
journal={Medical Image Analysis},
volume={90},
pages={102961},
year={2023},
publisher={Elsevier}
}