-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_experiment.py
136 lines (103 loc) · 4.58 KB
/
run_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse
import torch
import os
from components.engine.base_engine import Engine
from components.evaluator.syn_evaluator import SynEvaluator
from components.utils.serialization import load_model, load_yaml_config
from components.utils.serialization import save_final_predictions_from_dgs
from components.engine.utils import make_model_dir, fix_seed
from components.engine.utils import set_logging
import logging
use_cuda = torch.cuda.is_available()
def run_one_stage(mode, config_fn, seed=None):
config_d = load_yaml_config(config_fn)
if seed:
assert mode == 'train', 'Wrong mode!'
config_d['random_seed'] = int(seed)
if mode == 'train':
model_dirname = make_model_dir(config_d)
else:
model_fn = config_d["model_fn"]
model_dirname = os.path.split(model_fn)[0] # dir which holds the serialzed (trained) model
log_fn = config_d.get('log_fn', None)
if log_fn is None:
log_fn = os.path.join(model_dirname, '%s.log' % mode)
logger = set_logging(config_d.get('log_level', 'DEBUG'), log_fn)
logger.debug('MODE: %s, STAGE: %s', mode, config_d['stage'])
logger.debug('Model dir: %s', model_dirname)
# fix random number generators' seeds
fix_seed(config_d['random_seed'])
engine = Engine(config_d, mode)
engine.setup()
engine.run()
def run_pipeline(config_files, out_fn):
# In the pipeline mode we have two config files,
# separate for each stage. First, we do syntactic ordering,
# then surface realization.
# Stage 1
syn_config_fn = config_files[0]
syn_config_d = load_yaml_config(syn_config_fn)
log_fn = '%s.prediction_log' % (os.path.abspath(out_fn))
logger = set_logging(logging.DEBUG, log_fn)
fix_seed(syn_config_d['random_seed'])
logger.info('PIPELINE MODE')
logger.info('STAGE 1: syntactic ordering')
syn_config_d['output_fn'] = out_fn
syn_engine = Engine(syn_config_d, mode='predict')
syn_engine.setup()
depgraphs = syn_engine.run()
# Stage 2
print('\nSTAGE 2: morphological inflection')
morph_config_fn = config_files[1]
morph_config_d = load_yaml_config(morph_config_fn)
morph_config_d['output_fn'] = out_fn
morph_engine = Engine(morph_config_d, mode='predict')
morph_engine.setup()
# load the weights
morph_model_fname = morph_engine.config["model_fn"]
load_model(morph_engine.model_module, morph_model_fname)
if use_cuda:
morph_engine.model_module.cuda()
assert type(depgraphs) == dict
logger.info('Predicting on Syn outputs ...')
for data_split, digraphs in depgraphs.items():
for dg in digraphs: # can later examine digraphs, if needed
for node_id in dg.graph['node_order']:
form = morph_engine.nlgen.predict_from_dgnode(morph_engine.model_module,
morph_engine.data_module.vocab,
dg, node_id)
dg.node[node_id]['PRED_FORM'] = form
for data_split, dgs in depgraphs.items():
fname = '%s.%s.final.txt' % (out_fn, data_split)
logger.info('Saving Pipeline outputs (*%s*) to --> %s', data_split, fname)
targets, predicted_snts = save_final_predictions_from_dgs(dgs, fname)
if data_split == 'dev':
SynEvaluator.compute_final_scores(targets, predicted_snts)
def parse_args():
parser = argparse.ArgumentParser(description='Process cmd line options for the main script')
parser.add_argument('-c', '--config', nargs='+',
help='config file(s) to use for system setup')
parser.add_argument('-m', '--mode',
help='mode to run the system in', choices=['train', 'predict', 'pipeline'])
parser.add_argument('-s', '--seed', nargs='?',
help='random seed value')
parser.add_argument('-o', '--output', nargs='?',
help='output file storing pipeline predictions')
args = parser.parse_args()
return args
if __name__ == '__main__':
argvs = parse_args()
mode = argvs.mode
seed = argvs.seed
output_fn = argvs.output
config_files = argvs.config
num_config_files = len(config_files)
if num_config_files == 1:
assert mode in ['train', 'predict']
run_one_stage(mode, config_files[0], seed)
elif num_config_files == 2:
assert mode == 'pipeline', 'Wrong mode!'
assert output_fn is not None, 'Provide an output file to store pipeline predictions!'
run_pipeline(config_files, output_fn)
else:
raise NotImplementedError()