-
Notifications
You must be signed in to change notification settings - Fork 4
/
train_subgraph_aligner.py
84 lines (66 loc) · 3.06 KB
/
train_subgraph_aligner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import sys
from amr_utils.alignments import load_from_json
from amr_utils.amr_readers import AMR_Reader
from evaluate.utils import evaluate, perplexity, evaluate_duplicates
from models.subgraph_model import Subgraph_Model
from nlp_data import add_nlp_data
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-T','--train', required=True, type=str,
help='train AMR file (must have nlp data)')
parser.add_argument('-t','--test', type=str, nargs=2,
help='2 arguments: test AMR file and gold alignments file (must have nlp data)')
parser.add_argument('--iter', type=int, default=3,
help='number of iterations to train model')
parser.add_argument('--save-model', type=str,
help='params file to store the trained model')
parser.add_argument('--load-model', type=str,
help='params file to load model')
args = parser.parse_args()
def report_progress(amr_file, alignments, reader, epoch=None):
epoch = '' if epoch is None else f'.epoch{epoch}'
align_file = amr_file.replace('.txt', '') + f'.subgraph_alignments{epoch}.json'
print(f'Writing subgraph alignments to: {align_file}')
reader.save_alignments_to_json(align_file, alignments)
def main():
amr_file = args.train
reader = AMR_Reader()
amrs = reader.load(amr_file, remove_wiki=True)
add_nlp_data(amrs, amr_file)
eval_amr_file, eval_amrs, gold_eval_alignments = None, None, None
if args.test:
eval_amr_file, eval_align_file = args.test
eval_amrs = reader.load(eval_amr_file, remove_wiki=True)
add_nlp_data(eval_amrs, eval_amr_file)
gold_eval_alignments = load_from_json(eval_align_file, eval_amrs, unanonymize=True)
eval_amr_ids = {amr.id for amr in eval_amrs}
amrs = [amr for amr in amrs if amr.id not in eval_amr_ids]
# amrs = amrs[:1000]
if args.load_model:
print('Loading model from:', args.load_model)
align_model = Subgraph_Model.load_model(args.load_model)
else:
align_model = Subgraph_Model(amrs, align_duplicates=True)
iters = args.iter
alignments = None
for i in range(iters):
print(f'Epoch {i}: Training data')
alignments = align_model.align_all(amrs)
align_model.update_parameters(amrs, alignments)
perplexity(align_model, amrs, alignments)
report_progress(amr_file, alignments, reader, epoch=i)
print()
if eval_amrs:
print(f'Epoch {i}: Evaluation data')
eval_alignments = align_model.align_all(eval_amrs)
perplexity(align_model, eval_amrs, eval_alignments)
evaluate(eval_amrs, eval_alignments, gold_eval_alignments)
evaluate_duplicates(eval_amrs, eval_alignments, gold_eval_alignments)
report_progress(eval_amr_file, eval_alignments, reader, epoch=i)
print()
report_progress(amr_file, alignments, reader)
if args.save_model:
align_model.save_model(args.save_model)
print('Saving model to:', args.save_model)
if __name__=='__main__':
main()