-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcyk.py
125 lines (110 loc) · 5.04 KB
/
cyk.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from utils import is_rare
from utils import get_counts
import sys
import operator
import logging
import json
import timeit
"""
Usage:
python cyk.py cfg_vert.counts parse_dev.dat cfg.counts > [output_file]
Using the counts file from the vertical markovization performed on the training data,
calculates the maximum likelihood estimates for the PCFG rules, then generates the tree with
max probability per line of parse_dev.dat represented in the JSON tree format.
The difference is we are using cfg_vert.counts (resulting from the vertical markovization), and cfg.counts
(the original count file) is used in the case of a tree not found.
Uses functions in utils.py, namely to organize unary, binary and nonterminal counts into dictionaries.
First run
python count_cfg_freq.py parse_train_vert.dat > cfg_vert.counts
python relabel_rare.py cfg.counts parse_train.dat
python count_cfg_freq.py parse_train_vert.dat > cfg_vert.counts
to replace words with count < 5 with _RARE_
"""
def cyk(words, binary_count, unary_count, nonterminal_count):
# Initialize lookup chart with unary rules on diagonal and
# their respective probabilities in the probability matrix
n = len(words)
chart = [[{} for i in range(n+1)] for j in range(n)]
backpointers = [[{} for i in range(n+1)] for j in range(n)]
for i in range(n):
word = words[i]
# Check if word is rare, act accordingly
if is_rare(word, unary_count):
word = "_RARE_"
for X in unary_count:
if word in unary_count[X]:
q = float(unary_count[X][word])/float(nonterminal_count[X])
chart[i][i+1].update({X:q})
backpointers[i][i+1][X] = [words[i]]
# Building bottom-up from the unary rules on the diagonal, connect non-terminals
# into binary rules.
for l in range(2, n+1):
for i in range(n-l+1):
j = l + i
for s in range(i+1,j):
for X in binary_count:
for key in binary_count[X]:
Y = key.split(" ")[0]
Z = key.split(" ")[1]
if Y in chart[i][s] and Z in chart[s][j]:
q = float(binary_count[X][key])/float(nonterminal_count[X])
probability = float(chart[i][s][Y])*float(chart[s][j][Z])* q
if X in chart[i][j]:
if probability > chart[i][j][X]:
chart[i][j][X] = probability
backpointers[i][j][X] = [s,Y,Z]
else:
chart[i][j][X] = probability
backpointers[i][j][X] = [s,Y,Z]
return chart, backpointers
# returns max probability parse tree starting with 'S'; if no tree starting with 'S' exists, the arg max tree
# starting with any nonterminal.
def trace(root, i, j):
if len(backpointers[i][j][root]) == 1:
return [root, backpointers[i][j][root]]
else:
s = backpointers[i][j][root][0]
Y = backpointers[i][j][root][1]
Z = backpointers[i][j][root][2]
return [root, trace(Y,i,s), trace(Z,s,j)]
# To get script runtime (optional)
# start = timeit.default_timer()
# Obtain the count(X->YZ), count(X->w), count(X), A.K.A, the binary, unary and non-terminal counts
# (see utils.py)
nonterminal_count, unary_count, binary_count = get_counts(sys.argv[1])
nonterminal_simple, unary_simple, binary_simple = get_counts(sys.argv[3])
# Read dev file line by line
dev_data = file(sys.argv[2],"r")
line = dev_data.readline().strip()
while line:
words = line.split(" ")
chart, backpointers = cyk(words, binary_count, unary_count, nonterminal_count)
n = len(words)
# Check if at least one valid tree is returned using the vertical markovization model
if len(backpointers[0][n]) > 0:
# Get parse tree of max probability starting with 'S'
if 'S' in backpointers[0][n]:
tree = trace('S',0,n)
# If there are no valid parse trees starting with 'S', get arg max starting with any nonterminal
else:
X = max(chart[0][n].iteritems(), key=operator.itemgetter(1))[0]
tree = trace(X,0,n)
# if not, try with original rule counts
else :
chart, backpointers = cyk(words, binary_simple, unary_simple, nonterminal_simple)
if 'S' in backpointers[0][n]:
tree = trace('S',0,n)
# If there are no valid parse trees starting with 'S', get arg max starting with any nonterminal
else:
X = max(chart[0][n].iteritems(), key=operator.itemgetter(1))[0]
tree = trace(X,0,n)
# tree is in list form, convert to JSON for compatibility, write to output file
data = json.dumps(tree)
sys.stdout.write(data)
sys.stdout.write("\n")
line = dev_data.readline().strip()
# In case you want to see the parse trees as they are produced for each sentence
# logging.warning(data)
#stop = timeit.default_timer()
#logging.warning('Runtime:')
#logging.warning(stop - start)