-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_pytorch_model.py
104 lines (79 loc) · 3.1 KB
/
train_pytorch_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models
from pre_resnet import resnet18
from load_cifar10 import train_loader, test_loader
import os
import tensorboardX
#是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epoch_num = 200
lr = 0.1
batch_size = 128
net = resnet18().to(device)
#loss
loss_func = nn.CrossEntropyLoss()
#optimizer
# optimizer = torch.optim.Adam(net.parameters(), lr= lr)
optimizer = torch.optim.SGD(net.parameters(), lr = lr,
momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=10,
gamma=0.9)
model_path = "models/pytorch_resnet18"
log_path = "logs/pytorch_resnet18"
if not os.path.exists(log_path):
os.makedirs(log_path)
if not os.path.exists(model_path):
os.makedirs(model_path)
writer = tensorboardX.SummaryWriter(log_path)
step_n = 0
for epoch in range(epoch_num):
print(" epoch is ", epoch)
net.train() #train BN dropout
for i, data in enumerate(train_loader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = loss_func(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, pred = torch.max(outputs.data, dim=1)
correct = pred.eq(labels.data).cpu().sum()
# print("epoch is ", epoch)
# print("train lr is ", optimizer.state_dict()["param_groups"][0]["lr"])
# print("train step", i, "loss is:", loss.item(),
# "mini-batch correct is:", 100.0 * correct / batch_size)
writer.add_scalar("train loss", loss.item(), global_step=step_n)
writer.add_scalar("train correct",
100.0 * correct.item() / batch_size, global_step=step_n)
im = torchvision.utils.make_grid(inputs)
writer.add_image("train im", im, global_step=step_n)
step_n += 1
torch.save(net.state_dict(), "{}/{}.pth".format(model_path,
epoch + 1))
scheduler.step()
sum_loss = 0
sum_correct = 0
for i, data in enumerate(test_loader):
net.eval()
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = loss_func(outputs, labels)
_, pred = torch.max(outputs.data, dim=1)
correct = pred.eq(labels.data).cpu().sum()
sum_loss += loss.item()
sum_correct += correct.item()
im = torchvision.utils.make_grid(inputs)
writer.add_image("test im", im, global_step=step_n)
test_loss = sum_loss * 1.0 / len(test_loader)
test_correct = sum_correct * 100.0 / len(test_loader) / batch_size
writer.add_scalar("test loss", test_loss, global_step=epoch + 1)
writer.add_scalar("test correct",
test_correct, global_step=epoch + 1)
print("epoch is", epoch + 1, "loss is:", test_loss,
"test correct is:", test_correct)
writer.close()