Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dockerize #19

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Requires a pre-built local image called pytorch, build from the pytorch repo.
# The full process is documented here:
# https://medium.com/@russferriday/exploring-capsule-networks-with-nvidia-docker-25b843d3964c
# that doc can be integrated in the capsule-network README, if appropriate.
#
# Build the docker image with:
#
# $ docker build -t capsnet .
#
# Run with:
#
# $ nvidia-docker run --rm -it --ipc=host capsnet:latest
#
# At the container prompt, start the visdom server, and the capsnet processing:
#
# # python -m visdom.server & python capsule_network.py
#
# In a separate terminal on the docker host:
#
# Obtain the CONTAINER_ID...
#
# $ docker ps
#
# Get the container IP address...
#
# $ docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' <CONTAINER_ID>
#
# On the host, browse to <returned_IP>:8097

FROM pytorch

COPY ./requirements.txt requirements.txt
RUN pip install -r requirements.txt
COPY . /workspace
WORKDIR /workspace

85 changes: 59 additions & 26 deletions capsule_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
PyTorch implementation by Kenta Iwasaki @ Gram.AI.
"""
import sys

sys.setrecursionlimit(15000)

import torch
Expand All @@ -20,8 +21,11 @@

def softmax(input, dim=1):
transposed_input = input.transpose(dim, len(input.size()) - 1)
softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)), dim=-1)
return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size()) - 1)
softmaxed_output = F.softmax(
transposed_input.contiguous().view(-1, transposed_input.size(-1)),
dim=-1)
return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(
input.size()) - 1)


def augmentation(x, max_shift=2):
Expand All @@ -34,12 +38,15 @@ def augmentation(x, max_shift=2):
target_width_slice = slice(max(0, -w_shift), -w_shift + width)

shifted_image = torch.zeros(*x.size())
shifted_image[:, :, source_height_slice, source_width_slice] = x[:, :, target_height_slice, target_width_slice]
shifted_image[:, :, source_height_slice,
source_width_slice] = x[:, :, target_height_slice,
target_width_slice]
return shifted_image.float()


class CapsuleLayer(nn.Module):
def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, kernel_size=None, stride=None,
def __init__(self, num_capsules, num_route_nodes, in_channels,
out_channels, kernel_size=None, stride=None,
num_iterations=NUM_ROUTING_ITERATIONS):
super(CapsuleLayer, self).__init__()

Expand All @@ -49,10 +56,13 @@ def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, ker
self.num_capsules = num_capsules

if num_route_nodes != -1:
self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels))
self.route_weights = nn.Parameter(
torch.randn(num_capsules, num_route_nodes, in_channels,
out_channels))
else:
self.capsules = nn.ModuleList(
[nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0) for _ in
[nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
stride=stride, padding=0) for _ in
range(num_capsules)])

def squash(self, tensor, dim=-1):
Expand All @@ -62,18 +72,21 @@ def squash(self, tensor, dim=-1):

def forward(self, x):
if self.num_route_nodes != -1:
priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]
priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :,
:]

logits = Variable(torch.zeros(*priors.size())).cuda()
for i in range(self.num_iterations):
probs = softmax(logits, dim=2)
outputs = self.squash((probs * priors).sum(dim=2, keepdim=True))
outputs = self.squash(
(probs * priors).sum(dim=2, keepdim=True))

if i != self.num_iterations - 1:
delta_logits = (priors * outputs).sum(dim=-1, keepdim=True)
logits = logits + delta_logits
else:
outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]
outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in
self.capsules]
outputs = torch.cat(outputs, dim=-1)
outputs = self.squash(outputs)

Expand All @@ -84,10 +97,15 @@ class CapsuleNet(nn.Module):
def __init__(self):
super(CapsuleNet, self).__init__()

self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
self.primary_capsules = CapsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=256, out_channels=32,
self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9,
stride=1)
self.primary_capsules = CapsuleLayer(num_capsules=8,
num_route_nodes=-1,
in_channels=256, out_channels=32,
kernel_size=9, stride=2)
self.digit_capsules = CapsuleLayer(num_capsules=NUM_CLASSES, num_route_nodes=32 * 6 * 6, in_channels=8,
self.digit_capsules = CapsuleLayer(num_capsules=NUM_CLASSES,
num_route_nodes=32 * 6 * 6,
in_channels=8,
out_channels=16)

self.decoder = nn.Sequential(
Expand All @@ -110,7 +128,9 @@ def forward(self, x, y=None):
if y is None:
# In all batches, get the most active capsule.
_, max_length_indices = classes.max(dim=1)
y = Variable(torch.sparse.torch.eye(NUM_CLASSES)).cuda().index_select(dim=0, index=max_length_indices.data)
y = Variable(torch.eye(NUM_CLASSES)).cuda().index_select(
dim=0,
index=max_length_indices.data)

reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1))

Expand Down Expand Up @@ -160,14 +180,20 @@ def forward(self, images, labels, classes, reconstructions):
confusion_meter = tnt.meter.ConfusionMeter(NUM_CLASSES, normalized=True)

train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'})
train_error_logger = VisdomPlotLogger('line', opts={'title': 'Train Accuracy'})
train_error_logger = VisdomPlotLogger('line',
opts={'title': 'Train Accuracy'})
test_loss_logger = VisdomPlotLogger('line', opts={'title': 'Test Loss'})
test_accuracy_logger = VisdomPlotLogger('line', opts={'title': 'Test Accuracy'})
confusion_logger = VisdomLogger('heatmap', opts={'title': 'Confusion matrix',
'columnnames': list(range(NUM_CLASSES)),
'rownames': list(range(NUM_CLASSES))})
test_accuracy_logger = VisdomPlotLogger('line',
opts={'title': 'Test Accuracy'})
confusion_logger = VisdomLogger('heatmap',
opts={'title': 'Confusion matrix',
'columnnames': list(
range(NUM_CLASSES)),
'rownames': list(
range(NUM_CLASSES))})
ground_truth_logger = VisdomLogger('image', opts={'title': 'Ground Truth'})
reconstruction_logger = VisdomLogger('image', opts={'title': 'Reconstruction'})
reconstruction_logger = VisdomLogger('image',
opts={'title': 'Reconstruction'})

capsule_loss = CapsuleLoss()

Expand All @@ -178,7 +204,8 @@ def get_iterator(mode):
labels = getattr(dataset, 'train_labels' if mode else 'test_labels')
tensor_dataset = tnt.dataset.TensorDataset([data, labels])

return tensor_dataset.parallel(batch_size=BATCH_SIZE, num_workers=4, shuffle=mode)
return tensor_dataset.parallel(batch_size=BATCH_SIZE, num_workers=4,
shuffle=mode)


def processor(sample):
Expand All @@ -187,7 +214,7 @@ def processor(sample):
data = augmentation(data.unsqueeze(1).float() / 255.0)
labels = torch.LongTensor(labels)

labels = torch.sparse.torch.eye(NUM_CLASSES).index_select(dim=0, index=labels)
labels = torch.eye(NUM_CLASSES).index_select(dim=0, index=labels)

data = Variable(data).cuda()
labels = Variable(labels).cuda()
Expand All @@ -213,8 +240,10 @@ def on_sample(state):


def on_forward(state):
meter_accuracy.add(state['output'].data, torch.LongTensor(state['sample'][1]))
confusion_meter.add(state['output'].data, torch.LongTensor(state['sample'][1]))
meter_accuracy.add(state['output'].data,
torch.LongTensor(state['sample'][1]))
confusion_meter.add(state['output'].data,
torch.LongTensor(state['sample'][1]))
meter_loss.add(state['loss'].data[0])


Expand Down Expand Up @@ -251,9 +280,12 @@ def on_end_epoch(state):
reconstruction = reconstructions.cpu().view_as(ground_truth).data

ground_truth_logger.log(
make_grid(ground_truth, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())
make_grid(ground_truth, nrow=int(BATCH_SIZE ** 0.5),
normalize=True, range=(0, 1)).numpy())
reconstruction_logger.log(
make_grid(reconstruction, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())
make_grid(reconstruction, nrow=int(BATCH_SIZE ** 0.5),
normalize=True, range=(0, 1)).numpy())


# def on_start(state):
# state['epoch'] = 327
Expand All @@ -264,4 +296,5 @@ def on_end_epoch(state):
engine.hooks['on_start_epoch'] = on_start_epoch
engine.hooks['on_end_epoch'] = on_end_epoch

engine.train(processor, get_iterator(True), maxepoch=NUM_EPOCHS, optimizer=optimizer)
engine.train(processor, get_iterator(True), maxepoch=NUM_EPOCHS,
optimizer=optimizer)
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
git+https://github.com/pytorch/tnt.git@master
tqdm
visdom