-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
144 lines (94 loc) · 3.84 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn as nn
import torchvision
from vggnet import VGGNet
from resnet import ResNet18
from mobilenetv1 import mobilenetv1_small
from inceptionMolule import InceptionNetSmall
from base_resnet import resnet
from resnetV1 import resnet as resnetV1
from pre_resnet import pytorch_resnet18
from load_cifar10 import train_loader, test_loader
import os
import tensorboardX
#是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epoch_num = 200
lr = 0.1
batch_size = 128
# net = pytorch_resnet18().to(device)
# net = VGGNet().to(device)
net = VGGNet()
net = nn.DataParallel(net)
net = net.to(device)
#loss
loss_func = nn.CrossEntropyLoss()
#optimizer
optimizer = torch.optim.Adam(net.parameters(), lr= lr)
# optimizer = torch.optim.SGD(net.parameters(), lr = lr,
# momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=5,
gamma=0.9)
model_path = "models/pytorch_vgg"
log_path = "logs/pytorch_vgg"
if not os.path.exists(log_path):
os.makedirs(log_path)
if not os.path.exists(model_path):
os.makedirs(model_path)
writer = tensorboardX.SummaryWriter(log_path)
step_n = 0
if __name__ == '__main__':
for epoch in range(epoch_num):
print(" epoch is ", epoch)
print(device)
net.train() # train BN dropout
for i, data in enumerate(train_loader):
inputs, labels = data
# print(data)
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = loss_func(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print("step",i,"loss is:",loss.item())
_, pred = torch.max(outputs.data, dim=1)
correct = pred.eq(labels.data).cpu().sum()
print("train step", i, "loss is:", loss.item(),
"mini-batch correct is:", 100.0 * correct / batch_size)
# print("epoch is ", epoch)
# print("train lr is ", optimizer.state_dict()["param_groups"][0]["lr"])
# print("train step", i, "loss is:", loss.item(),
# "mini-batch correct is:", 100.0 * correct / batch_size)
writer.add_scalar("train loss", loss.item(), global_step=step_n)
writer.add_scalar("train correct",
100.0 * correct.item() / batch_size, global_step=step_n)
im = torchvision.utils.make_grid(inputs)
writer.add_image("train im", im, global_step=step_n)
step_n += 1
torch.save(net.state_dict(), "{}/{}.pth".format(model_path,
epoch + 1))
scheduler.step()
sum_loss = 0
sum_correct = 0
for i, data in enumerate(test_loader):
net.eval()
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = loss_func(outputs, labels)
_, pred = torch.max(outputs.data, dim=1)
correct = pred.eq(labels.data).cpu().sum()
sum_loss += loss.item()
sum_correct += correct.item()
im = torchvision.utils.make_grid(inputs)
writer.add_image("test im", im, global_step=step_n)
test_loss = sum_loss * 1.0 / len(test_loader)
test_correct = sum_correct * 100.0 / len(test_loader) / batch_size
writer.add_scalar("test loss", test_loss, global_step=epoch + 1)
writer.add_scalar("test correct",
test_correct, global_step=epoch + 1)
print("epoch is", epoch + 1, "loss is:", test_loss,
"test correct is:", test_correct)
writer.close()