forked from LiberAI/NSpM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
split_in_train_dev_test.py
executable file
·79 lines (65 loc) · 2.98 KB
/
split_in_train_dev_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/usr/bin/env python
"""
Neural SPARQL Machines - Split into train, dev, and test sets.
'SPARQL as a Foreign Language' by Tommaso Soru and Edgard Marx et al., SEMANTiCS 2017
https://arxiv.org/abs/1708.07624
Version 1.0.0
"""
import argparse
import random
import os
import io
TRAINING_PERCENTAGE = 80
TEST_PERCENTAGE = 10
DEV_PERCENTAGE = 10
if __name__ == '__main__':
parser = argparse.ArgumentParser()
requiredNamed = parser.add_argument_group('required named arguments')
requiredNamed.add_argument('--lines', dest='lines', metavar='lines',
help='total number of lines (wc -l <file>)', required=True)
requiredNamed.add_argument('--dataset', dest='dataset',
metavar='dataset.sparql', help='sparql dataset file', required=True)
args = parser.parse_args()
lines = int(args.lines)
dataset_file = os.path.splitext(args.dataset)[0]
sparql_file = dataset_file + '.sparql'
en_file = dataset_file + '.en'
random.seed()
test_and_dev_percentage = sum([TEST_PERCENTAGE, DEV_PERCENTAGE])
number_of_test_and_dev_examples = int(
lines * test_and_dev_percentage / 100)
number_of_dev_examples = int(
number_of_test_and_dev_examples * DEV_PERCENTAGE / test_and_dev_percentage)
dev_and_test = random.sample(range(lines), number_of_test_and_dev_examples)
dev = random.sample(dev_and_test, number_of_dev_examples)
with io.open(sparql_file, encoding="utf-8") as original_sparql, io.open(en_file, encoding="utf-8") as original_en:
sparql = original_sparql.readlines()
english = original_en.readlines()
dev_sparql_lines = []
dev_en_lines = []
train_sparql_lines = []
train_en_lines = []
test_sparql_lines = []
test_en_lines = []
for i in range(len(sparql)):
sparql_line = sparql[i]
en_line = english[i]
if i in dev_and_test:
if i in dev:
dev_sparql_lines.append(sparql_line)
dev_en_lines.append(en_line)
else:
test_sparql_lines.append(sparql_line)
test_en_lines.append(en_line)
else:
train_sparql_lines.append(sparql_line)
train_en_lines.append(en_line)
with io.open('train.sparql', 'w', encoding="utf-8") as train_sparql, io.open('train.en', 'w', encoding="utf-8") as train_en, \
io.open('dev.sparql', 'w', encoding="utf-8") as dev_sparql, io.open('dev.en', 'w', encoding="utf-8") as dev_en, \
io.open('test.sparql', 'w', encoding="utf-8") as test_sparql, io.open('test.en', 'w', encoding="utf-8") as test_en:
train_sparql.writelines(train_sparql_lines)
train_en.writelines(train_en_lines)
dev_sparql.writelines(dev_sparql_lines)
dev_en.writelines(dev_en_lines)
test_sparql.writelines(test_sparql_lines)
test_en.writelines(test_en_lines)