-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
149 lines (131 loc) · 5.62 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import json
import time
import yaml
from committee import net_dirs, nets
from cuda import *
from data import *
from net import *
from plot import plot_confusion_matrix
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm
config_path = Path('train_config.yaml')
config = None
with config_path.open('r') as f:
config = yaml.load(f)
H = {} # Training history and statistics
CUDA, device = get_cuda_if_available()
H['cuda'] = CUDA
train_path = config['train_path']
H['batch_size'] = config['batch_size']
train_dataset, validation_dataset = train_validation_split(
train_path, max_rows=config['data_num'], validation_num=config['validation_num'], pretransform=True)
# transform = transforms.Compose([
# transforms.Resize((32, 32)),
# transforms.Grayscale(),
# transforms.ToTensor()])
#train_dataset = ImageFolder('augment/out/', transform=transform)
#validation_dataset, _ = train_validation_split(train_path, pretransform=True)
train_loader = DataLoader(dataset=train_dataset,
batch_size=config['batch_size'], shuffle=True, num_workers=4, pin_memory=True)
validation_loader = DataLoader(
dataset=validation_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=1, pin_memory=True)
H['data_num'] = len(train_loader)
H['validation_num'] = len(validation_loader)
validation_classes = np.zeros(10)
for x, y in tqdm(validation_loader, desc='Validation stats'):
idx, counts = np.unique(y, return_counts=True)
validation_classes[idx] += counts
H['validation_classes'] = validation_classes.tolist()
for net_dir in net_dirs:
net_dir.mkdir(parents=True, exist_ok=True)
for net, net_dir in tqdm(zip(nets, net_dirs), desc='Net'):
H['net'] = type(net).__name__
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=config['learning_rate'])
H['optimizer'] = str(optimizer)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='max', verbose=True)
H['lr_scheduler'] = str(lr_scheduler)
criterion = nn.CrossEntropyLoss()
H['criterion'] = str(criterion)
H['epoch_num'] = config['epoch_num']
H['loss'] = []
H['train_acc'] = []
H['test_acc'] = []
start = time.process_time()
predicted_train = []
true_train = []
predicted_test = []
true_test = []
for epoch in tqdm(range(config['epoch_num']), desc='Total'):
def is_last_epoch():
return epoch + 1 == config['epoch_num']
running_loss = 0.0
for x, y in tqdm(train_loader, desc='Epoch ' + str(epoch)):
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
outputs = net(x)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()
running_loss += loss.item()
net.eval()
H['loss'].append(
running_loss / (config['data_num'] / config['batch_size']))
acc = 0
for x, y_true in tqdm(train_loader, desc='Train acc ' + str(epoch)):
x = x.to(device)
y_true = y_true.to(device)
y_pred = net(x).argmax(dim=1)
acc += y_true.eq(y_pred).sum()
if is_last_epoch():
true_train += y_true.to(torch.device('cpu')).numpy().tolist()
predicted_train += y_pred.to(torch.device('cpu')
).numpy().tolist()
acc = float(acc) / (len(train_loader) * config['batch_size'])
H['train_acc'].append(acc)
acc = 0
for x, y_true in tqdm(validation_loader, desc='Validation ' + str(epoch)):
x = x.to(device)
y_true = y_true.to(device)
y_pred = net(x).argmax(dim=1)
acc += y_true.eq(y_pred).sum()
if is_last_epoch():
true_test += y_true.to(torch.device('cpu')).numpy().tolist()
predicted_test += y_pred.to(torch.device('cpu')
).numpy().tolist()
net.train()
acc = float(acc) / config['validation_num']
H['test_acc'].append(acc)
lr_scheduler.step(acc)
net_state_path = net_dir.joinpath('net' + str(epoch) + '.state')
net_state_path.touch(exist_ok=True)
with net_state_path.open(mode='wb') as f:
torch.save(net.state_dict(), f)
if epoch % 100 == 99:
net_stats_path = net_dir.joinpath('stats' + str(epoch) + '.json')
net_stats_path.touch(exist_ok=True)
with net_stats_path.open('w') as f:
json.dump(H, f, indent=2, sort_keys=True)
end = time.process_time()
H['learning_duration'] = end - start
net_state_path = net_dir.joinpath('net.state')
net_state_path.touch(exist_ok=True)
with net_state_path.open(mode='wb') as f:
torch.save(net.state_dict(), f)
net_stats_path = net_dir.joinpath('stats.json')
net_stats_path.touch(exist_ok=True)
with net_stats_path.open('w') as f:
json.dump(H, f, indent=2, sort_keys=True)
cnf_matrix = confusion_matrix(true_train, predicted_train)
plot_confusion_matrix(cm=cnf_matrix, classes=list(range(10)),
title='Confusion matrix, without normalization',
filesave=str(net_dir.joinpath('train_cnf.png')))
cnf_matrix = confusion_matrix(true_test, predicted_test)
plot_confusion_matrix(cm=cnf_matrix, classes=list(range(10)),
title='Confusion matrix, without normalization',
filesave=str(net_dir.joinpath('test_cnf.png')))