The main purpose of this repository is providing a pipeline for training and evaluating capsule nets on 2d or 3d image/volumetric datasets for the classification task. In our case we apply the pipeline on the Lung-PET-CT-Dx dataset as well as on MNIST for checking our pipeline.
Some collected results can be found in the presentation (AdvancedMl-Final.pdf
) in this repo.
- conda
Run conda env create --file environment.yml
from this folder.
Run conda env export > environment.yml
The interface for using this repository is src/train.py
. Always execute it from src
(run cd src
from the project root dir).
Via command line arguments you can specify which model to train on which dataset and other hyperparameters.
In the following we will describe the main arguments you can use for running the pipeline.
In general the command structure is python train.py [MODEL] [PARAMETER_1] ...
with [MODEL] = resnet_2d, capsnet_2d or capsnet_3d
.
--dataset=[DATASET_NAME]
: Specifies the dataset. Can belungpetctx
or ``mnist`--slice_width=[INTEGER]
Specifies the CT image width per slice OR image width in MNIST--early_stopping
If set, training will stop once validation loss has not improved over 2 epochs--class_imbalance=[STRATEGY]
Specifies how to deal with class imbalances (only available for lungpetctx dataset). Strategy can be "class_weights" (scale loss by inverse class frequency), "undersample" (take less samples of majority class to achieve class balance), "none" (do nothing)
Other arguments can be seen via python train.py [MODEL] -h
(e.g. python train.py capsnet_2d -h
)
To make sure our pipeline works one can train capsnet_2d
or resnet_2d
on the MNIST dataset.
Run python train.py resnet_2d --epochs=10 --slice_width=28 --learning_rate=0.001 --dataset=mnist
Run python train.py capsnet_2d --dataset=mnist --slice_width=128 --batch_size=4 --learning_rate=0.001
Run python train.py resnet_2d --epochs=10 --slice_width=128 --learning_rate=0.001 --dataset=lungpetctx --early_stopping
Run python train.py capsnet_2d --epochs=10 --slice_width=128 --learning_rate=0.01 --dataset=lungpetctx
Not supported
Run python train.py capsnet_3d --epochs=10 --slice_width=128 --learning_rate=0.01 --dataset=lungpetctx --iterations=3 --early_stopping --batch_size=4