Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Yann Dauphin authored and Yann Dauphin committed Feb 27, 2018
0 parents commit 2d3298e
Show file tree
Hide file tree
Showing 15 changed files with 2,123 additions and 0 deletions.
399 changes: 399 additions & 0 deletions LICENSE

Large diffs are not rendered by default.

49 changes: 49 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Mixup-CIFAR10
By [Hongyi Zhang](http://web.mit.edu/~hongyiz/www/), [Moustapha Cisse](https://mine.kaust.edu.sa/Pages/cisse.aspx), [Yann Dauphin](http://dauphin.io/), [David Lopez-Paz](https://lopezpaz.org/).

Facebook AI Research

## Introduction

Mixup is a generic and straightforward data augmentation principle.
In essence, mixup trains a neural network on convex combinations of pairs of
examples and their labels. By doing so, mixup regularizes the neural network to
favor simple linear behavior in-between training examples.

This repository contains the implementation used for the results in
our paper (https://arxiv.org/abs/1710.09412).

## Citation

If you use this method or this code in your paper, then please cite it:

```
@article{
zhang2018mixup,
title={mixup: Beyond Empirical Risk Minimization},
author={Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz},
journal={International Conference on Learning Representations},
year={2018},
url={https://openreview.net/forum?id=r1Ddp1-Rb},
}
```

## Requirements and Installation
* A computer running macOS or Linux
* For training new models, you'll also need a NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* Python version 3.6
* A [PyTorch installation](http://pytorch.org/)

## Training
Use `python train.py` to train a new model.
Here is an example setting:
```
$ CUDA_VISIBLE_DEVICES=0 python train.py --lr=0.1 --seed=20170922 --decay=1e-4
```

## License

This project is CC-BY-NC-licensed.

## Acknowledgement
The CIFAR-10 reimplementation of _mixup_ is adapted from the [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar) repository by [kuangliu](https://github.com/kuangliu).
9 changes: 9 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .vgg import *
from .lenet import *
from .resnet import *
from .resnext import *
from .densenet import *
from .googlenet import *
from .mobilenet import *
# from .densenet_efficient_multi_gpu import DenseNet190
from .densenet3 import DenseNet190
33 changes: 33 additions & 0 deletions models/alldnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
'''LeNet in PyTorch.'''
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class AllDNet(nn.Module):
def __init__(self):
super(AllDNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
# self.conv2 = nn.Linear(6*14*14, 16*10*10)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
activations = []
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
# out = out.view(out.size(0), -1)
# activations.append(out)
out = F.relu(self.conv2(out))
# out = out.view(out.size(0), 16, 10, -1)
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
activations.append(out)
out = F.relu(self.fc1(out))
activations.append(out)
out = F.relu(self.fc2(out))
activations.append(out)
out = self.fc3(out)
return out, activations

109 changes: 109 additions & 0 deletions models/densenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
'''DenseNet in PyTorch.'''
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable


class Bottleneck(nn.Module):
def __init__(self, in_planes, growth_rate):
super(Bottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(4*growth_rate)
self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)

def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out = torch.cat([out,x], 1)
return out


class Transition(nn.Module):
def __init__(self, in_planes, out_planes):
super(Transition, self).__init__()
self.bn = nn.BatchNorm2d(in_planes)
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)

def forward(self, x):
out = self.conv(F.relu(self.bn(x)))
out = F.avg_pool2d(out, 2)
return out


class DenseNet(nn.Module):
def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10):
super(DenseNet, self).__init__()
self.growth_rate = growth_rate

num_planes = 2*growth_rate
self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)

self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])
num_planes += nblocks[0]*growth_rate
out_planes = int(math.floor(num_planes*reduction))
self.trans1 = Transition(num_planes, out_planes)
num_planes = out_planes

self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])
num_planes += nblocks[1]*growth_rate
out_planes = int(math.floor(num_planes*reduction))
self.trans2 = Transition(num_planes, out_planes)
num_planes = out_planes

self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])
num_planes += nblocks[2]*growth_rate
out_planes = int(math.floor(num_planes*reduction))
self.trans3 = Transition(num_planes, out_planes)
num_planes = out_planes

