-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_path_finding.py
118 lines (100 loc) · 3.61 KB
/
train_path_finding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Training and testing for sequence output tasks in bAbI.
Here we take task 19 'Path Finding' as an example
"""
import argparse
import numpy as np
import torch
from data_utils import get_babi_dataloaders
from ggsnn import GGSNN
from torch.optim import Adam
def main(args):
out_feats = {19: 6}
n_etypes = {19: 4}
train_dataloader, dev_dataloader, test_dataloaders = get_babi_dataloaders(
batch_size=args.batch_size,
train_size=args.train_num,
task_id=args.task_id,
q_type=-1,
)
model = GGSNN(
annotation_size=2,
out_feats=out_feats[args.task_id],
n_steps=5,
n_etypes=n_etypes[args.task_id],
max_seq_length=2,
num_cls=5,
)
opt = Adam(model.parameters(), lr=args.lr)
print(f"Task {args.task_id}")
print(f"Training set size: {len(train_dataloader.dataset)}")
print(f"Dev set size: {len(dev_dataloader.dataset)}")
# training and dev stage
for epoch in range(args.epochs):
model.train()
for i, batch in enumerate(train_dataloader):
g, ground_truths, seq_lengths = batch
loss, _ = model(g, seq_lengths, ground_truths)
opt.zero_grad()
loss.backward()
opt.step()
if epoch % 20 == 0:
print(f"Epoch {epoch}, batch {i} loss: {loss.data}")
if epoch % 20 != 0:
continue
dev_res = []
model.eval()
for g, ground_truths, seq_lengths in dev_dataloader:
with torch.no_grad():
preds = model(g, seq_lengths)
preds = preds.data.numpy().tolist()
ground_truths = ground_truths.data.numpy().tolist()
for i, p in enumerate(preds):
if p == ground_truths[i]:
dev_res.append(1.0)
else:
dev_res.append(0.0)
acc = sum(dev_res) / len(dev_res)
print(f"Epoch {epoch}, Dev acc {acc}")
# test stage
for i, dataloader in enumerate(test_dataloaders):
print(f"Test set {i} size: {len(dataloader.dataset)}")
test_acc_list = []
for dataloader in test_dataloaders:
test_res = []
model.eval()
for g, ground_truths, seq_lengths in dataloader:
with torch.no_grad():
preds = model(g, seq_lengths)
preds = preds.data.numpy().tolist()
ground_truths = ground_truths.data.numpy().tolist()
for i, p in enumerate(preds):
if p == ground_truths[i]:
test_res.append(1.0)
else:
test_res.append(0.0)
acc = sum(test_res) / len(test_res)
test_acc_list.append(acc)
test_acc_mean = np.mean(test_acc_list)
test_acc_std = np.std(test_acc_list)
print(
f"Mean of accuracy in 10 test datasets: {test_acc_mean}, std: {test_acc_std}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Gated Graph Sequence Neural Networks for sequential output tasks in "
"bAbI"
)
parser.add_argument(
"--task_id", type=int, default=19, help="task id from 1 to 20"
)
parser.add_argument(
"--train_num", type=int, default=250, help="Number of training examples"
)
parser.add_argument("--batch_size", type=int, default=10, help="batch size")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument(
"--epochs", type=int, default=200, help="number of training epochs"
)
args = parser.parse_args()
main(args)