Skip to content

Commit

Permalink
Model experimentation and data augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ArcCha committed May 10, 2018
1 parent 3bc04c1 commit 7e327d0
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
augment
CNN
train_config.yaml
notebooks/*.json
Expand Down
53 changes: 53 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from pathlib import Path

import Augmentor
import numpy as np
import torch
import torch.nn.functional as F
import yaml
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from tqdm import tqdm


class DigitRecognizerDataset(Dataset):
def __init__(self, X, Y, pretransform=False):
self.pretransform = pretransform
self.transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomRotation(15),
transforms.ToTensor()])
if self.pretransform:
X = list(map(self.transform, map(Image.fromarray, X)))
Expand Down Expand Up @@ -61,3 +67,50 @@ def get_test_dataset(csv_file, pretransform=False):
X = data.reshape(-1, 28, 28)
test_dataset = DigitRecognizerDataset(X, None, pretransform=pretransform)
return test_dataset


def save_as_png(csv_file):
data = np.genfromtxt(csv_file, delimiter=',', skip_header=1)
Y, X = np.split(data, [1], axis=1)
Y = np.squeeze(Y).astype(int)
root_path = Path('augment/original')
print(Y.shape)
root_path.mkdir(parents=True, exist_ok=True)
class_paths = [root_path.joinpath(str(i)) for i in range(10)]
for path in class_paths:
path.mkdir(exist_ok=True)
idx = [0 for _ in range(10)]
for i in tqdm(range(len(X))):
path = class_paths[Y[i]]
path = path.joinpath(str(idx[Y[i]]).zfill(5) + '.png')
idx[Y[i]] += 1
x = X[i].reshape(28, 28)
image = Image.fromarray(x).convert('RGB')
with path.open('wb') as f:
image.save(f, format='PNG')


def augment_data(root_path):
src_path = root_path.joinpath('original')
class_paths = [src_path.joinpath(str(i)).resolve() for i in range(10)]
dst_path = root_path.joinpath('out')
dst_path.mkdir(exist_ok=True)
out_paths = [dst_path.joinpath(str(i)) for i in range(10)]
for path in out_paths:
path.mkdir(exist_ok=True)
out_paths = [path.resolve() for path in out_paths]
p = Augmentor.Pipeline(source_directory=str(
class_paths[0]), output_directory=str(out_paths[0]), save_format='PNG')
for i in range(1, 10):
p.add_further_directory(str(class_paths[i]), str(out_paths[i]))
p.random_distortion(1.0, 5, 5, 1)
p.sample(80000)


if __name__ == '__main__':
config_path = Path('train_config.yaml')
config = None
with config_path.open('r') as f:
config = yaml.load(f)
# save_as_png(config['train_path'])
augment_data(Path('augment'))
4 changes: 4 additions & 0 deletions doc.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ Just an experiment, maybe a baseline for other models, it gets ~95%.
# PracticalCNN

Just a different name, different layer parameters, nothing interested.

# RichCNN

Potentially little bit better, more filters in conv layers
41 changes: 38 additions & 3 deletions net.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.bn = nn.BatchNorm2d(1)
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.bn(x)
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, reduce(lambda x, y: x * y, x.size()[1:]))
Expand All @@ -28,24 +26,61 @@ def forward(self, x):
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.bn1 = nn.BatchNorm2d(1)
self.conv1 = nn.Conv2d(1, 6, 5)
self.bn2 = nn.BatchNorm2d(6)
self.conv2 = nn.Conv2d(6, 16, 5)
self.bn3 = nn.BatchNorm2d(16)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.bn4 = nn.BatchNorm1d(120)
self.fc2 = nn.Linear(120, 84)
self.bn5 = nn.BatchNorm1d(84)
self.drop = nn.Dropout()
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = self.bn1(x)
x = F.max_pool2d(F.relu(self.conv1(x)), 2)
x = self.bn2(x)
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = self.bn3(x)
x = x.view(-1, reduce(lambda x, y: x * y, x.size()[1:]))
x = F.relu(self.fc1(x))
x = self.bn4(x)
x = F.relu(self.fc2(x))
x = self.bn5(x)
x = self.drop(x)
x = self.fc3(x)
return x


class RichCNN(nn.Module):
def __init__(self):
super(RichCNN, self).__init__()
self.bn1 = nn.BatchNorm2d(1)
self.conv1 = nn.Conv2d(1, 32, 5)
self.bn2 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, 5)
self.bn3 = nn.BatchNorm2d(64)
self.fc1 = nn.Linear(64 * 5 * 5, 1024)
self.bn4 = nn.BatchNorm1d(1024)
self.drop = nn.Dropout()
self.fc2 = nn.Linear(1024, 10)

def forward(self, x):
x = self.bn1(x)
x = F.max_pool2d(F.relu(self.conv1(x)), 2)
x = self.bn2(x)
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = self.bn3(x)
x = x.view(-1, reduce(lambda x, y: x * y, x.size()[1:]))
x = F.relu(self.fc1(x))
x = self.bn4(x)
x = self.drop(x)
x = self.fc2(x)
return x


class PracticalCNN(nn.Module):
def __init__(self):
super(PracticalCNN, self).__init__()
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
validation_classes[idx] += counts
H['validation_classes'] = validation_classes.tolist()

net = CNN()
net = RichCNN()
H['net'] = type(net).__name__
net_dir = Path('./' + H['net'])
net_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -49,7 +49,7 @@
optimizer = torch.optim.Adam(net.parameters(), lr=config['learning_rate'])
H['optimizer'] = str(optimizer)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='max')
optimizer, mode='max', verbose=True)
H['lr_scheduler'] = str(lr_scheduler)
criterion = nn.CrossEntropyLoss()
H['criterion'] = str(criterion)
Expand Down

0 comments on commit 7e327d0

Please sign in to comment.