forked from qq456cvb/Point-Transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_cls.py
72 lines (61 loc) · 2.55 KB
/
test_cls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from dataset import ModelNetDataLoader
import argparse
import numpy as np
import os
import torch
import datetime
import logging
from pathlib import Path
from tqdm import tqdm
import sys
import provider
import importlib
import shutil
import hydra
import omegaconf
def test(model, loader, num_class=40):
mean_correct = []
class_acc = np.zeros((num_class,3))
for j, data in tqdm(enumerate(loader), total=len(loader)):
points, target = data
target = target[:, 0]
points, target = points.cuda(), target.cuda()
classifier = model.eval()
pred = classifier(points)
pred_choice = pred.data.max(1)[1]
for cat in np.unique(target.cpu()):
classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum()
class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0])
class_acc[cat,1]+=1
correct = pred_choice.eq(target.long().data).cpu().sum()
mean_correct.append(correct.item()/float(points.size()[0]))
class_acc[:,2] = class_acc[:,0]/ class_acc[:,1]
class_acc = np.mean(class_acc[:,2])
instance_acc = np.mean(mean_correct)
return instance_acc, class_acc
@hydra.main(version_base=None, config_path='config', config_name='cls')
def main(args):
omegaconf.OmegaConf.set_struct(args, False)
'''HYPER PARAMETER'''
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
os.environ["HYDRA_FULL_ERROR"] = "1"
logger = logging.getLogger(__name__)
'''DATA LOADING'''
logger.info('Load dataset ...')
DATA_PATH = hydra.utils.to_absolute_path('modelnet40_normal_resampled/')
TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test', normal_channel=args.normal)
testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4)
'''MODEL LOADING'''
args.num_class = 40
args.input_dim = 6 if args.normal else 3
# shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.')
classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformerCls')(
args).cuda()
criterion = torch.nn.CrossEntropyLoss()
checkpoint = torch.load('best_model.pth')
classifier.load_state_dict(checkpoint['model_state_dict'])
with torch.no_grad():
instance_acc, class_acc = test(classifier.eval(), testDataLoader)
logger.info('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
if __name__ == '__main__':
main()