-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
133 lines (114 loc) · 4.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import pickle
import itertools
import codecs
import numpy as np
import collections
Instance = collections.namedtuple("Instance", ["sentence", "tags"])
NONE_TAG = "<NONE>"
POS_KEY = "POS"
PADDING_CHAR = "<*>"
def normalize(word, normalized_words=None):
if 'http' in word:
if normalized_words is not None:
normalized_words[word] += 1
return 'URL'
elif '@' in word:
if normalized_words is not None:
normalized_words[word] += 1
return 'EMAIL'
return word
def get_word_chars(sentence, i2w, c2i, normalized_words=None):
"""get_word_chars: gets the character index level representation of a sentence,
prior to processing with the character level RNN
:param sentence: a list of word indices
:param i2w: index to word mappings
:param c2i: character to index mappings
"""
pad_char = c2i[PADDING_CHAR]
return [[pad_char] + [c2i[c] for c in normalize(i2w[word], normalized_words)] + [pad_char] for word in sentence]
class CSVLogger:
def __init__(self, filename, columns):
self.file = open(filename, "w")
self.columns = columns
self.file.write(','.join(columns) + "\n")
def add_column(self, data):
self.file.write(','.join([str(d) for d in data]) + "\n")
self.file.flush()
def close(self):
self.file.close()
def read_pretrained_embeddings(filename, w2i):
word_to_embed = {}
with codecs.open(filename, "r", "utf-8") as f:
for line in f:
split = line.split()
if len(split) > 2:
word = split[0]
vec = split[1:]
word_to_embed[word] = vec
embedding_dim = len(word_to_embed[list(word_to_embed.keys())[0]])
out = np.random.uniform(-0.8, 0.8, (len(w2i), embedding_dim))
for word, embed in list(word_to_embed.items()):
embed_arr = np.array(embed)
if np.linalg.norm(embed_arr) < 15.0 and word in w2i:
# Theres a reason for this if condition. Some tokens in ptb
# cause numerical problems because they are long strings of the same punctuation, e.g
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! which end up having huge norms, since Morfessor will
# segment it as a ton of ! and then the sum of these morpheme vectors is huge.
out[w2i[word]] = np.array(embed)
return out
def read_text_embs(files):
word_embs = dict()
for filename in files:
with codecs.open(filename, "r", "utf-8") as f:
for line in f:
split = line.split()
if len(split) > 2:
word_embs[split[0]] = np.array([float(s) for s in split[1:]])
return list(zip(*iter(word_embs.items())))
def read_pickle_embs(files):
word_embs = dict()
for filename in files:
print(filename)
words, embs = pickle.load(open(filename, "r"))
word_embs.update(list(zip(words, embs)))
return list(zip(*iter(word_embs.items())))
def split_tagstring(s, uni_key=False, has_pos=False):
'''
Returns attribute-value mapping from UD-type CONLL field
:param uni_key: if toggled, returns attribute-value pairs as joined strings (with the '=')
:param has_pos: input line segment includes POS tag label
'''
if has_pos:
s = s.split("\t")[1]
ret = [] if uni_key else {}
if "=" not in s: # incorrect format
return ret
for attval in s.split('|'):
attval = attval.strip()
if not uni_key:
a,v = attval.split('=')
ret[a] = v
else:
ret.append(attval)
return ret
def morphotag_strings(i2ts, tag_mapping, pos_separate_col=True):
senlen = len(list(tag_mapping.values())[0])
key_value_strs = []
# j iterates along sentence, as we're building the string representations
# in the opposite orientation as the mapping
for j in range(senlen):
place_strs = []
for att, seq in list(tag_mapping.items()):
val = i2ts[att][seq[j]]
if pos_separate_col and att == POS_KEY:
pos_str = val
elif val != NONE_TAG:
place_strs.append(att + "=" + val)
morpho_str = "|".join(sorted(place_strs))
if pos_separate_col:
key_value_strs.append(pos_str + "\t" + morpho_str)
else:
key_value_strs.append(morpho_str)
return key_value_strs
def sortvals(dct):
return [v for (k,v) in sorted(dct.items())]