forked from benbogin/spider-schema-gnn
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
42 lines (31 loc) · 1.14 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import json
import shutil
import sys
from allennlp.commands import main
config_file = "train_configs/defaults.jsonnet"
# Use overrides to train on CPU
# overrides = json.dumps({"trainer": {"cuda_device": -1}})
# Use overrides to train on GPU
overrides = json.dumps({"trainer": {"cuda_device": 0}})
serialization_dir = "experiments/name_of_experiment"
# Training will fail if the serialization directory already
# has stuff in it. If you are running the same training loop
# over and over again for debugging purposes, it will.
# Hence we wipe it out in advance.
# BE VERY CAREFUL NOT TO DO THIS FOR ACTUAL TRAINING!
shutil.rmtree(serialization_dir, ignore_errors=True)
# Assemble the command into sys.argv
sys.argv = [
"allennlp", # command name, not used by main
"train",
config_file,
"-s", serialization_dir,
"--include-package", "dataset_readers.spider",
"--include-package", "models.semantic_parsing.spider_parser",
"-o", overrides,
]
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--toy', action='store_true', default=False,
help='If set, use small data; used for fast debugging.')
main()