-
Notifications
You must be signed in to change notification settings - Fork 0
/
seq_translate.py
87 lines (68 loc) · 4.49 KB
/
seq_translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# usage: python seq_translate.py model_prefix --start start_iteration --end end_iteration --gap interval --dataset dataset
import argparse
import sys
import os
import subprocess
import operator
import time
from libs.constants import Datasets
from libs.utility.translate import de_bpe, get_bleu
TestDatasets = {'enfr_bpe'}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('model_prefix', nargs='?', default='model/complete/enfr.npz',
help='The prefix of nmt model path, default is "%(default)s"')
parser.add_argument('--start', action="store", metavar="index", dest="start", type=int, default=1,
help='The starting index of saved model to test, default is %(default)s')
parser.add_argument('--end', action="store", metavar="index", dest="end", type=int, default=10,
help='The ending index of saved model to test, default is %(default)s')
parser.add_argument('--gap', action="store", metavar="index", dest="interval", type=int, default=10000,
help='The interval between two consecutive tested models\' indexes, default is %(default)s')
parser.add_argument('--result', action='store', metavar='filename', dest='result_file', type=str,
default='trans_result.tsv', help='Target small train file, default is %(default)s')
parser.add_argument('--beam', action="store", metavar="beam_size", dest="beam_size", type=int, default=4,
help='The beam size for translation, default is 4')
parser.add_argument('--dataset', action='store', dest='dataset', default='en-fr_bpe',
help='Dataset, default is "%(default)s"')
args = parser.parse_args()
if args.result_file == 'trans_result.tsv':
model_file_name = os.path.split(args.model_prefix)[-1]
args.result_file = './translated/complete/{}_bs{}.txt'.format(os.path.splitext(model_file_name)[0],
args.beam_size)
else:
model_file_name = os.path.split(args.result_file)[-1]
print args
bleus = {}
train1, train2, small1, small2, dev1, dev2, dev3, test1, test2, dic1, dic2 = Datasets[args.dataset]
zhen = 'zh-en' in args.dataset and 'wmt17' not in args.dataset
for idx in xrange(args.start, args.end + 1):
trans_model_file = '%s.iter%d.npz' % (os.path.splitext(args.model_prefix)[0], idx * args.interval)
trans_result_file = '%s.iter%d.txt' % (os.path.splitext(args.result_file)[0], idx * args.interval)
start_time = time.time()
if not os.path.exists(trans_result_file):
exec_str = 'python translate_single.py -b 32 {} -k {} -p 1 -n {} {} {} {} {} {}\n'.format(
'-zhen' if zhen else '', args.beam_size, trans_model_file, './data/dic/{}'.format(dic1), './data/dic/{}'.format(dic2),
'./data/test/{}'.format(test1), trans_result_file, './data/dic/{}'.format(dev1) if zhen else '',
)
print 'Translate model {} '.format(trans_model_file)
print exec_str
pl_output = subprocess.Popen(exec_str, shell=True, stdout=subprocess.PIPE).stdout.read()
end_time = time.time()
m, s = divmod(end_time - start_time, 60)
if 'tc' in args.dataset: # first de-truecase, then de-bpe
exec_str = 'perl scripts/moses/detruecase.perl < {} > {}.detc'.format(trans_result_file, trans_result_file)
pl_output = subprocess.Popen(exec_str, shell=True, stdout=subprocess.PIPE).stdout.read()
trans_result_file = '{}.detc'.format(trans_result_file)
if 'bpe' in args.dataset:
with open('{}.bpe'.format(trans_result_file), 'w') as fout:
fout.write(de_bpe(open(trans_result_file, 'r').read()))
trans_result_file = '{}.bpe'.format(trans_result_file)
bleus[idx] = get_bleu('./data/test/{}'.format(test2), trans_result_file, zhen = zhen)
print 'model %s, bleu %.2f, time %02d:%02d' % (idx * args.interval, bleus[idx], m, s)
args.result_file = './translated/complete/{}_s{}_e{}_bs{}.txt'.format(os.path.splitext(model_file_name)[0], args.start,
args.end, args.beam_size)
bleu_array = sorted(bleus.items(), key=operator.itemgetter(0), reverse=False)
with open(args.result_file, 'w') as fout:
fout.write('\n'.join([str(idx) + '\t' + str(score) for (idx, score) in bleu_array]))
if __name__ == '__main__':
main()