-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
73 lines (60 loc) · 3.02 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import random
from data import ImageDetectionsField, TextField, RawField
from data import COCO, DataLoader
import evaluation
from models.MDSANet import Transformer, TransformerEncoder, TransformerDecoderLayer, ScaledDotProductAttention, TransformerEnsemble
import torch
from tqdm import tqdm
import argparse
import pickle
import numpy as np
import os
def predict_captions(model, dataloader, text_field):
import itertools
model.eval()
gen = {}
gts = {}
with tqdm(desc='Evaluation', unit='it', total=len(dataloader)) as pbar:
for it, (images, caps_gt) in enumerate(iter(dataloader)):
images = images.to(device)
with torch.no_grad():
out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 5, out_size=1)
caps_gen = text_field.decode(out, join_words=False)
for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
gen['%d_%d' % (it, i)] = [gen_i.strip(), ]
gts['%d_%d' % (it, i)] = gts_i
pbar.update()
gts = evaluation.PTBTokenizer.tokenize(gts)
gen = evaluation.PTBTokenizer.tokenize(gen)
scores, _ = evaluation.compute_scores(gts, gen)
return scores
if __name__ == '__main__':
device = torch.device('cuda')
parser = argparse.ArgumentParser(description='MDSANet')
parser.add_argument('--batch_size', type=int, default=10)
parser.add_argument('--workers', type=int, default=4)
parser.add_argument('--features_path', type=str, default='/home/data/coco_grid_feats2.hdf5')
parser.add_argument('--annotation_folder', type=str, default='./annotation')
parser.add_argument('--model_path', type=str)
args = parser.parse_args()
print('MDSANet Evaluation')
# Pipeline for image regions
image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=49, load_in_tmp=False)
# Pipeline for text
text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy',
remove_punctuation=True, nopoints=False)
# Create the dataset
dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder)
_, _, test_dataset = dataset.splits
text_field.vocab = pickle.load(open('./vocab.pkl', 'rb'))
# Model and dataloaders
encoder = TransformerEncoder(3, 0, attention_module=ScaledDotProductAttention)
decoder = TransformerDecoderLayer(len(text_field.vocab), 130, 3, text_field.vocab.stoi['<pad>'])
model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device)
data = torch.load(args.model_path)
model.load_state_dict(data['state_dict'])
dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()})
dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size, num_workers=args.workers)
scores = predict_captions(model, dict_dataloader_test, text_field)
print(scores)