forked from pcyin/tranX
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluation.py
66 lines (51 loc) · 2.46 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# coding=utf-8
from __future__ import print_function
import sys
import traceback
from tqdm import tqdm
def decode(examples, model, args, verbose=False, **kwargs):
## TODO: create decoder for each dataset
if verbose:
print('evaluating %d examples' % len(examples))
was_training = model.training
model.eval()
is_wikisql = args.parser == 'wikisql_parser'
decode_results = []
count = 0
for example in tqdm(examples, desc='Decoding', file=sys.stdout, total=len(examples)):
if is_wikisql:
hyps = model.parse(example.src_sent, context=example.table, beam_size=args.beam_size)
else:
hyps = model.parse(example.src_sent, context=None, beam_size=args.beam_size)
decoded_hyps = []
for hyp_id, hyp in enumerate(hyps):
got_code = False
try:
hyp.code = model.transition_system.ast_to_surface_code(hyp.tree)
got_code = True
decoded_hyps.append(hyp)
except:
if verbose:
print("Exception in converting tree to code:", file=sys.stdout)
print('-' * 60, file=sys.stdout)
print('Example: %s\nIntent: %s\nTarget Code:\n%s\nHypothesis[%d]:\n%s' % (example.idx,
' '.join(example.src_sent),
example.tgt_code,
hyp_id,
hyp.tree.to_string()), file=sys.stdout)
if got_code:
print()
print(hyp.code)
traceback.print_exc(file=sys.stdout)
print('-' * 60, file=sys.stdout)
count += 1
decode_results.append(decoded_hyps)
if was_training: model.train()
return decode_results
def evaluate(examples, parser, evaluator, args, verbose=False, return_decode_result=False, eval_top_pred_only=False):
decode_results = decode(examples, parser, args, verbose=verbose)
eval_result = evaluator.evaluate_dataset(examples, decode_results, fast_mode=eval_top_pred_only, args=args)
if return_decode_result:
return eval_result, decode_results
else:
return eval_result