A Flax (Linen) implementation of ResNet (He et al. 2015), Wide ResNet (Zagoruyko & Komodakis 2016), ResNeXt (Xie et al. 2017), ResNet-D (He et al. 2020), and ResNeSt (Zhang et al. 2020). The code is modular so you can mix and match the various stem, residual, and bottleneck implementations.
You can install this package from PyPI:
pip install jax-resnet
Or directly from GitHub:
pip install --upgrade git+https://github.com/n2cholas/jax-resnet.git
See the bottom of jax-resnet/resnet.py
for the available aliases/options for
the ResNet variants (all models are in Flax)
Pretrained checkpoints from
torch.hub
are available for the
following networks:
- ResNet [18, 34, 50, 101, 152]
- WideResNet [50, 101]
- ResNeXt [50, 101]
- ResNeSt [50-Fast, 50, 101, 200, 269]
The models are
tested
to have the same intermediate activations and outputs as the torch.hub
implementations, except ResNeSt-50 Fast, whose activations don't match exactly
but the final accuracy does.
A pretrained checkpoint for ResNetD-50 is available from fast.ai. The activations do not match exactly, but the final accuracy matches.
import jax.numpy as jnp
from jax_resnet import pretrained_resnest
ResNeSt50, variables = pretrained_resnest(50)
model = ResNeSt50()
out = model.apply(variables,
jnp.ones((32, 224, 224, 3)), # ImageNet sized inputs.
mutable=False) # Ensure `batch_stats` aren't updated.
You must install PyTorch yourself (instructions) to use these functions.
To extract a subset of the model, you can use
Sequential(model.layers[start:end])
.
The slice_variables
function (found in in
common.py
)
allows you to extract the corresponding subset of the variables dict. Check out
that docstring for more information.
The top 1 and top 5 accuracies reported below are on the ImageNet2012 validation split. The data was preprocessed as in the official PyTorch example.
Model | Size | Top 1 | Top 5 |
---|---|---|---|
ResNet | 18 | 69.75% | 89.06% |
34 | 73.29% | 91.42% | |
50 | 76.13% | 92.86% | |
101 | 77.37% | 93.53% | |
152 | 78.30% | 94.04% | |
Wide ResNet | 50 | 78.48% | 94.08% |
101 | 78.88% | 94.29% | |
ResNeXt | 50 | 77.60% | 93.70% |
101 | 79.30% | 94.51% | |
ResNet-D | 50 | 77.57% | 93.85% |
The ResNeSt validation data was preprocessed as in zhang1989/ResNeSt.
Model | Size | Crop Size | Top 1 | Top 5 |
---|---|---|---|---|
ResNeSt-Fast | 50 | 224 | 80.53% | 95.34% |
ResNeSt | 50 | 224 | 81.05% | 95.42% |
101 | 256 | 82.82% | 96.32% | |
200 | 320 | 83.84% | 96.86% | |
269 | 416 | 84.53% | 96.98% |
- Deep Residual Learning for Image Recognition. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. arXiv 2015.
- Wide Residual Networks. Sergey Zagoruyko, Nikos Komodakis. BMVC 2016
- Aggregated Residual Transformations for Deep Neural Networks. Saining Xie, Ross Girshick, Piotr Dollár, Zhuowen Tu, Kaiming He. CVPR 2017.
- Bag of Tricks for Image Classification with Convolutional Neural Networks. Tong He, Zhi Zhang, Hang Zhang, Zhongyue Zhang, Junyuan Xie, Mu Li. CVPR 2019.
- ResNeSt: Split-Attention Networks. Hang Zhang, Chongruo Wu, Zhongyue Zhang, Yi Zhu, Zhi Zhang, Haibin Lin, Yue Sun, Tong He, Jonas Mueller, R. Manmatha, Mu Li, Alexander Smola. arXiv 2020.