Pytorch implementation and Wapper classes for torch.nn.modules layers of:
- Flipout[1]
- Local Reparameterization Trick[2]
Pytorch >= 1.0.0
Variationalizers - Wrapper for nn.module class(Supports [Lazy|Standard][Linear|Convolutional] layers.)
- Flipout Wrapper
- Local Reparameterization Wrapper
Stand-alone Flipout layers:
- Conv2d_flipout
- Linear_flipout
Clone this repo and run:
pip install -e torch_variational
Example usage for wrapper classes:
from torch_variational.wrapper import Variational_Flipout, Variational_LRT
Flipout_layer = Variational_Flipout(nn.Linear(in_features = 10, out_features = 10, bias = True))
LRT_layer = Variational_LRT(nn.Linear(in_features = 10, out_features = 10, bias = True))
Flipout_output = Flipout_layer(torch.randn(1, 10))
LRT_output = LRT_layer(torch.randn(1, 10))
Flipout_kld = Flipout_layer.kld()
LRT_kld = LRT_layer.kld()
Example usage for Stand-alone Flipout layers:
from torch_variational.wrapper import Variational_Flipout
layer = flipout.Linear_flipout(in_features = 10, out_features = 10, bias = True)
output, kld = layer(torch.randn(1, 10))
Assumed weight multiplicative variances.
[1] @inproceedings{DBLP:conf/iclr/WenVBTG18,
author = {Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger B. Grosse},
title = {Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-Batches},
booktitle = {6th International Conference on Learning Representations, {ICLR} 2018,
Vancouver, BC, Canada, April 30 - May 3, 2018, Conference Track Proceedings},
year = {2018},
url = {https://openreview.net/forum?id=rJNpifWAb}
}
[2] @inproceedings{NIPS2015_bc731692,
author = {Kingma, Durk P and Salimans, Tim and Welling, Max},
title = {Variational Dropout and the Local Reparameterization Trick},
booktitle = {Advances in Neural Information Processing Systems},
volume = {28},
year = {2015}
url = {https://proceedings.neurips.cc/paper/2015/file/bc7316929fe1545bf0b98d114ee3ecb8-Paper.pdf},
}