forked from Riroaki/CapsNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
87 lines (77 loc) · 3.05 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from torch import optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from capsnet import CapsNet, CapsuleLoss
# Check cuda availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def main():
# Load model
model = CapsNet().to(device)
criterion = CapsuleLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96)
# Load data
transform = transforms.Compose([
# shift by 2 pixels in either direction with zero padding.
transforms.RandomCrop(28, padding=2),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
DATA_PATH = './data'
BATCH_SIZE = 128
train_loader = DataLoader(
dataset=MNIST(root=DATA_PATH, download=True, train=True, transform=transform),
batch_size=BATCH_SIZE,
num_workers=4,
shuffle=True)
test_loader = DataLoader(
dataset=MNIST(root=DATA_PATH, download=True, train=False, transform=transform),
batch_size=BATCH_SIZE,
num_workers=4,
shuffle=True)
# Train
EPOCHES = 50
model.train()
for ep in range(EPOCHES):
batch_id = 1
correct, total, total_loss = 0, 0, 0.
for images, labels in train_loader:
optimizer.zero_grad()
images = images.to(device)
labels = torch.eye(10).index_select(dim=0, index=labels).to(device)
logits, reconstruction = model(images)
# Compute loss & accuracy
loss = criterion(images, labels, logits, reconstruction)
correct += torch.sum(
torch.argmax(logits, dim=1) == torch.argmax(labels, dim=1)).item()
total += len(labels)
accuracy = correct / total
total_loss += loss
loss.backward()
optimizer.step()
print('Epoch {}, batch {}, loss: {}, accuracy: {}'.format(ep + 1,
batch_id,
total_loss / batch_id,
accuracy))
batch_id += 1
scheduler.step(ep)
print('Total loss for epoch {}: {}'.format(ep + 1, total_loss))
# Eval
model.eval()
correct, total = 0, 0
for images, labels in test_loader:
# Add channels = 1
images = images.to(device)
# Categogrical encoding
labels = torch.eye(10).index_select(dim=0, index=labels).to(device)
logits, reconstructions = model(images)
pred_labels = torch.argmax(logits, dim=1)
correct += torch.sum(pred_labels == torch.argmax(labels, dim=1)).item()
total += len(labels)
print('Accuracy: {}'.format(correct / total))
# Save model
torch.save(model.state_dict(), './model/capsnet_ep{}_acc{}.pt'.format(EPOCHES, correct / total))
if __name__ == '__main__':
main()