-
Notifications
You must be signed in to change notification settings - Fork 181
/
train.py
85 lines (68 loc) · 2.69 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import time
import torch
import torch.nn as nn
import utils
from torch.autograd import Variable
def instance_bce_with_logits(logits, labels):
assert logits.dim() == 2
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels)
loss *= labels.size(1)
return loss
def compute_score_with_logits(logits, labels):
logits = torch.max(logits, 1)[1].data # argmax
one_hots = torch.zeros(*labels.size()).cuda()
one_hots.scatter_(1, logits.view(-1, 1), 1)
scores = (one_hots * labels)
return scores
def train(model, train_loader, eval_loader, num_epochs, output):
utils.create_dir(output)
optim = torch.optim.Adamax(model.parameters())
logger = utils.Logger(os.path.join(output, 'log.txt'))
best_eval_score = 0
for epoch in range(num_epochs):
total_loss = 0
train_score = 0
t = time.time()
for i, (v, b, q, a) in enumerate(train_loader):
v = Variable(v).cuda()
b = Variable(b).cuda()
q = Variable(q).cuda()
a = Variable(a).cuda()
pred = model(v, b, q, a)
loss = instance_bce_with_logits(pred, a)
loss.backward()
nn.utils.clip_grad_norm(model.parameters(), 0.25)
optim.step()
optim.zero_grad()
batch_score = compute_score_with_logits(pred, a.data).sum()
total_loss += loss.data[0] * v.size(0)
train_score += batch_score
total_loss /= len(train_loader.dataset)
train_score = 100 * train_score / len(train_loader.dataset)
model.train(False)
eval_score, bound = evaluate(model, eval_loader)
model.train(True)
logger.write('epoch %d, time: %.2f' % (epoch, time.time()-t))
logger.write('\ttrain_loss: %.2f, score: %.2f' % (total_loss, train_score))
logger.write('\teval score: %.2f (%.2f)' % (100 * eval_score, 100 * bound))
if eval_score > best_eval_score:
model_path = os.path.join(output, 'model.pth')
torch.save(model.state_dict(), model_path)
best_eval_score = eval_score
def evaluate(model, dataloader):
score = 0
upper_bound = 0
num_data = 0
for v, b, q, a in iter(dataloader):
v = Variable(v, volatile=True).cuda()
b = Variable(b, volatile=True).cuda()
q = Variable(q, volatile=True).cuda()
pred = model(v, b, q, None)
batch_score = compute_score_with_logits(pred, a.cuda()).sum()
score += batch_score
upper_bound += (a.max(1)[0]).sum()
num_data += pred.size(0)
score = score / len(dataloader.dataset)
upper_bound = upper_bound / len(dataloader.dataset)
return score, upper_bound