forked from becxer/cnn-dailymail
-
Notifications
You must be signed in to change notification settings - Fork 19
/
make_datafiles.py
140 lines (106 loc) · 4.98 KB
/
make_datafiles.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import sys
import os
import hashlib
import struct
import subprocess
import collections
dm_single_close_quote = u'\u2019' # unicode
dm_double_close_quote = u'\u201d'
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence
all_train_urls = "url_lists/all_train.txt"
all_val_urls = "url_lists/all_val.txt"
all_test_urls = "url_lists/all_test.txt"
finished_files_dir = "cnn_dm"
# These are the number of .story files we expect there to be in cnn_stories_dir and dm_stories_dir
num_expected_cnn_stories = 92579
num_expected_dm_stories = 219506
def read_text_file(text_file):
lines = []
with open(text_file, "r") as f:
for line in f:
lines.append(line.strip())
return lines
def hashhex(s):
"""Returns a heximal formated SHA1 hash of the input string."""
h = hashlib.sha1()
h.update(s.encode())
return h.hexdigest()
def get_url_hashes(url_list):
return [hashhex(url) for url in url_list]
def fix_missing_period(line):
"""Adds a period to a line that is missing a period"""
if "@highlight" in line: return line
if line=="": return line
if line[-1] in END_TOKENS: return line
# print line[-1]
return line + " ."
def get_art_abs(story_file):
lines = read_text_file(story_file)
# Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; consequently they end up in the body of the article as run-on sentences)
lines = [fix_missing_period(line) for line in lines]
# Separate out article and abstract sentences
article_lines = []
highlights = []
next_is_highlight = False
for idx,line in enumerate(lines):
if line == "":
continue # empty line
elif line.startswith("@highlight"):
next_is_highlight = True
elif next_is_highlight:
highlights.append(line)
else:
article_lines.append(line)
# Make article into a single string
article = ' '.join(article_lines)
# Make abstract into a signle string
abstract = ' '.join(highlights)
return article, abstract
def write_to_bin(url_file, out_prefix):
"""Reads the .story files corresponding to the urls listed in the url_file and writes them to a out_file."""
print("Making bin file for URLs listed in %s..." % url_file)
url_list = read_text_file(url_file)
url_hashes = get_url_hashes(url_list)
story_fnames = [s+".story" for s in url_hashes]
num_stories = len(story_fnames)
with open(out_prefix + '.source', 'wt') as source_file, open(out_prefix + '.target', 'wt') as target_file:
for idx,s in enumerate(story_fnames):
if idx % 1000 == 0:
print("Writing story %i of %i; %.2f percent done" % (idx, num_stories, float(idx)*100.0/float(num_stories)))
# Look in the story dirs to find the .story file corresponding to this url
if os.path.isfile(os.path.join(cnn_stories_dir, s)):
story_file = os.path.join(cnn_stories_dir, s)
elif os.path.isfile(os.path.join(dm_stories_dir, s)):
story_file = os.path.join(dm_stories_dir, s)
else:
print("Error: Couldn't find story file %s in either story directories %s and %s." % (s, cnn_stories_dir, dm_stories_dir))
# Check again if stories directories contain correct number of files
print("Checking that the stories directories %s and %s contain correct number of files..." % (cnn_stories_dir, dm_stories_dir))
check_num_stories(cnn_stories_dir, num_expected_cnn_stories)
check_num_stories(dm_stories_dir, num_expected_dm_stories)
raise Exception("Stories directories %s and %s contain correct number of files but story file %s found in neither." % (cnn_stories_dir, dm_stories_dir, s))
# Get the strings to write to .bin file
article, abstract = get_art_abs(story_file)
# Write article and abstract to files
source_file.write(article + '\n')
target_file.write(abstract + '\n')
print("Finished writing files")
def check_num_stories(stories_dir, num_expected):
num_stories = len(os.listdir(stories_dir))
if num_stories != num_expected:
raise Exception("stories directory %s contains %i files but should contain %i" % (stories_dir, num_stories, num_expected))
if __name__ == '__main__':
if len(sys.argv) != 3:
print("USAGE: python make_datafiles.py <cnn_stories_dir> <dailymail_stories_dir>")
sys.exit()
cnn_stories_dir = sys.argv[1]
dm_stories_dir = sys.argv[2]
# Check the stories directories contain the correct number of .story files
check_num_stories(cnn_stories_dir, num_expected_cnn_stories)
check_num_stories(dm_stories_dir, num_expected_dm_stories)
# Create some new directories
if not os.path.exists(finished_files_dir): os.makedirs(finished_files_dir)
# Read the stories, do a little postprocessing then write to bin files
write_to_bin(all_test_urls, os.path.join(finished_files_dir, "test"))
write_to_bin(all_val_urls, os.path.join(finished_files_dir, "val"))
write_to_bin(all_train_urls, os.path.join(finished_files_dir, "train"))