-
Notifications
You must be signed in to change notification settings - Fork 6
/
hico_eval.py
137 lines (120 loc) · 5.94 KB
/
hico_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from __future__ import print_function
import sys
import os
import ipdb
import pickle
import h5py
import argparse
import numpy as np
from tqdm import tqdm
import dgl
import networkx as nx
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from model.model import AGRNN
from datasets.hico_constants import HicoConstants
from datasets.hico_dataset import HicoDataset, collate_fn
from datasets import metadata
import utils.io as io
def main(args):
# use GPU if available else revert to CPU
device = torch.device('cuda' if torch.cuda.is_available() and args.gpu else 'cpu')
print("Testing on", device)
# Load checkpoint and set up model
try:
# load checkpoint
checkpoint = torch.load(args.pretrained, map_location=device)
print('Checkpoint loaded!')
# set up model and initialize it with uploaded checkpoint
# ipdb.set_trace()
if not args.exp_ver:
args.exp_ver = args.pretrained.split("/")[-3]+"_"+args.pretrained.split("/")[-1].split("_")[-2]
data_const = HicoConstants(feat_type=checkpoint['feat_type'], exp_ver=args.exp_ver)
model = AGRNN(feat_type=checkpoint['feat_type'], bias=checkpoint['bias'], bn=checkpoint['bn'], dropout=checkpoint['dropout'], multi_attn=checkpoint['multi_head'], layer=checkpoint['layers'], diff_edge=checkpoint['diff_edge']) #2 )
# ipdb.set_trace()
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()
print('Constructed model successfully!')
except Exception as e:
print('Failed to load checkpoint or construct model!', e)
sys.exit(1)
print('Creating hdf5 file for predicting hoi dets ...')
if not os.path.exists(data_const.result_dir):
os.mkdir(data_const.result_dir)
pred_hoi_dets_hdf5 = os.path.join(data_const.result_dir, 'pred_hoi_dets.hdf5')
pred_hois = h5py.File(pred_hoi_dets_hdf5,'w')
test_dataset = HicoDataset(data_const=data_const, subset='test', test=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)
# for global_id in tqdm(test_list):
for data in tqdm(test_dataloader):
train_data = data
global_id = train_data['global_id'][0]
# img_name = train_data['img_name'][0]
det_boxes = train_data['det_boxes'][0]
roi_scores = train_data['roi_scores'][0]
roi_labels = train_data['roi_labels'][0]
node_num = train_data['node_num']
features = train_data['features']
spatial_feat = train_data['spatial_feat']
word2vec = train_data['word2vec']
# referencing
features, spatial_feat, word2vec = features.to(device), spatial_feat.to(device), word2vec.to(device)
outputs, attn, attn_lang = model(node_num, features, spatial_feat, word2vec, [roi_labels]) # !NOTE: it is important to set [roi_labels]
action_score = nn.Sigmoid()(outputs)
action_score = action_score.cpu().detach().numpy()
attn = attn.cpu().detach().numpy()
attn_lang = attn_lang.cpu().detach().numpy()
# save detection result
pred_hois.create_group(global_id)
det_data_dict = {}
h_idxs = np.where(roi_labels == 1)[0]
labeled_edge_list = np.cumsum(node_num - np.arange(len(h_idxs)) - 1)
labeled_edge_list[-1] = 0
for h_idx in h_idxs:
for i_idx in range(len(roi_labels)):
if i_idx <= h_idx:
continue
# import ipdb; ipdb.set_trace()
edge_idx = labeled_edge_list[h_idx-1] + (i_idx-h_idx-1)
# score = roi_scores[h_idx] * roi_scores[i_idx] * action_score[edge_idx] * (attn[h_idx][i_idx-1]+attn_lang[h_idx][i_idx-1])
score = roi_scores[h_idx] * roi_scores[i_idx] * action_score[edge_idx]
try:
hoi_ids = metadata.obj_hoi_index[roi_labels[i_idx]]
except Exception as e:
ipdb.set_trace()
for hoi_idx in range(hoi_ids[0]-1, hoi_ids[1]):
hoi_pair_score = np.concatenate((det_boxes[h_idx], det_boxes[i_idx], np.expand_dims(score[metadata.hoi_to_action[hoi_idx]], 0)), axis=0)
if str(hoi_idx+1).zfill(3) not in det_data_dict.keys():
det_data_dict[str(hoi_idx+1).zfill(3)] = hoi_pair_score[None,:]
else:
det_data_dict[str(hoi_idx+1).zfill(3)] = np.vstack((det_data_dict[str(hoi_idx+1).zfill(3)], hoi_pair_score[None,:]))
for k, v in det_data_dict.items():
pred_hois[global_id].create_dataset(k, data=v)
pred_hois.close()
def str2bool(arg):
arg = arg.lower()
if arg in ['yes', 'true', '1']:
return True
elif arg in ['no', 'false', '0']:
return False
else:
# raise argparse.ArgumentTypeError('Boolean value expected!')
pass
if __name__ == "__main__":
# set some arguments
parser = argparse.ArgumentParser(description='Evaluate the model')
parser.add_argument('--pretrained', '-p', type=str, default='checkpoints/v3_2048/epoch_train/checkpoint_300_epoch.pth', #default='checkpoints/v3_2048/epoch_train/checkpoint_300_epoch.pth',
help='Location of the checkpoint file: ./checkpoints/checkpoint_150_epoch.pth')
parser.add_argument('--gpu', type=str2bool, default='true',
help='use GPU or not: true')
# parser.add_argument('--feat_type', '--f_t', type=str, default='fc7', required=True, choices=['fc7', 'pool'],
# help='if using graph head, here should be pool: default(fc7) ')
parser.add_argument('--exp_ver', '--e_v', type=str, default=None,
help='the version of code, will create subdir in log/ && checkpoints/ ')
args = parser.parse_args()
# data_const = HicoConstants(feat_type=args.feat_type, exp_ver=args.exp_ver)
# inferencing
main(args)