forked from digantamisra98/Mish
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
136 lines (107 loc) · 4.62 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
'''
Script for demonstration of the Mish activation function
implemented for classification of Fashion MNIST dataset.
'''
# import basic libraries
from collections import OrderedDict
# import custom packages
import argparse
# import pytorch
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms
# import Mila activation function
from Mish.Torch.mish import Mish as mish
import Mish.Torch.functional as Func
# activation names constant
MISH = 'mish'
# create class for basic fully-connected deep neural network
class Classifier(nn.Module):
def __init__(self, activation = 'mish'):
super().__init__()
# get activation the function to use
self.activation = activation
# initialize layers
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 64)
self.fc4 = nn.Linear(64, 10)
def forward(self, x):
# make sure the input tensor is flattened
x = x.view(x.shape[0], -1)
# apply Mila activation function
if (self.activation == MISH):
x = Func.mish(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.log_softmax(self.fc4(x), dim=1)
return x
def main():
'''
Demonstrate Mish activation function to classify Fashion MNIST
'''
# Parse command line arguments
parser = argparse.ArgumentParser(description='Argument parser')
# Add argument to choose Mish activation function
parser.add_argument('--activation', action='store', default = MISH,
help='Activation function for demonstration.',
choices = [MISH])
# Add argument to choose the way to initialize the model
parser.add_argument('--model_initialization', action='store', default = 'class',
help='Model initialization mode: use custom class or use Sequential.',
choices = ['class', 'sequential'])
# Parse command line arguments
results = parser.parse_args()
activation = results.activation
model_initialization = results.model_initialization
# Define a transform
transform = transforms.Compose([transforms.ToTensor()])
# Download and load the training data for Fashion MNIST
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# Download and load the test data for Fashion MNIST
testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
print("Create model with {activation} function.\n".format(activation = activation))
# Initialize the model
if (model_initialization == 'class'):
# Initialize the model using defined Classifier class
model = Classifier(activation = activation)
else:
# Setup the activation function
if (activation == MISH):
activation_function = mish()
# Initialize the model using nn.Sequential
model = nn.Sequential(OrderedDict([
('fc1', nn.Linear(784, 256)),
('mila', activation_function), # use custom activation function
('fc2', nn.Linear(256, 128)),
('bn2', nn.BatchNorm1d(num_features=128)),
('relu2', nn.ReLU()),
('dropout', nn.Dropout(0.3)),
('fc3', nn.Linear(128, 64)),
('bn3', nn.BatchNorm1d(num_features=64)),
('relu3', nn.ReLU()),
('logits', nn.Linear(64, 10)),
('logsoftmax', nn.LogSoftmax(dim=1))]))
# Train the model
print("Training the model on Fashion MNIST dataset with {} activation function.\n".format(activation))
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.003)
epochs = 5
for e in range(epochs):
running_loss = 0
for images, labels in trainloader:
images = images.view(images.shape[0], -1)
log_ps = model(images)
loss = criterion(log_ps, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
else:
print(f"Training loss: {running_loss}")
if __name__ == '__main__':
main()