-
Notifications
You must be signed in to change notification settings - Fork 12
/
validation.py
49 lines (41 loc) · 1.9 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# coding:utf8
from torch.autograd import Variable
import torch as t
import os
from tqdm import tqdm
from utils import AverageMeter, calculate_accuracy
def val_epoch(epoch, data_loader, model, criterion, opt, logger, path=None, best_avg=None):
# print('validation at epoch {}'.format(epoch))
model.eval()
losses = AverageMeter()
accuracies = AverageMeter()
for i, (inputs, targets) in tqdm(enumerate(data_loader), total=len(data_loader)):
if opt.cuda:
targets = targets.cuda(async=True)
# inputs = Variable(inputs.cuda(), volatile=True)
with t.no_grad():
inputs = Variable(inputs.cuda())
# targets = Variable(targets, volatile=True)
sv_targets = targets.unsqueeze(0).repeat(opt.frames, 1).permute(1, 0).contiguous().view(-1)
targets = Variable(targets)
sv_targets = Variable(sv_targets)
outputs, sv_out = model(inputs)
loss = criterion(outputs, targets)
sv_loss = criterion(sv_out, sv_targets)
acc = calculate_accuracy(outputs, targets)
step = i + 1 + (epoch - 1) * len(data_loader)
if step % opt.step_every_summary == 0:
# logger.log_value('Val_Loss', loss.data[0], step)
if logger is not None:
logger.log_value('Val_Loss', loss.item(), step)
logger.log_value('SV_Val_Loss', sv_loss.item(), step)
logger.log_value('Val_Accuracy', acc, step)
# print("val.size:", inputs.size(0))
losses.update(loss.item(), inputs.size(0))
accuracies.update(acc, inputs.size(0))
if logger is not None:
logger.log_value('Val_Accuracy_avg_epoch', accuracies.avg, epoch)
if best_avg is not None and accuracies.avg > best_avg and path:
save_file_path = os.path.join(path, 'best.model')
t.save(model.state_dict(), save_file_path)
return losses.avg