Code for the NIPS 2017 paper Prototypical Networks for Few-shot Learning.
If you use this code, please cite our paper:
@inproceedings{snell2017prototypical,
title={Prototypical Networks for Few-shot Learning},
author={Snell, Jake and Swersky, Kevin and Zemel, Richard},
booktitle={Advances in Neural Information Processing Systems},
year={2017}
}
- This code has been tested on Ubuntu 16.04 with Python 2.7 and 3.5. If you're using conda, you can create a Python environment by running
conda create -n protonets python=2.7
orconda create -n protonets python=3.5
and then activate it by runningsource activate protonets
. - Install PyTorch and torchvision.
- Install torchnet by running
pip install git+https://github.com/pytorch/tnt.git@master
. - Install the protonets package by running
python setup.py install
orpython setup.py develop
.
- Run
sh download_omniglot.sh
.
- Run
python scripts/train/few_shot/run_train.py
. This will run training and place the results intoresults
.- You can specify a different output directory by passing in the option
--log.exp_dir EXP_DIR
, whereEXPD_DIR
is your desired output directory. - If you are running on a GPU you can pass in the option
--data.cuda
.
- You can specify a different output directory by passing in the option
- Re-run in trainval mode
python scripts/train/few_shot/run_trainval.py
. This will save your model intoresults/trainval
by default.
- Run evaluation as:
python scripts/predict/few_shot/run_eval.py --model.model_path results/trainval/best_model.t7
.