Constructed and trained a capsule network to predict digits from the MNIST dataset.
A Capsule Network is basically a neural network that tries to perform inverse graphics(Process of converting a visual image to some internal hierarchical representation of geometric data). It understands relative relationships between objects. Capsule Networks use vectors called capsules that incorporate all the important information about the state of the feature they are detecting. A capsule is any function that tries to predict the presence and instantiation parameters of a particular object at any given location. The architecture consists of an encoder network and a decoder network. The forward pass of the combined network is computed using the dynamic routing algorithm.
Fig 1. The CapsNet Architecture (Encoder) from the original paper by S Sabour et al., 2017.
Fig 2. The CapsNet Architecture (Decoder) from the original paper by S Sabour et al., 2017.
These instructions will get you a copy of the project up and running on your local machine for development and testing purposes.
For using this project, you need to install PyTorch and Plotly.
pip install torch torchvision
pip install plotly
The MNIST Dataset is a dataset of 60,000 28x28 grayscale images of the 10 digits, along with a test set of 10,000 images.
Run the script capsulenetwork.py with mode="Train" in the terminal as follows.
Python capsulenetwork.py
Each epoch takes about 87 seconds on average when using Google Colab's GPU.
The reconstructions for every 10th epoch are stored in the Training folder.
Fig 3. Reconstructions for every 10th epoch.
After 100 Epochs:
Final Training Accuracy = 99.91%
Final Training Loss = 0.4595620
Fig 4. Training Loss and Training Accuracy Graph
Run the script capsulenetwork.py with mode="Test" in the terminal as follows.
Python capsulenetwork.py
Test Set Accuracy = 98.80%
Fig 5. Training Accuracy vs Testing Accuracy Graph
Fig. 6 Ground Truth Image | Fig 7. Reconstructed Image |
---|---|
Each capsule in the Digit Capsule Layer is a 16-Dimensional Vector. By holding 15 dimensions constant and slightly varying one dimension, we can understand the property captured by that dimension as shown below:
Fig. 8 Dimension 4 (Localised Skew) | Fig 9. Dimension 5 (Curvature) |
---|---|
Fig. 10 Dimension 7 (Stroke & Thickness) | Fig 11. Dimension 9 (Edge Translation) |
---|---|
There is my interpretation of what some of these dimensions capture. In this way, Capsule Networks capture the important spatial hierarchies between simple features and complex features.
- PyTorch - Deep Learning Framework
- Google Colab - Cloud Service
- Vikram Shenoy - Initial work - Vikram Shenoy
- Project is inspired by Sara Sabour, Nicholas Frosst, and Geoffrey E Hinton's paper, Dynamic Routing Between Capsules
- Initial understanding of Capsule Networks was made easy through Aurélien Géron's Youtube Video on Capsule Networks.
- Procured an in-depth understanding of Capsule Networks and dynamic routing algorithm through Max Pechyonkin's blog, Understanding Hinton’s Capsule Networks.
- Referenced Gram.AI's code for some details.