-
Notifications
You must be signed in to change notification settings - Fork 91
/
train_fMNIST.py
126 lines (110 loc) · 5.11 KB
/
train_fMNIST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from tqdm import tqdm
from models import ConvAngularPen, ConvBaseline
from plotting import plot
def main():
train_ds = datasets.FashionMNIST(
root = './data',
train=True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=(0.1307,), std=(0.3081,))]),
download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_ds,
batch_size=args.batch_size,
shuffle=True)
example_loader = torch.utils.data.DataLoader(dataset=train_ds,
batch_size=args.batch_size,
shuffle=False)
os.makedirs('./figs', exist_ok=True)
print('Training Baseline model....')
model_baseline = train_baseline(train_loader)
bl_embeds, bl_labels = get_embeds(model_baseline, example_loader)
plot(bl_embeds, bl_labels, fig_path='./figs/baseline.png')
print('Saved Baseline figure')
del model_baseline, bl_embeds, bl_labels
loss_types = ['cosface', 'sphereface', 'arcface']
for loss_type in loss_types:
print('Training {} model....'.format(loss_type))
model_am = train_am(train_loader, loss_type)
am_embeds, am_labels = get_embeds(model_am, example_loader)
plot(am_embeds, am_labels, fig_path='./figs/{}.png'.format(loss_type))
print('Saved {} figure'.format(loss_type))
del model_am, am_embeds, am_labels
def train_baseline(train_loader):
model = ConvBaseline().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
total_step = len(train_loader)
for epoch in tqdm(range(args.num_epochs)):
for i, (feats, labels) in enumerate(tqdm(train_loader)):
feats = feats.to(device)
labels = labels.to(device)
out = model(feats)
loss = criterion(out, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print('Baseline: Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, args.num_epochs, i+1, total_step, loss.item()))
if((epoch+1) % 8 == 0):
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr']/4
return model.cpu()
def train_am(train_loader, loss_type):
model = ConvAngularPen(loss_type=loss_type).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
total_step = len(train_loader)
for epoch in tqdm(range(args.num_epochs)):
for i, (feats, labels) in enumerate(tqdm(train_loader)):
feats = feats.to(device)
labels = labels.to(device)
loss = model(feats, labels=labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print('{}: Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(loss_type, epoch+1, args.num_epochs, i+1, total_step, loss.item()))
if((epoch+1) % 8 == 0):
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr']/4
return model.cpu()
def get_embeds(model, loader):
model = model.to(device).eval()
full_embeds = []
full_labels = []
with torch.no_grad():
for i, (feats, labels) in enumerate(loader):
feats = feats[:100].to(device)
full_labels.append(labels[:100].cpu().detach().numpy())
embeds = model(feats, embed=True)
full_embeds.append(F.normalize(embeds.detach().cpu()).numpy())
model = model.cpu()
return np.concatenate(full_embeds), np.concatenate(full_labels)
def parse_args():
parser = argparse.ArgumentParser(description='Run Angular Penalty and Baseline experiments in fMNIST')
parser.add_argument('--batch-size', type=int, default=512,
help='input batch size for training (default: 512)')
parser.add_argument('--num-epochs', type=int, default=40,
help='Number of epochs to train each model for (default: 20)')
parser.add_argument('--seed', type=int, default=1234,
help='Random seed (default: 1234)')
parser.add_argument('--lr', type=float, default=0.01,
help='learning rate (default: 0.01)')
parser.add_argument('--use-cuda', action='store_true', default=False,
help='enables CUDA training')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
use_cuda = args.use_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
torch.manual_seed(args.seed)
main()