This code is reposted from the official google-research repository.
This is a Tensorflow implementation of the Stacked Capsule Autoencoder (SCAE), which was introduced in the in the following paper: A. R. Kosiorek, Sara Sabour, Y. W. Teh, and Geoffrey E. Hinton, "Stacked Capsule Autoencoders".
- Author: Adam R. Kosiorek, Oxford Robotics Institute & Department of Statistics, University of Oxford
- Email: adamk(at)robots.ox.ac.uk
- Webpage: http://akosiorek.github.io/
This work was done during Adam's internship at Google Brain in Toronto.
If you look at natural images containing objects, you will quickly see that the same object can be captured from various viewpoints. Capsule Networks are specifically designed to be robust to viewpoint changes, which makes learning more data-efficient and allows better generalization to unseen viewpoints. This project introduces a novel unsupervised version of Capsule Networks called Stacked Capsule Autoencoders (SCAE). Unlike in the original Capsules, SCAE is a generative model with an affine-aware decoder. This forces the encoder to learn image representation that is equivariant to viewpoint changes, and which leads to state-of-the-art unsupervised classification performance on MNIST and SVHN. For a more detailed description please have a look at the paper or at Adam's blog.
Fig 1: TSNE embeddings of object-capsule presence probabilities on MNIST digits color-coded by digit class.
If you execute setup_virtualenv.sh
, it will create a virtual environment and install all required dependencies. Alternatively, you can install all the dependencies using pip install -r requirements.txt
. You can also manually install Tensorflow v1.15 and the following dependencies:
absl_py
>=0.8.1imageio
>=2.6.1matplotlib
>=3.0.3monty
>=3.0.2numpy
>=1.16.2Pillow
>=6.2.1scikit_learn
>=0.20.4scipy
>=1.2.1dm_sonnet
==1.35tensorflow
==1.15.0tensorflow_probability
==0.8.0tensorflow_datasets
==1.3.0
The current implementation is not compatible with tensorflow
==2.0.
- The training loop is defined in
train.py
. - The model is built in
capsules/configs/model_config.py
. - The part capsule encoder and decoder are defined in
capsules/primary.py
. - The object capsule can be found in
capsules/capsule.py
. - The SCAE model is defined in
capsules/models/scae.py
.
how to run tf1.15 on cuda11?
You can get the Container Version from here: link
example:
sudo docker pull nvcr.io/nvidia/tensorflow:21.07-tf1-py3
sudo docker run --gpus all -d -it --name=tf1.15-cuda11.4 nvcr.io/nvidia/tensorflow:21.07-tf1-py3
sudo docker cp stacked_capsule_autoencoders tf1.15-cuda11.4:/workspace
sudo docker attach tf1.15-cuda11.4
You can train the model by invoking:
pip install -r requirements.txt
./run_mnist.sh
Different model configurations and datasets can be chosen by command-line flags. For an overview of the flags, have a look at the top of the training script (train.py
) and the top of model and data configs in capsules/configs
.
To replicate experiments reported in the paper you can run one of the run_*.sh
files:
run_mnist.sh
runs the full MNIST model as reported in the paper.run_mnist_coupled.sh
runs the full MNIST model, with the difference that the part mixing probabilities are constrained to be the same as pixel intensities. We used a similar setup for the SVHN experiments.run_constellation.sh
runs the constellation experiment.
You can also pass command line arguments to the above scripts. For example,
./run.sh --batch_size=128
will invoke run.sh
script and set the batch size to 128.
Model snapshots, tensorboard logs and intermediate plots (needs passing --plot=True
) will be stored in {logdir}/{name}
directory where logdir
and name
can be set as command-line arguments. logdir
defaults to google-research/stacked_capsule_autoencoders/checkpoints
.
You can get pretrain model from here: link
You can use the eval_mnist_model.py
script to evaluate a trained (on MNIST) model snapshot. To do so, you need to invoke the script with the same command-line arguments you used to train the model. We provide eval_mnist.sh
and eval_mnist_coupled.sh
for evaluating snapshots produced by run_mnist.sh
and run_mnist_coupeld.sh
. Invoking this script will load the model snapshot and report linear classification accuracy, unsupervised classification accuracy (clustering + bipartite graph matching), and also produce a TSNE plot of capsule probabilities. The plot will be stored in the checkpoint directory.
If you find this repo or the corresponding paper useful in your research, please consider citing:
@inproceedings{Kosiorek2019scae,
title={Stacked Capsule Autoencoders},
author={Kosiorek, Adam Roman and Sabour, Sara and Teh, Yee Whye and Hinton, Geoffrey Everest},
booktitle={Advances in Neural Information Processing Systems},
url = {https://arxiv.org/abs/1906.06818},
pdf = {https://arxiv.org/pdf/1906.06818},
year={2019}
}
Version 1.0
- Original implementation; contains the constellation and MNIST experiments.