-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
102 lines (84 loc) · 3.82 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import sys
import argparse
from tqdm import tqdm
import torch
from torch.utils import data
from torchvision import transforms
cur_path = os.path.dirname(__file__)
sys.path.insert(0, os.path.join(cur_path, '..'))
import utils as ptutil
from model.lednet import LEDNet
from data import get_segmentation_dataset
from data.sampler import make_data_sampler
from utils.metric_seg import SegmentationMetric
def parse_args():
parser = argparse.ArgumentParser(description='Eval Segmentation.')
parser.add_argument('--batch-size', type=int, default=1,
help='Training mini-batch size')
parser.add_argument('--num-workers', '-j', dest='num_workers', type=int,
default=4, help='Number of data workers')
parser.add_argument('--dataset', type=str, default='citys',
help='Select dataset.')
parser.add_argument('--split', type=str, default='val',
help='Select val|test, evaluate in val or test data')
parser.add_argument('--mode', type=str, default='testval',
help='Select testval|val, w/o corp and with crop')
parser.add_argument('--base-size', type=int, default=1024,
help='base image size')
parser.add_argument('--crop-size', type=int, default=768,
help='crop image size')
parser.add_argument('--pretrained', type=str,
default='./LEDNet_iter_073600.pth',
help='Default Pre-trained model root.')
# device
parser.add_argument('--cuda', type=ptutil.str2bool, default='true',
help='Training with GPUs.')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--init-method', type=str, default="env://")
args = parser.parse_args()
return args
def validate(net, val_data, metric, device):
net.eval()
tbar = tqdm(val_data)
for i, (data, targets) in enumerate(tbar):
data, targets = data.to(device), targets.to(device)
with torch.no_grad():
predicts = net(data)
metric.update(targets, predicts)
return metric
if __name__ == '__main__':
args = parse_args()
device = torch.device('cpu')
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
distributed = num_gpus > 1
if args.cuda and torch.cuda.is_available():
torch.backends.cudnn.benchmark = False if args.mode == 'testval' else True
device = torch.device('cuda')
else:
distributed = False
if distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method=args.init_method)
# Load Model
model = LEDNet(19)
model.load_state_dict(torch.load(args.pretrained))
model.keep_shape = True if args.mode == 'testval' else False
model.to(device)
# testing data
input_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
])
data_kwargs = {'base_size': args.base_size, 'crop_size': args.crop_size, 'transform': input_transform}
val_dataset = get_segmentation_dataset(args.dataset, split=args.split, mode=args.mode, **data_kwargs)
sampler = make_data_sampler(val_dataset, False, distributed)
batch_sampler = data.BatchSampler(sampler=sampler, batch_size=args.batch_size, drop_last=False)
val_data = data.DataLoader(val_dataset, shuffle=False, batch_sampler=batch_sampler,
num_workers=args.num_workers)
metric = SegmentationMetric(val_dataset.num_class)
metric = validate(model, val_data, metric, device)
ptutil.synchronize()
pixAcc, mIoU = ptutil.accumulate_metric(metric)
if ptutil.is_main_process():
print('pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))