forked from zhangxu0307/Ind-RNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
66 lines (52 loc) · 2.25 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch as th
import torchvision
from torch.autograd import Variable
from torch import nn
from torch import optim
from torchvision import datasets
import torchvision.transforms as transforms
from sequential_mnist import loadSequentialMNIST
from IndRNN import IndRNNModel
def train(model, batchSize, epoch, useCuda = False):
if useCuda:
model = model.cuda()
optimizer = optim.RMSprop(model.parameters(), lr=0.1, momentum=0.9)
ceriation = nn.CrossEntropyLoss()
trainLoader, testLoader = loadSequentialMNIST(batchSize=batchSize)
for i in range(epoch):
# trainning
sum_loss = 0
for batch_idx, (x, target) in enumerate(trainLoader):
optimizer.zero_grad()
if useCuda:
x, target = x.cuda(), target.cuda()
x, target = Variable(x), Variable(target)
out = model(x, batchSize)
loss = ceriation(out, target)
sum_loss += loss.data[0]
loss.backward()
optimizer.step()
if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(trainLoader):
print('==>>> epoch: {}, batch index: {}, train loss: {:.6f}'.format( i, batch_idx + 1, sum_loss/batch_idx))
# testing
correct_cnt, sum_loss = 0, 0
total_cnt = 0
for batch_idx, (x, target) in enumerate(testLoader):
if useCuda:
x, targe = x.cuda(), target.cuda()
x, target = Variable(x, volatile=True), Variable(target, volatile=True)
out = model(x, batchSize)
loss = ceriation(out, target)
_, pred_label = th.max(out.data, 1)
total_cnt += x.data.size()[0]
correct_cnt += (pred_label == target.data).sum()
# smooth average
if (batch_idx + 1) % 100 == 0 or (batch_idx + 1) == len(testLoader):
print('==>>> epoch: {}, batch index: {}, test loss: {:.6f}, acc: {:.3f}'.format(
i, batch_idx + 1, sum_loss/batch_idx, correct_cnt * 1.0 / total_cnt))
th.save(model.state_dict(), model.name())
if __name__ == '__main__':
epoch = 10
batchSize = 128
model = IndRNNModel(inputDim=1, hiddenNum=256, outputDim=10, layerNum=1)
train(model, batchSize, epoch, useCuda=False)