forked from Shawn1993/cnn-text-classification-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
85 lines (72 loc) · 3.26 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import sys
import torch
import torch.autograd as autograd
import torch.nn.functional as F
def train(train_iter, dev_iter, model, args):
if args.cuda:
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
steps = 0
model.train()
for epoch in range(1, args.epochs+1):
for batch in train_iter:
feature, target = batch.text, batch.label
feature.data.t_(), target.data.sub_(1) # batch first, index align
if args.cuda:
feature, target = feature.cuda(), feature.cuda()
optimizer.zero_grad()
logit = model(feature)
loss = F.cross_entropy(logit, target)
loss.backward()
optimizer.step()
steps += 1
if steps % args.log_interval == 0:
corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
accuracy = corrects/batch.batch_size * 100.0
sys.stdout.write(
'\rBatch[{}] - loss: {:.6f} acc: {:.4f}%({}/{})'.format(steps,
loss.data[0],
accuracy,
corrects,
batch.batch_size))
if steps % args.test_interval == 0:
eval(dev_iter, model, args)
if steps % args.save_interval == 0:
if not os.path.isdir(args.save_dir): os.makedirs(args.save_dir)
save_prefix = os.path.join(args.save_dir, 'snapshot')
save_path = '{}_steps{}.pt'.format(save_prefix, steps)
torch.save(model, save_path)
def eval(data_iter, model, args):
model.eval()
corrects, avg_loss = 0, 0
for batch in data_iter:
feature, target = batch.text, batch.label
feature.data.t_(), target.data.sub_(1) # batch first, index align
if args.cuda:
feature, target = feature.cuda(), feature.cuda()
logit = model(feature)
loss = F.cross_entropy(logit, target, size_average=False)
avg_loss += loss.data[0]
corrects += (torch.max(logit, 1)
[1].view(target.size()).data == target.data).sum()
size = len(data_iter.dataset)
avg_loss = loss.data[0]/size
accuracy = corrects/size * 100.0
model.train()
print('\nEvaluation - loss: {:.6f} acc: {:.4f}%({}/{}) \n'.format(avg_loss,
accuracy,
corrects,
size))
def predict(text, model, text_field, label_feild):
assert isinstance(text, str)
model.eval()
text = text_field.tokenize(text)
text = text_field.preprocess(text)
text = [[text_field.vocab.stoi[x] for x in text]]
x = text_field.tensor_type(text)
x = autograd.Variable(x, volatile=True)
print(x)
output = model(x)
_, predicted = torch.max(output, 1)
return label_feild.vocab.itos[predicted.data[0][0]+1]