self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])
num_planes += nblocks[3]*growth_rate

self.bn = nn.BatchNorm2d(num_planes)
self.linear = nn.Linear(num_planes, num_classes)

def _make_dense_layers(self, block, in_planes, nblock):
layers = []
for i in range(nblock):
layers.append(block(in_planes, self.growth_rate))
in_planes += self.growth_rate
return nn.Sequential(*layers)

def forward(self, x):
out = self.conv1(x)
out = self.trans1(self.dense1(out))
out = self.trans2(self.dense2(out))
out = self.trans3(self.dense3(out))
out = self.dense4(out)
out = F.avg_pool2d(F.relu(self.bn(out)), 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out

def DenseNet121():
return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32)

def DenseNet169():
return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32)

def DenseNet201():
return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32)

def DenseNet161():
return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48)

def densenet_cifar():
return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12)

def test_densenet():
net = densenet_cifar()
x = torch.randn(1,3,32,32)
y = net(Variable(x))
print(y)

# test_densenet()
121 changes: 121 additions & 0 deletions models/densenet3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, dropRate=0.0):
super(BasicBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
padding=1, bias=False)
self.droprate = dropRate
def forward(self, x):
out = self.conv1(self.relu(self.bn1(x)))
if self.droprate > 0:
out = F.dropout(out, p=self.droprate, training=self.training)
return torch.cat([x, out], 1)

class BottleneckBlock(nn.Module):
def __init__(self, in_planes, out_planes, dropRate=0.0):
super(BottleneckBlock, self).__init__()
inter_planes = out_planes * 4
self.bn1 = nn.BatchNorm2d(in_planes)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,
padding=0, bias=False)
self.bn2 = nn.BatchNorm2d(inter_planes)
self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,
padding=1, bias=False)
self.droprate = dropRate
def forward(self, x):
out = self.conv1(self.relu(self.bn1(x)))
if self.droprate > 0:
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
out = self.conv2(self.relu(self.bn2(out)))
if self.droprate > 0:
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
return torch.cat([x, out], 1)

class TransitionBlock(nn.Module):
def __init__(self, in_planes, out_planes, dropRate=0.0):
super(TransitionBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
padding=0, bias=False)
self.droprate = dropRate
def forward(self, x):
out = self.conv1(self.relu(self.bn1(x)))
if self.droprate > 0:
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
return F.avg_pool2d(out, 2)

class DenseBlock(nn.Module):
def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0):
super(DenseBlock, self).__init__()
self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate)
def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate):
layers = []
for i in range(nb_layers):
layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate))
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)

class DenseNet3(nn.Module):
def __init__(self, depth, num_classes, growth_rate=12,
reduction=0.5, bottleneck=True, dropRate=0.0):
super(DenseNet3, self).__init__()
in_planes = 2 * growth_rate
n = (depth - 4) // 3
if bottleneck == True:
n = n//2
block = BottleneckBlock
else:
block = BasicBlock
# 1st conv before any dense block
self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1,
padding=1, bias=False)
# 1st block
self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
in_planes = int(in_planes+n*growth_rate)
self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
in_planes = int(math.floor(in_planes*reduction))
# 2nd block
self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
in_planes = int(in_planes+n*growth_rate)
self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
in_planes = int(math.floor(in_planes*reduction))
# 3rd block
self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
in_planes = int(in_planes+n*growth_rate)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(in_planes)
self.relu = nn.ReLU(inplace=True)
self.fc = nn.Linear(in_planes, num_classes)
self.in_planes = in_planes

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
out = self.conv1(x)
out = self.trans1(self.block1(out))
out = self.trans2(self.block2(out))
out = self.block3(out)
out = self.relu(self.bn1(out))
out = F.avg_pool2d(out, 8)
out = out.view(-1, self.in_planes)
return self.fc(out)

def DenseNet190():
return DenseNet3(190, 10, growth_rate=40)
Loading

0 comments on commit 2d3298e

Please sign in to comment.