-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathtest.py
124 lines (100 loc) · 4.42 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import os
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from utils import accuracy, ProgressMeter, AverageMeter
from repvgg import get_RepVGG_func_by_name
from utils import load_checkpoint, get_default_ImageNet_val_loader
parser = argparse.ArgumentParser(description='PyTorch ImageNet Test')
parser.add_argument('data', metavar='DIR', help='path to dataset')
parser.add_argument('mode', metavar='MODE', default='train', choices=['train', 'deploy'], help='train or deploy')
parser.add_argument('weights', metavar='WEIGHTS', help='path to the weights file')
parser.add_argument('-a', '--arch', metavar='ARCH', default='RepVGG-A0')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=100, type=int,
metavar='N',
help='mini-batch size (default: 100) for test')
parser.add_argument('-r', '--resolution', default=224, type=int,
metavar='R',
help='resolution (default: 224) for test')
def test():
args = parser.parse_args()
# if 'plus' in args.arch:
# from repvggplus import get_RepVGGplus_func_by_name
# model = get_RepVGGplus_func_by_name(args.arch)(deploy=args.mode=='deploy', use_checkpoint=False)
# else:
# repvgg_build_func = get_RepVGG_func_by_name(args.arch)
# model = repvgg_build_func(deploy=args.mode == 'deploy')
from mobileone import make_mobileone_s0, repvgg_model_convert
model = make_mobileone_s0(deploy=False)
model.eval()
if not torch.cuda.is_available():
print('using CPU, this will be slow')
use_gpu = False
else:
model = model.cuda()
use_gpu = True
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
if os.path.isfile(args.weights):
print("=> loading checkpoint '{}'".format(args.weights))
load_checkpoint(model, args.weights)
else:
print("=> no checkpoint found at '{}'".format(args.weights))
cudnn.benchmark = True
val_loader = get_default_ImageNet_val_loader(args)
print('validte model before mereing blocks ...')
validate(val_loader, model, criterion, use_gpu)
print('starting to convert to model (merge blocks)')
converted_model = repvgg_model_convert(model, do_copy=True)
converted_model.eval()
converted_model.cuda()
validate(val_loader, converted_model, criterion, use_gpu)
converted_model.to(torch.device('cpu'))
torch.save(converted_model.state_dict(), f'mobileone_depoly_model.pt')
print(f'convert model weight has saved into mobileone_depoly_model.pt')
def validate(val_loader, model, criterion, use_gpu):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, top1, top5],
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
if use_gpu:
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(images)
if isinstance(output, dict): # If the model being tested is a training-time RepVGGplus, which has auxiliary classifiers
output = output['main']
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0:
progress.display(i)
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg
if __name__ == '__main__':
test()
# python test.py /dataset/ILSVRC/Data/CLS-LOC deploy mobileone_s0_hello_best.pth.tar